From f4d038b385a47d28df427aaf163d674dcc193e74 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Thu, 23 Oct 2025 11:31:11 +0200 Subject: [PATCH] chore: pre-align cot retriever with dev --- .../graph_completion_cot_retriever.py | 234 +++++++----------- 1 file changed, 83 insertions(+), 151 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 55cbcfce5..3f6ca81be 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -1,31 +1,22 @@ -import json +import asyncio from typing import Optional, List, Type, Any -from pydantic import BaseModel from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text +from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + get_conversation_history, +) from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt +from cognee.context_global_variables import session_user +from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() -def _as_answer_text(completion: Any) -> str: - """Convert completion to human-readable text for validation and follow-up prompts.""" - if isinstance(completion, str): - return completion - if isinstance(completion, BaseModel): - # Add notice that this is a structured response - json_str = completion.model_dump_json(indent=2) - return f"[Structured Response]\n{json_str}" - try: - return json.dumps(completion, indent=2) - except TypeError: - return str(completion) - - class GraphCompletionCotRetriever(GraphCompletionRetriever): """ Handles graph completion by generating responses based on a series of interactions with @@ -70,138 +61,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self.followup_system_prompt_path = followup_system_prompt_path self.followup_user_prompt_path = followup_user_prompt_path - async def _run_cot_completion( - self, - query: str, - context: Optional[List[Edge]] = None, - max_iter: int = 4, - response_model: Type = str, - ) -> tuple[Any, str, List[Edge]]: - """ - Run chain-of-thought completion with optional structured output. - - Returns: - -------- - - completion_result: The generated completion (string or structured model) - - context_text: The resolved context text - - triplets: The list of triplets used - """ - followup_question = "" - triplets = [] - completion = "" - - for round_idx in range(max_iter + 1): - if round_idx == 0: - if context is None: - triplets = await self.get_context(query) - context_text = await self.resolve_edges_to_text(triplets) - else: - context_text = await self.resolve_edges_to_text(context) - else: - triplets += await self.get_context(followup_question) - context_text = await self.resolve_edges_to_text(list(set(triplets))) - - if response_model is str: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) - else: - args = {"question": query, "context": context_text} - user_prompt = render_prompt(self.user_prompt_path, args) - system_prompt = ( - self.system_prompt - if self.system_prompt - else read_query_prompt(self.system_prompt_path) - ) - - completion = await LLMGateway.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=response_model, - ) - - logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") - - if round_idx < max_iter: - answer_text = _as_answer_text(completion) - valid_args = {"query": query, "answer": answer_text, "context": context_text} - valid_user_prompt = render_prompt( - filename=self.validation_user_prompt_path, context=valid_args - ) - valid_system_prompt = read_query_prompt( - prompt_file_name=self.validation_system_prompt_path - ) - - reasoning = await LLMGateway.acreate_structured_output( - text_input=valid_user_prompt, - system_prompt=valid_system_prompt, - response_model=str, - ) - followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning} - followup_prompt = render_prompt( - filename=self.followup_user_prompt_path, context=followup_args - ) - followup_system = read_query_prompt( - prompt_file_name=self.followup_system_prompt_path - ) - - followup_question = await LLMGateway.acreate_structured_output( - text_input=followup_prompt, system_prompt=followup_system, response_model=str - ) - logger.info( - f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" - ) - - return completion, context_text, triplets - - async def get_structured_completion( - self, - query: str, - context: Optional[List[Edge]] = None, - max_iter: int = 4, - response_model: Type = str, - ) -> Any: - """ - Generate structured completion responses based on a user query and contextual information. - - This method applies the same chain-of-thought logic as get_completion but returns - structured output using the provided response model. - - Parameters: - ----------- - - query (str): The user's query to be processed and answered. - - context (Optional[List[Edge]]): Optional context that may assist in answering the query. - If not provided, it will be fetched based on the query. (default None) - - max_iter: The maximum number of iterations to refine the answer and generate - follow-up questions. (default 4) - - response_model (Type): The Pydantic model type for structured output. (default str) - - Returns: - -------- - - Any: The generated structured completion based on the response model. - """ - completion, context_text, triplets = await self._run_cot_completion( - query=query, - context=context, - max_iter=max_iter, - response_model=response_model, - ) - - if self.save_interaction and context and triplets and completion: - await self.save_qa( - question=query, answer=str(completion), context=context_text, triplets=triplets - ) - - return completion - async def get_completion( self, query: str, context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, max_iter=4, ) -> List[str]: """ @@ -218,6 +82,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - query (str): The user's query to be processed and answered. - context (Optional[Any]): Optional context that may assist in answering the query. If not provided, it will be fetched based on the query. (default None) + - session_id (Optional[str]): Optional session identifier for caching. If None, + defaults to 'default_session'. (default None) - max_iter: The maximum number of iterations to refine the answer and generate follow-up questions. (default 4) @@ -226,16 +92,82 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - completion, context_text, triplets = await self._run_cot_completion( - query=query, - context=context, - max_iter=max_iter, - response_model=str, - ) + followup_question = "" + triplets = [] + completion = "" + + # Retrieve conversation history if session saving is enabled + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + conversation_history = "" + if session_save: + conversation_history = await get_conversation_history(session_id=session_id) + + for round_idx in range(max_iter + 1): + if round_idx == 0: + if context is None: + triplets = await self.get_context(query) + context_text = await self.resolve_edges_to_text(triplets) + else: + context_text = await self.resolve_edges_to_text(context) + else: + triplets += await self.get_context(followup_question) + context_text = await self.resolve_edges_to_text(list(set(triplets))) + + completion = await generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + conversation_history=conversation_history if session_save else None, + ) + logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") + if round_idx < max_iter: + valid_args = {"query": query, "answer": completion, "context": context_text} + valid_user_prompt = render_prompt( + filename=self.validation_user_prompt_path, context=valid_args + ) + valid_system_prompt = read_query_prompt( + prompt_file_name=self.validation_system_prompt_path + ) + + reasoning = await LLMGateway.acreate_structured_output( + text_input=valid_user_prompt, + system_prompt=valid_system_prompt, + response_model=str, + ) + followup_args = {"query": query, "answer": completion, "reasoning": reasoning} + followup_prompt = render_prompt( + filename=self.followup_user_prompt_path, context=followup_args + ) + followup_system = read_query_prompt( + prompt_file_name=self.followup_system_prompt_path + ) + + followup_question = await LLMGateway.acreate_structured_output( + text_input=followup_prompt, system_prompt=followup_system, response_model=str + ) + logger.info( + f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" + ) if self.save_interaction and context and triplets and completion: await self.save_qa( question=query, answer=completion, context=context_text, triplets=triplets ) + # Save to session cache + if session_save: + context_summary = await summarize_text(context_text) + await save_conversation_history( + query=query, + context_summary=context_summary, + answer=completion, + session_id=session_id, + ) + return [completion]