refactor: add structured output to completion retrievers

This commit is contained in:
Andrej Milicevic 2025-11-04 15:09:33 +01:00
parent 8d7c4d5384
commit 7e3c24100b
9 changed files with 67 additions and 90 deletions

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
@ -85,7 +85,11 @@ class EntityCompletionRetriever(BaseRetriever):
return None
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[str]:
"""
Generate completion using provided context or fetch new context.
@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
fetched if not provided. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional
from typing import Any, Optional, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
@ -75,7 +75,11 @@ class CompletionRetriever(BaseRetriever):
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> str:
"""
Generates an LLM completion using the context.
@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if session_save:
@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
session_id=session_id,
)
return completion
return [completion]

View file

@ -56,6 +56,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
context_extension_rounds=4,
response_model: Type = str,
) -> List[str]:
"""
Extends the context for a given query by retrieving related triplets and generating new
@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
defaults to 'default_session'. (default None)
- context_extension_rounds: The maximum number of rounds to extend the context with
new triplets before halting. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context_text and triplets and completion:

View file

@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
generate_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_structured_completion(
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
return completion, context_text, triplets
async def get_structured_completion(
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
max_iter=4,
response_model: Type = str,
) -> Any:
) -> List[str]:
"""
Generate structured completion responses based on a user query and contextual information.
Generate completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs. It returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
Returns:
--------
- Any: The generated structured completion based on the response model.
- List[str]: A list containing the generated answer to the user's query.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return completion
async def get_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter=4,
) -> List[str]:
"""
Generate completion responses based on a user query and contextual information.
This method interacts with a language model client to retrieve a structured response,
using a series of iterations to refine the answers and generate follow-up questions
based on reasoning derived from previous outputs. It raises exceptions if the context
retrieval fails or if the model encounters issues in generating outputs.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[Any]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
Returns:
--------
- List[str]: A list containing the generated answer to the user's query.
"""
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion]

View file

@ -146,6 +146,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[str]:
"""
Generates a completion using graph connections context based on a query.
@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:

View file

@ -146,7 +146,11 @@ class TemporalRetriever(GraphCompletionRetriever):
return self.descriptions_to_string(top_k_events)
async def get_completion(
self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
self,
query: str,
context: Optional[str] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[str]:
"""
Generates a response using the query and optional context.
@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:

View file

@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_structured_completion(
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
@ -11,8 +11,8 @@ async def generate_structured_completion(
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
response_model: Type = str,
) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -28,26 +28,6 @@ async def generate_structured_completion(
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)
async def summarize_text(
text: str,
system_prompt_path: str = "summarize_search_results.txt",

View file

@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
)
retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion(
completion = await retriever.get_completion(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
enrichment.improved_answer = completion.answer
enrichment.improved_answer = completion[0].answer
enrichment.new_context = new_context_text
enrichment.explanation = completion.explanation
enrichment.explanation = completion[0].explanation
return enrichment
else:
logger.warning(

View file

@ -206,16 +206,22 @@ class TestGraphCompletionCoTRetriever:
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_structured_completion("Who works at Figma?")
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
assert string_answer.strip(), "Answer should not be empty"
string_answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in string_answer), (
"Answer should not be empty"
)
# Test with structured response model
structured_answer = await retriever.get_structured_completion(
structured_answer = await retriever.get_completion(
"Who works at Figma?", response_model=TestAnswer
)
assert isinstance(structured_answer, TestAnswer), (
assert isinstance(structured_answer, list), (
f"Expected list, got {type(structured_answer).__name__}"
)
assert all(isinstance(item, TestAnswer) for item in string_answer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)
assert structured_answer.answer.strip(), "Answer field should not be empty"
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
assert structured_answer[0].answer.strip(), "Answer field should not be empty"
assert structured_answer[0].explanation.strip(), "Explanation field should not be empty"