chore: restore the feedback enrichment cot retriever functionality

This commit is contained in:
lxobr 2025-10-23 12:07:31 +02:00
parent 46e6d87c1f
commit 66a8242cec

View file

@ -1,5 +1,7 @@
import asyncio
import json
from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
@ -17,6 +19,20 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger()
def _as_answer_text(completion: Any) -> str:
"""Convert completion to human-readable text for validation and follow-up prompts."""
if isinstance(completion, str):
return completion
if isinstance(completion, BaseModel):
# Add notice that this is a structured response
json_str = completion.model_dump_json(indent=2)
return f"[Structured Response]\n{json_str}"
try:
return json.dumps(completion, indent=2)
except TypeError:
return str(completion)
class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
Handles graph completion by generating responses based on a series of interactions with
@ -25,6 +41,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -61,6 +78,160 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_system_prompt_path = followup_system_prompt_path
self.followup_user_prompt_path = followup_user_prompt_path
async def _run_cot_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
response_model: Type = str,
) -> tuple[Any, str, List[Edge]]:
"""
Run chain-of-thought completion with optional structured output and session caching.
Returns:
--------
- completion_result: The generated completion (string or structured model)
- context_text: The resolved context text
- triplets: The list of triplets used
"""
followup_question = ""
triplets = []
completion = ""
# Retrieve conversation history if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
if response_model is str:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if session_save else None,
)
else:
args = {"question": query, "context": context_text}
user_prompt = render_prompt(self.user_prompt_path, args)
system_prompt = (
self.system_prompt
if self.system_prompt
else read_query_prompt(self.system_prompt_path)
)
completion = await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=response_model,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
answer_text = _as_answer_text(completion)
valid_args = {"query": query, "answer": answer_text, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
# Save to session cache
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=str(completion),
session_id=session_id,
)
return completion, context_text, triplets
async def get_structured_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
response_model: Type = str,
) -> Any:
"""
Generate structured completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- Any: The generated structured completion based on the response model.
"""
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets
)
return completion
async def get_completion(
self,
query: str,
@ -92,82 +263,17 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query.
"""
followup_question = ""
triplets = []
completion = ""
# Retrieve conversation history if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if session_save else None,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
# Save to session cache
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion]