chore: pre-align cot retriever with dev

This commit is contained in:
lxobr 2025-10-23 11:31:11 +02:00
parent 46b19ad02c
commit f4d038b385

View file

@ -1,31 +1,22 @@
import json import asyncio
from typing import Optional, List, Type, Any from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger() logger = get_logger()
def _as_answer_text(completion: Any) -> str:
"""Convert completion to human-readable text for validation and follow-up prompts."""
if isinstance(completion, str):
return completion
if isinstance(completion, BaseModel):
# Add notice that this is a structured response
json_str = completion.model_dump_json(indent=2)
return f"[Structured Response]\n{json_str}"
try:
return json.dumps(completion, indent=2)
except TypeError:
return str(completion)
class GraphCompletionCotRetriever(GraphCompletionRetriever): class GraphCompletionCotRetriever(GraphCompletionRetriever):
""" """
Handles graph completion by generating responses based on a series of interactions with Handles graph completion by generating responses based on a series of interactions with
@ -70,138 +61,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_system_prompt_path = followup_system_prompt_path self.followup_system_prompt_path = followup_system_prompt_path
self.followup_user_prompt_path = followup_user_prompt_path self.followup_user_prompt_path = followup_user_prompt_path
async def _run_cot_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
max_iter: int = 4,
response_model: Type = str,
) -> tuple[Any, str, List[Edge]]:
"""
Run chain-of-thought completion with optional structured output.
Returns:
--------
- completion_result: The generated completion (string or structured model)
- context_text: The resolved context text
- triplets: The list of triplets used
"""
followup_question = ""
triplets = []
completion = ""
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
if response_model is str:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
else:
args = {"question": query, "context": context_text}
user_prompt = render_prompt(self.user_prompt_path, args)
system_prompt = (
self.system_prompt
if self.system_prompt
else read_query_prompt(self.system_prompt_path)
)
completion = await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=response_model,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
answer_text = _as_answer_text(completion)
valid_args = {"query": query, "answer": answer_text, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
return completion, context_text, triplets
async def get_structured_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
max_iter: int = 4,
response_model: Type = str,
) -> Any:
"""
Generate structured completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- Any: The generated structured completion based on the response model.
"""
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets
)
return completion
async def get_completion( async def get_completion(
self, self,
query: str, query: str,
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4, max_iter=4,
) -> List[str]: ) -> List[str]:
""" """
@ -218,6 +82,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- query (str): The user's query to be processed and answered. - query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query. - context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None) If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate - max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4) follow-up questions. (default 4)
@ -226,16 +92,82 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query. - List[str]: A list containing the generated answer to the user's query.
""" """
completion, context_text, triplets = await self._run_cot_completion( followup_question = ""
query=query, triplets = []
context=context, completion = ""
max_iter=max_iter,
response_model=str, # Retrieve conversation history if session saving is enabled
) cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if session_save else None,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:
await self.save_qa( await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
# Save to session cache
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion] return [completion]