feat: allow structured output in the cot retriever

This commit is contained in:
lxobr 2025-10-20 23:43:41 +02:00
parent 97eb89386e
commit 1e1fac3261
2 changed files with 201 additions and 51 deletions

View file

@ -1,4 +1,6 @@
import json
from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
@ -10,6 +12,20 @@ from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
logger = get_logger()
def _as_answer_text(completion: Any) -> str:
"""Convert completion to human-readable text for validation and follow-up prompts."""
if isinstance(completion, str):
return completion
if isinstance(completion, BaseModel):
# Add notice that this is a structured response
json_str = completion.model_dump_json(indent=2)
return f"[Structured Response]\n{json_str}"
try:
return json.dumps(completion, indent=2)
except TypeError:
return str(completion)
class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
Handles graph completion by generating responses based on a series of interactions with
@ -54,6 +70,134 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_system_prompt_path = followup_system_prompt_path
self.followup_user_prompt_path = followup_user_prompt_path
async def _run_cot_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
max_iter: int = 4,
response_model: Type = str,
) -> tuple[Any, str, List[Edge]]:
"""
Run chain-of-thought completion with optional structured output.
Returns:
--------
- completion_result: The generated completion (string or structured model)
- context_text: The resolved context text
- triplets: The list of triplets used
"""
followup_question = ""
triplets = []
completion = ""
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
if response_model is str:
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
else:
args = {"question": query, "context": context_text}
user_prompt = render_prompt(self.user_prompt_path, args)
system_prompt = (
self.system_prompt
if self.system_prompt
else read_query_prompt(self.system_prompt_path)
)
completion = await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=response_model,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
answer_text = _as_answer_text(completion)
valid_args = {"query": query, "answer": answer_text, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
return completion, context_text, triplets
async def get_structured_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
max_iter: int = 4,
response_model: Type = str,
) -> Any:
"""
Generate structured completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- Any: The generated structured completion based on the response model.
"""
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets
)
return completion
async def get_completion(
self,
query: str,
@ -82,57 +226,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query.
"""
followup_question = ""
triplets = []
completion = ""
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
max_iter=max_iter,
response_model=str,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(

View file

@ -2,6 +2,7 @@ import os
import pytest
import pathlib
from typing import Optional, Union
from pydantic import BaseModel
import cognee
from cognee.low_level import setup, DataPoint
@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1)
entities = [company1, person1]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_structured_completion("Who works at Figma?")
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
assert string_answer.strip(), "Answer should not be empty"
# Test with structured response model
structured_answer = await retriever.get_structured_completion(
"Who works at Figma?", response_model=TestAnswer
)
assert isinstance(structured_answer, TestAnswer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)
assert structured_answer.answer.strip(), "Answer field should not be empty"
assert structured_answer.explanation.strip(), "Explanation field should not be empty"