From ecae650a2873cb94b78198ac10bf216ac523d179 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Thu, 23 Oct 2025 12:30:55 +0200 Subject: [PATCH] refactor: unify structured and str completion --- .../graph_completion_cot_retriever.py | 99 +++++++++---------- cognee/modules/retrieval/utils/completion.py | 29 +++++- 2 files changed, 71 insertions(+), 57 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index d785a1494..299db6855 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -6,7 +6,10 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text +from cognee.modules.retrieval.utils.completion import ( + generate_structured_completion, + summarize_text, +) from cognee.modules.retrieval.utils.session_cache import ( save_conversation_history, get_conversation_history, @@ -82,12 +85,20 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self, query: str, context: Optional[List[Edge]] = None, - session_id: Optional[str] = None, + conversation_history: str = "", max_iter: int = 4, response_model: Type = str, ) -> tuple[Any, str, List[Edge]]: """ - Run chain-of-thought completion with optional structured output and session caching. + Run chain-of-thought completion with optional structured output. + + Parameters: + ----------- + - query: User query + - context: Optional pre-fetched context edges + - conversation_history: Optional conversation history string + - max_iter: Maximum CoT iterations + - response_model: Type for structured output (str for plain text) Returns: -------- @@ -99,16 +110,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets = [] completion = "" - # Retrieve conversation history if session saving is enabled - cache_config = CacheConfig() - user = session_user.get() - user_id = getattr(user, "id", None) - session_save = user_id and cache_config.caching - - conversation_history = "" - if session_save: - conversation_history = await get_conversation_history(session_id=session_id) - for round_idx in range(max_iter + 1): if round_idx == 0: if context is None: @@ -120,29 +121,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets += await self.get_context(followup_question) context_text = await self.resolve_edges_to_text(list(set(triplets))) - if response_model is str: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - conversation_history=conversation_history if session_save else None, - ) - else: - args = {"question": query, "context": context_text} - user_prompt = render_prompt(self.user_prompt_path, args) - system_prompt = ( - self.system_prompt - if self.system_prompt - else read_query_prompt(self.system_prompt_path) - ) - - completion = await LLMGateway.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=response_model, - ) + completion = await generate_structured_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + conversation_history=conversation_history if conversation_history else None, + response_model=response_model, + ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") @@ -176,16 +163,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" ) - # Save to session cache - if session_save: - context_summary = await summarize_text(context_text) - await save_conversation_history( - query=query, - context_summary=context_summary, - answer=str(completion), - session_id=session_id, - ) - return completion, context_text, triplets async def get_structured_completion( @@ -217,10 +194,21 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): -------- - Any: The generated structured completion based on the response model. """ + # Check if session saving is enabled + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + # Load conversation history if enabled + conversation_history = "" + if session_save: + conversation_history = await get_conversation_history(session_id=session_id) + completion, context_text, triplets = await self._run_cot_completion( query=query, context=context, - session_id=session_id, + conversation_history=conversation_history, max_iter=max_iter, response_model=response_model, ) @@ -230,6 +218,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=str(completion), context=context_text, triplets=triplets ) + # Save to session cache if enabled + if session_save: + context_summary = await summarize_text(context_text) + await save_conversation_history( + query=query, + context_summary=context_summary, + answer=str(completion), + session_id=session_id, + ) + return completion async def get_completion( @@ -263,7 +261,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - completion, context_text, triplets = await self._run_cot_completion( + completion = await self.get_structured_completion( query=query, context=context, session_id=session_id, @@ -271,9 +269,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): response_model=str, ) - if self.save_interaction and context and triplets and completion: - await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets - ) - return [completion] diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index 6b6b6190e..db7a10252 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -1,17 +1,18 @@ -from typing import Optional +from typing import Optional, Type, Any from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt -async def generate_completion( +async def generate_structured_completion( query: str, context: str, user_prompt_path: str, system_prompt_path: str, system_prompt: Optional[str] = None, conversation_history: Optional[str] = None, -) -> str: - """Generates a completion using LLM with given context and prompts.""" + response_model: Type = str, +) -> Any: + """Generates a structured completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = render_prompt(user_prompt_path, args) system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path) @@ -23,6 +24,26 @@ async def generate_completion( return await LLMGateway.acreate_structured_output( text_input=user_prompt, system_prompt=system_prompt, + response_model=response_model, + ) + + +async def generate_completion( + query: str, + context: str, + user_prompt_path: str, + system_prompt_path: str, + system_prompt: Optional[str] = None, + conversation_history: Optional[str] = None, +) -> str: + """Generates a completion using LLM with given context and prompts.""" + return await generate_structured_completion( + query=query, + context=context, + user_prompt_path=user_prompt_path, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + conversation_history=conversation_history, response_model=str, )