From 1e1fac32611fb7e77201c6937776ea4100844813 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 20 Oct 2025 23:43:41 +0200 Subject: [PATCH] feat: allow structured output in the cot retriever --- .../graph_completion_cot_retriever.py | 201 +++++++++++++----- .../graph_completion_retriever_cot_test.py | 51 +++++ 2 files changed, 201 insertions(+), 51 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 4602dd59d..55cbcfce5 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -1,4 +1,6 @@ +import json from typing import Optional, List, Type, Any +from pydantic import BaseModel from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger @@ -10,6 +12,20 @@ from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt logger = get_logger() +def _as_answer_text(completion: Any) -> str: + """Convert completion to human-readable text for validation and follow-up prompts.""" + if isinstance(completion, str): + return completion + if isinstance(completion, BaseModel): + # Add notice that this is a structured response + json_str = completion.model_dump_json(indent=2) + return f"[Structured Response]\n{json_str}" + try: + return json.dumps(completion, indent=2) + except TypeError: + return str(completion) + + class GraphCompletionCotRetriever(GraphCompletionRetriever): """ Handles graph completion by generating responses based on a series of interactions with @@ -54,6 +70,134 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): self.followup_system_prompt_path = followup_system_prompt_path self.followup_user_prompt_path = followup_user_prompt_path + async def _run_cot_completion( + self, + query: str, + context: Optional[List[Edge]] = None, + max_iter: int = 4, + response_model: Type = str, + ) -> tuple[Any, str, List[Edge]]: + """ + Run chain-of-thought completion with optional structured output. + + Returns: + -------- + - completion_result: The generated completion (string or structured model) + - context_text: The resolved context text + - triplets: The list of triplets used + """ + followup_question = "" + triplets = [] + completion = "" + + for round_idx in range(max_iter + 1): + if round_idx == 0: + if context is None: + triplets = await self.get_context(query) + context_text = await self.resolve_edges_to_text(triplets) + else: + context_text = await self.resolve_edges_to_text(context) + else: + triplets += await self.get_context(followup_question) + context_text = await self.resolve_edges_to_text(list(set(triplets))) + + if response_model is str: + completion = await generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) + else: + args = {"question": query, "context": context_text} + user_prompt = render_prompt(self.user_prompt_path, args) + system_prompt = ( + self.system_prompt + if self.system_prompt + else read_query_prompt(self.system_prompt_path) + ) + + completion = await LLMGateway.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=response_model, + ) + + logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") + + if round_idx < max_iter: + answer_text = _as_answer_text(completion) + valid_args = {"query": query, "answer": answer_text, "context": context_text} + valid_user_prompt = render_prompt( + filename=self.validation_user_prompt_path, context=valid_args + ) + valid_system_prompt = read_query_prompt( + prompt_file_name=self.validation_system_prompt_path + ) + + reasoning = await LLMGateway.acreate_structured_output( + text_input=valid_user_prompt, + system_prompt=valid_system_prompt, + response_model=str, + ) + followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning} + followup_prompt = render_prompt( + filename=self.followup_user_prompt_path, context=followup_args + ) + followup_system = read_query_prompt( + prompt_file_name=self.followup_system_prompt_path + ) + + followup_question = await LLMGateway.acreate_structured_output( + text_input=followup_prompt, system_prompt=followup_system, response_model=str + ) + logger.info( + f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" + ) + + return completion, context_text, triplets + + async def get_structured_completion( + self, + query: str, + context: Optional[List[Edge]] = None, + max_iter: int = 4, + response_model: Type = str, + ) -> Any: + """ + Generate structured completion responses based on a user query and contextual information. + + This method applies the same chain-of-thought logic as get_completion but returns + structured output using the provided response model. + + Parameters: + ----------- + - query (str): The user's query to be processed and answered. + - context (Optional[List[Edge]]): Optional context that may assist in answering the query. + If not provided, it will be fetched based on the query. (default None) + - max_iter: The maximum number of iterations to refine the answer and generate + follow-up questions. (default 4) + - response_model (Type): The Pydantic model type for structured output. (default str) + + Returns: + -------- + - Any: The generated structured completion based on the response model. + """ + completion, context_text, triplets = await self._run_cot_completion( + query=query, + context=context, + max_iter=max_iter, + response_model=response_model, + ) + + if self.save_interaction and context and triplets and completion: + await self.save_qa( + question=query, answer=str(completion), context=context_text, triplets=triplets + ) + + return completion + async def get_completion( self, query: str, @@ -82,57 +226,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - followup_question = "" - triplets = [] - completion = "" - - for round_idx in range(max_iter + 1): - if round_idx == 0: - if context is None: - triplets = await self.get_context(query) - context_text = await self.resolve_edges_to_text(triplets) - else: - context_text = await self.resolve_edges_to_text(context) - else: - triplets += await self.get_context(followup_question) - context_text = await self.resolve_edges_to_text(list(set(triplets))) - - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) - logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") - if round_idx < max_iter: - valid_args = {"query": query, "answer": completion, "context": context_text} - valid_user_prompt = render_prompt( - filename=self.validation_user_prompt_path, context=valid_args - ) - valid_system_prompt = read_query_prompt( - prompt_file_name=self.validation_system_prompt_path - ) - - reasoning = await LLMGateway.acreate_structured_output( - text_input=valid_user_prompt, - system_prompt=valid_system_prompt, - response_model=str, - ) - followup_args = {"query": query, "answer": completion, "reasoning": reasoning} - followup_prompt = render_prompt( - filename=self.followup_user_prompt_path, context=followup_args - ) - followup_system = read_query_prompt( - prompt_file_name=self.followup_system_prompt_path - ) - - followup_question = await LLMGateway.acreate_structured_output( - text_input=followup_prompt, system_prompt=followup_system, response_model=str - ) - logger.info( - f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" - ) + completion, context_text, triplets = await self._run_cot_completion( + query=query, + context=context, + max_iter=max_iter, + response_model=str, + ) if self.save_interaction and context and triplets and completion: await self.save_qa( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 206cfaf84..7fcfe0d6b 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -2,6 +2,7 @@ import os import pytest import pathlib from typing import Optional, Union +from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +class TestAnswer(BaseModel): + answer: str + explanation: str + + class TestGraphCompletionCoTRetriever: @pytest.mark.asyncio async def test_graph_completion_cot_context_simple(self): @@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever: assert all(isinstance(item, str) and item.strip() for item in answer), ( "Answer must contain only non-empty strings" ) + + @pytest.mark.asyncio + async def test_get_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1) + + entities = [company1, person1] + await add_data_points(entities) + + retriever = GraphCompletionCotRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_structured_completion("Who works at Figma?") + assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}" + assert string_answer.strip(), "Answer should not be empty" + + # Test with structured response model + structured_answer = await retriever.get_structured_completion( + "Who works at Figma?", response_model=TestAnswer + ) + assert isinstance(structured_answer, TestAnswer), ( + f"Expected TestAnswer, got {type(structured_answer).__name__}" + ) + assert structured_answer.answer.strip(), "Answer field should not be empty" + assert structured_answer.explanation.strip(), "Explanation field should not be empty"