chore: restore the feedback enrichment cot retriever functionality
This commit is contained in:
parent
46e6d87c1f
commit
66a8242cec
1 changed files with 178 additions and 72 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
from typing import Optional, List, Type, Any
|
from typing import Optional, List, Type, Any
|
||||||
|
from pydantic import BaseModel
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
@ -17,6 +19,20 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _as_answer_text(completion: Any) -> str:
|
||||||
|
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
||||||
|
if isinstance(completion, str):
|
||||||
|
return completion
|
||||||
|
if isinstance(completion, BaseModel):
|
||||||
|
# Add notice that this is a structured response
|
||||||
|
json_str = completion.model_dump_json(indent=2)
|
||||||
|
return f"[Structured Response]\n{json_str}"
|
||||||
|
try:
|
||||||
|
return json.dumps(completion, indent=2)
|
||||||
|
except TypeError:
|
||||||
|
return str(completion)
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
"""
|
"""
|
||||||
Handles graph completion by generating responses based on a series of interactions with
|
Handles graph completion by generating responses based on a series of interactions with
|
||||||
|
|
@ -25,6 +41,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
questions based on reasoning. The public methods are:
|
questions based on reasoning. The public methods are:
|
||||||
|
|
||||||
- get_completion
|
- get_completion
|
||||||
|
- get_structured_completion
|
||||||
|
|
||||||
Instance variables include:
|
Instance variables include:
|
||||||
- validation_system_prompt_path
|
- validation_system_prompt_path
|
||||||
|
|
@ -61,6 +78,160 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
self.followup_system_prompt_path = followup_system_prompt_path
|
self.followup_system_prompt_path = followup_system_prompt_path
|
||||||
self.followup_user_prompt_path = followup_user_prompt_path
|
self.followup_user_prompt_path = followup_user_prompt_path
|
||||||
|
|
||||||
|
async def _run_cot_completion(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[List[Edge]] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
max_iter: int = 4,
|
||||||
|
response_model: Type = str,
|
||||||
|
) -> tuple[Any, str, List[Edge]]:
|
||||||
|
"""
|
||||||
|
Run chain-of-thought completion with optional structured output and session caching.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- completion_result: The generated completion (string or structured model)
|
||||||
|
- context_text: The resolved context text
|
||||||
|
- triplets: The list of triplets used
|
||||||
|
"""
|
||||||
|
followup_question = ""
|
||||||
|
triplets = []
|
||||||
|
completion = ""
|
||||||
|
|
||||||
|
# Retrieve conversation history if session saving is enabled
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
session_save = user_id and cache_config.caching
|
||||||
|
|
||||||
|
conversation_history = ""
|
||||||
|
if session_save:
|
||||||
|
conversation_history = await get_conversation_history(session_id=session_id)
|
||||||
|
|
||||||
|
for round_idx in range(max_iter + 1):
|
||||||
|
if round_idx == 0:
|
||||||
|
if context is None:
|
||||||
|
triplets = await self.get_context(query)
|
||||||
|
context_text = await self.resolve_edges_to_text(triplets)
|
||||||
|
else:
|
||||||
|
context_text = await self.resolve_edges_to_text(context)
|
||||||
|
else:
|
||||||
|
triplets += await self.get_context(followup_question)
|
||||||
|
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||||
|
|
||||||
|
if response_model is str:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context_text,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
conversation_history=conversation_history if session_save else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
args = {"question": query, "context": context_text}
|
||||||
|
user_prompt = render_prompt(self.user_prompt_path, args)
|
||||||
|
system_prompt = (
|
||||||
|
self.system_prompt
|
||||||
|
if self.system_prompt
|
||||||
|
else read_query_prompt(self.system_prompt_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
completion = await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||||
|
|
||||||
|
if round_idx < max_iter:
|
||||||
|
answer_text = _as_answer_text(completion)
|
||||||
|
valid_args = {"query": query, "answer": answer_text, "context": context_text}
|
||||||
|
valid_user_prompt = render_prompt(
|
||||||
|
filename=self.validation_user_prompt_path, context=valid_args
|
||||||
|
)
|
||||||
|
valid_system_prompt = read_query_prompt(
|
||||||
|
prompt_file_name=self.validation_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning = await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=valid_user_prompt,
|
||||||
|
system_prompt=valid_system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
|
||||||
|
followup_prompt = render_prompt(
|
||||||
|
filename=self.followup_user_prompt_path, context=followup_args
|
||||||
|
)
|
||||||
|
followup_system = read_query_prompt(
|
||||||
|
prompt_file_name=self.followup_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
followup_question = await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to session cache
|
||||||
|
if session_save:
|
||||||
|
context_summary = await summarize_text(context_text)
|
||||||
|
await save_conversation_history(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=str(completion),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion, context_text, triplets
|
||||||
|
|
||||||
|
async def get_structured_completion(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[List[Edge]] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
max_iter: int = 4,
|
||||||
|
response_model: Type = str,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Generate structured completion responses based on a user query and contextual information.
|
||||||
|
|
||||||
|
This method applies the same chain-of-thought logic as get_completion but returns
|
||||||
|
structured output using the provided response model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- query (str): The user's query to be processed and answered.
|
||||||
|
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
||||||
|
If not provided, it will be fetched based on the query. (default None)
|
||||||
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||||
|
defaults to 'default_session'. (default None)
|
||||||
|
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||||
|
follow-up questions. (default 4)
|
||||||
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- Any: The generated structured completion based on the response model.
|
||||||
|
"""
|
||||||
|
completion, context_text, triplets = await self._run_cot_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
session_id=session_id,
|
||||||
|
max_iter=max_iter,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.save_interaction and context and triplets and completion:
|
||||||
|
await self.save_qa(
|
||||||
|
question=query, answer=str(completion), context=context_text, triplets=triplets
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -92,82 +263,17 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
- List[str]: A list containing the generated answer to the user's query.
|
- List[str]: A list containing the generated answer to the user's query.
|
||||||
"""
|
"""
|
||||||
followup_question = ""
|
completion, context_text, triplets = await self._run_cot_completion(
|
||||||
triplets = []
|
query=query,
|
||||||
completion = ""
|
context=context,
|
||||||
|
session_id=session_id,
|
||||||
# Retrieve conversation history if session saving is enabled
|
max_iter=max_iter,
|
||||||
cache_config = CacheConfig()
|
response_model=str,
|
||||||
user = session_user.get()
|
)
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
session_save = user_id and cache_config.caching
|
|
||||||
|
|
||||||
conversation_history = ""
|
|
||||||
if session_save:
|
|
||||||
conversation_history = await get_conversation_history(session_id=session_id)
|
|
||||||
|
|
||||||
for round_idx in range(max_iter + 1):
|
|
||||||
if round_idx == 0:
|
|
||||||
if context is None:
|
|
||||||
triplets = await self.get_context(query)
|
|
||||||
context_text = await self.resolve_edges_to_text(triplets)
|
|
||||||
else:
|
|
||||||
context_text = await self.resolve_edges_to_text(context)
|
|
||||||
else:
|
|
||||||
triplets += await self.get_context(followup_question)
|
|
||||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
|
||||||
|
|
||||||
completion = await generate_completion(
|
|
||||||
query=query,
|
|
||||||
context=context_text,
|
|
||||||
user_prompt_path=self.user_prompt_path,
|
|
||||||
system_prompt_path=self.system_prompt_path,
|
|
||||||
system_prompt=self.system_prompt,
|
|
||||||
conversation_history=conversation_history if session_save else None,
|
|
||||||
)
|
|
||||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
|
||||||
if round_idx < max_iter:
|
|
||||||
valid_args = {"query": query, "answer": completion, "context": context_text}
|
|
||||||
valid_user_prompt = render_prompt(
|
|
||||||
filename=self.validation_user_prompt_path, context=valid_args
|
|
||||||
)
|
|
||||||
valid_system_prompt = read_query_prompt(
|
|
||||||
prompt_file_name=self.validation_system_prompt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
reasoning = await LLMGateway.acreate_structured_output(
|
|
||||||
text_input=valid_user_prompt,
|
|
||||||
system_prompt=valid_system_prompt,
|
|
||||||
response_model=str,
|
|
||||||
)
|
|
||||||
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
|
|
||||||
followup_prompt = render_prompt(
|
|
||||||
filename=self.followup_user_prompt_path, context=followup_args
|
|
||||||
)
|
|
||||||
followup_system = read_query_prompt(
|
|
||||||
prompt_file_name=self.followup_system_prompt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
followup_question = await LLMGateway.acreate_structured_output(
|
|
||||||
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to session cache
|
|
||||||
if session_save:
|
|
||||||
context_summary = await summarize_text(context_text)
|
|
||||||
await save_conversation_history(
|
|
||||||
query=query,
|
|
||||||
context_summary=context_summary,
|
|
||||||
answer=completion,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue