refactor: add structured output to completion retrievers

This commit is contained in:
Andrej Milicevic 2025-11-04 15:09:33 +01:00
parent 8d7c4d5384
commit 7e3c24100b
9 changed files with 67 additions and 90 deletions

View file

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import Any, Optional, List from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
@ -85,7 +85,11 @@ class EntityCompletionRetriever(BaseRetriever):
return None return None
async def get_completion( async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[str]: ) -> List[str]:
""" """
Generate completion using provided context or fetch new context. Generate completion using provided context or fetch new context.
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
fetched if not provided. (default None) fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None, - session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None) defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns: Returns:
-------- --------
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history, conversation_history=conversation_history,
response_model=response_model,
), ),
) )
else: else:
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
response_model=response_model,
) )
if session_save: if session_save:

View file

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import Any, Optional from typing import Any, Optional, Type
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
@ -75,7 +75,11 @@ class CompletionRetriever(BaseRetriever):
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion( async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> str: ) -> str:
""" """
Generates an LLM completion using the context. Generates an LLM completion using the context.
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
completion; if None, it retrieves the context for the query. (default None) completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None, - session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None) defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns: Returns:
-------- --------
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
conversation_history=conversation_history, conversation_history=conversation_history,
response_model=response_model,
), ),
) )
else: else:
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
response_model=response_model,
) )
if session_save: if session_save:
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
session_id=session_id, session_id=session_id,
) )
return completion return [completion]

View file

@ -56,6 +56,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
context_extension_rounds=4, context_extension_rounds=4,
response_model: Type = str,
) -> List[str]: ) -> List[str]:
""" """
Extends the context for a given query by retrieving related triplets and generating new Extends the context for a given query by retrieving related triplets and generating new
@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
defaults to 'default_session'. (default None) defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with - context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4) new triplets before halting. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns: Returns:
-------- --------
@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
conversation_history=conversation_history, conversation_history=conversation_history,
response_model=response_model,
), ),
) )
else: else:
@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
response_model=response_model,
) )
if self.save_interaction and context_text and triplets and completion: if self.save_interaction and context_text and triplets and completion:

View file

@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import ( from cognee.modules.retrieval.utils.completion import (
generate_structured_completion, generate_completion,
summarize_text, summarize_text,
) )
from cognee.modules.retrieval.utils.session_cache import ( from cognee.modules.retrieval.utils.session_cache import (
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are: questions based on reasoning. The public methods are:
- get_completion - get_completion
- get_structured_completion
Instance variables include: Instance variables include:
- validation_system_prompt_path - validation_system_prompt_path
@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question) triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets))) context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_structured_completion( completion = await generate_completion(
query=query, query=query,
context=context_text, context=context_text,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
return completion, context_text, triplets return completion, context_text, triplets
async def get_structured_completion( async def get_completion(
self, self,
query: str, query: str,
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
max_iter: int = 4, max_iter=4,
response_model: Type = str, response_model: Type = str,
) -> Any: ) -> List[str]:
""" """
Generate structured completion responses based on a user query and contextual information. Generate completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs. It returns
structured output using the provided response model. structured output using the provided response model.
Parameters: Parameters:
----------- -----------
- query (str): The user's query to be processed and answered. - query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query. - context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None) If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None, - session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None) defaults to 'default_session'. (default None)
@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
Returns: Returns:
-------- --------
- Any: The generated structured completion based on the response model.
- List[str]: A list containing the generated answer to the user's query.
""" """
# Check if session saving is enabled # Check if session saving is enabled
cache_config = CacheConfig() cache_config = CacheConfig()
@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id, session_id=session_id,
) )
return completion
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
Returns:
--------
- List[str]: A list containing the generated answer to the user's query.
"""
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion] return [completion]

View file

@ -146,6 +146,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
query: str, query: str,
context: Optional[List[Edge]] = None, context: Optional[List[Edge]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
response_model: Type = str,
) -> List[str]: ) -> List[str]:
""" """
Generates a completion using graph connections context based on a query. Generates a completion using graph connections context based on a query.
@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
conversation_history=conversation_history, conversation_history=conversation_history,
response_model=response_model,
), ),
) )
else: else:
@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
response_model=response_model,
) )
if self.save_interaction and context and triplets and completion: if self.save_interaction and context and triplets and completion:

View file

@ -146,7 +146,11 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events) return self.descriptions_to_string(top_k_events)
async def get_completion( async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None self,
query: str,
context: Optional[str] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[str]: ) -> List[str]:
""" """
Generates a response using the query and optional context. Generates a response using the query and optional context.
@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
retrieved based on the query. (default None) retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None, - session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None) defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns: Returns:
-------- --------
@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history, conversation_history=conversation_history,
response_model=response_model,
), ),
) )
else: else:
@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
context=context, context=context,
user_prompt_path=self.user_prompt_path, user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
response_model=response_model,
) )
if session_save: if session_save:

View file

@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_structured_completion( async def generate_completion(
query: str, query: str,
context: str, context: str,
user_prompt_path: str, user_prompt_path: str,
@ -11,8 +11,8 @@ async def generate_structured_completion(
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None, conversation_history: Optional[str] = None,
response_model: Type = str, response_model: Type = str,
) -> Any: ) -> str:
"""Generates a structured completion using LLM with given context and prompts.""" """Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context} args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args) user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path) system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -28,26 +28,6 @@ async def generate_structured_completion(
) )
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)
async def summarize_text( async def summarize_text(
text: str, text: str,
system_prompt_path: str = "summarize_search_results.txt", system_prompt_path: str = "summarize_search_results.txt",

View file

@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
) )
retrieved_context = await retriever.get_context(query_text) retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion( completion = await retriever.get_completion(
query=query_text, query=query_text,
context=retrieved_context, context=retrieved_context,
response_model=ImprovedAnswerResponse, response_model=ImprovedAnswerResponse,
@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
new_context_text = await retriever.resolve_edges_to_text(retrieved_context) new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion: if completion:
enrichment.improved_answer = completion.answer enrichment.improved_answer = completion[0].answer
enrichment.new_context = new_context_text enrichment.new_context = new_context_text
enrichment.explanation = completion.explanation enrichment.explanation = completion[0].explanation
return enrichment return enrichment
else: else:
logger.warning( logger.warning(

View file

@ -206,16 +206,22 @@ class TestGraphCompletionCoTRetriever:
retriever = GraphCompletionCotRetriever() retriever = GraphCompletionCotRetriever()
# Test with string response model (default) # Test with string response model (default)
string_answer = await retriever.get_structured_completion("Who works at Figma?") string_answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}" assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
assert string_answer.strip(), "Answer should not be empty" assert all(isinstance(item, str) and item.strip() for item in string_answer), (
"Answer should not be empty"
)
# Test with structured response model # Test with structured response model
structured_answer = await retriever.get_structured_completion( structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer "Who works at Figma?", response_model=TestAnswer
) )
assert isinstance(structured_answer, TestAnswer), ( assert isinstance(structured_answer, list), (
f"Expected list, got {type(structured_answer).__name__}"
)
assert all(isinstance(item, TestAnswer) for item in string_answer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}" f"Expected TestAnswer, got {type(structured_answer).__name__}"
) )
assert structured_answer.answer.strip(), "Answer field should not be empty"
assert structured_answer.explanation.strip(), "Explanation field should not be empty" assert structured_answer[0].answer.strip(), "Answer field should not be empty"
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"