refactor: unify structured and str completion

This commit is contained in:
lxobr 2025-10-23 12:30:55 +02:00
parent 66a8242cec
commit ecae650a28
2 changed files with 71 additions and 57 deletions

View file

@ -6,7 +6,10 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import ( from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history, save_conversation_history,
get_conversation_history, get_conversation_history,
@ -82,12 +85,20 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self, self,
query: str, query: str,
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None, conversation_history: str = "",
max_iter: int = 4, max_iter: int = 4,
response_model: Type = str, response_model: Type = str,
) -> tuple[Any, str, List[Edge]]: ) -> tuple[Any, str, List[Edge]]:
""" """
Run chain-of-thought completion with optional structured output and session caching. Run chain-of-thought completion with optional structured output.
Parameters:
-----------
- query: User query
- context: Optional pre-fetched context edges
- conversation_history: Optional conversation history string
- max_iter: Maximum CoT iterations
- response_model: Type for structured output (str for plain text)
Returns: Returns:
-------- --------
@ -99,16 +110,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets = [] triplets = []
completion = "" completion = ""
# Retrieve conversation history if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
for round_idx in range(max_iter + 1): for round_idx in range(max_iter + 1):
if round_idx == 0: if round_idx == 0:
if context is None: if context is None:
@ -120,29 +121,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question) triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets))) context_text = await self.resolve_edges_to_text(list(set(triplets)))
if response_model is str: completion = await generate_structured_completion(
completion = await generate_completion( query=query,
query=query, context=context_text,
context=context_text, user_prompt_path=self.user_prompt_path,
user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt,
system_prompt=self.system_prompt, conversation_history=conversation_history if conversation_history else None,
conversation_history=conversation_history if session_save else None, response_model=response_model,
) )
else:
args = {"question": query, "context": context_text}
user_prompt = render_prompt(self.user_prompt_path, args)
system_prompt = (
self.system_prompt
if self.system_prompt
else read_query_prompt(self.system_prompt_path)
)
completion = await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=response_model,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
@ -176,16 +163,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
) )
# Save to session cache
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=str(completion),
session_id=session_id,
)
return completion, context_text, triplets return completion, context_text, triplets
async def get_structured_completion( async def get_structured_completion(
@ -217,10 +194,21 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
-------- --------
- Any: The generated structured completion based on the response model. - Any: The generated structured completion based on the response model.
""" """
# Check if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
# Load conversation history if enabled
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
completion, context_text, triplets = await self._run_cot_completion( completion, context_text, triplets = await self._run_cot_completion(
query=query, query=query,
context=context, context=context,
session_id=session_id, conversation_history=conversation_history,
max_iter=max_iter, max_iter=max_iter,
response_model=response_model, response_model=response_model,
) )
@ -230,6 +218,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=str(completion), context=context_text, triplets=triplets question=query, answer=str(completion), context=context_text, triplets=triplets
) )
# Save to session cache if enabled
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=str(completion),
session_id=session_id,
)
return completion return completion
async def get_completion( async def get_completion(
@ -263,7 +261,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query. - List[str]: A list containing the generated answer to the user's query.
""" """
completion, context_text, triplets = await self._run_cot_completion( completion = await self.get_structured_completion(
query=query, query=query,
context=context, context=context,
session_id=session_id, session_id=session_id,
@ -271,9 +269,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
response_model=str, response_model=str,
) )
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
return [completion] return [completion]

View file

@ -1,17 +1,18 @@
from typing import Optional from typing import Optional, Type, Any
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_completion( async def generate_structured_completion(
query: str, query: str,
context: str, context: str,
user_prompt_path: str, user_prompt_path: str,
system_prompt_path: str, system_prompt_path: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None, conversation_history: Optional[str] = None,
) -> str: response_model: Type = str,
"""Generates a completion using LLM with given context and prompts.""" ) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
args = {"question": query, "context": context} args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args) user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path) system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -23,6 +24,26 @@ async def generate_completion(
return await LLMGateway.acreate_structured_output( return await LLMGateway.acreate_structured_output(
text_input=user_prompt, text_input=user_prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
response_model=response_model,
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str, response_model=str,
) )