From 7e3c24100b0606ecf7d9b71ab76bba9be9c64e68 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 4 Nov 2025 15:09:33 +0100 Subject: [PATCH 1/6] refactor: add structured output to completion retrievers --- .../retrieval/EntityCompletionRetriever.py | 11 +++- .../modules/retrieval/completion_retriever.py | 13 +++- ..._completion_context_extension_retriever.py | 4 ++ .../graph_completion_cot_retriever.py | 65 ++++--------------- .../retrieval/graph_completion_retriever.py | 3 + .../modules/retrieval/temporal_retriever.py | 9 ++- cognee/modules/retrieval/utils/completion.py | 26 +------- .../feedback/generate_improved_answers.py | 6 +- .../graph_completion_retriever_cot_test.py | 20 ++++-- 9 files changed, 67 insertions(+), 90 deletions(-) diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index 6086977ce..1f1ddad0a 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional, List +from typing import Any, Optional, List, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor @@ -85,7 +85,11 @@ class EntityCompletionRetriever(BaseRetriever): return None async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, ) -> List[str]: """ Generate completion using provided context or fetch new context. @@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever): fetched if not provided. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index bb568924d..f071e41de 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional +from typing import Any, Optional, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine @@ -75,7 +75,11 @@ class CompletionRetriever(BaseRetriever): raise NoDataError("No data found in the system, please add data first.") from error async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, ) -> str: """ Generates an LLM completion using the context. @@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever): completion; if None, it retrieves the context for the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if session_save: @@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever): session_id=session_id, ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 58b6b586f..6b2c6a9e6 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -56,6 +56,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): context: Optional[List[Edge]] = None, session_id: Optional[str] = None, context_extension_rounds=4, + response_model: Type = str, ) -> List[str]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -76,6 +77,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): defaults to 'default_session'. (default None) - context_extension_rounds: The maximum number of rounds to extend the context with new triplets before halting. (default 4) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -143,6 +145,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -152,6 +155,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context_text and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 299db6855..39255fe68 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import ( - generate_structured_completion, + generate_completion, summarize_text, ) from cognee.modules.retrieval.utils.session_cache import ( @@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): questions based on reasoning. The public methods are: - get_completion - - get_structured_completion Instance variables include: - validation_system_prompt_path @@ -121,7 +120,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): triplets += await self.get_context(followup_question) context_text = await self.resolve_edges_to_text(list(set(triplets))) - completion = await generate_structured_completion( + completion = await generate_completion( query=query, context=context_text, user_prompt_path=self.user_prompt_path, @@ -165,24 +164,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): return completion, context_text, triplets - async def get_structured_completion( + async def get_completion( self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, - max_iter: int = 4, + max_iter=4, response_model: Type = str, - ) -> Any: + ) -> List[str]: """ - Generate structured completion responses based on a user query and contextual information. + Generate completion responses based on a user query and contextual information. - This method applies the same chain-of-thought logic as get_completion but returns + This method interacts with a language model client to retrieve a structured response, + using a series of iterations to refine the answers and generate follow-up questions + based on reasoning derived from previous outputs. It raises exceptions if the context + retrieval fails or if the model encounters issues in generating outputs. It returns structured output using the provided response model. Parameters: ----------- + - query (str): The user's query to be processed and answered. - - context (Optional[List[Edge]]): Optional context that may assist in answering the query. + - context (Optional[Any]): Optional context that may assist in answering the query. If not provided, it will be fetched based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) @@ -192,7 +195,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): Returns: -------- - - Any: The generated structured completion based on the response model. + + - List[str]: A list containing the generated answer to the user's query. """ # Check if session saving is enabled cache_config = CacheConfig() @@ -228,45 +232,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion - - async def get_completion( - self, - query: str, - context: Optional[List[Edge]] = None, - session_id: Optional[str] = None, - max_iter=4, - ) -> List[str]: - """ - Generate completion responses based on a user query and contextual information. - - This method interacts with a language model client to retrieve a structured response, - using a series of iterations to refine the answers and generate follow-up questions - based on reasoning derived from previous outputs. It raises exceptions if the context - retrieval fails or if the model encounters issues in generating outputs. - - Parameters: - ----------- - - - query (str): The user's query to be processed and answered. - - context (Optional[Any]): Optional context that may assist in answering the query. - If not provided, it will be fetched based on the query. (default None) - - session_id (Optional[str]): Optional session identifier for caching. If None, - defaults to 'default_session'. (default None) - - max_iter: The maximum number of iterations to refine the answer and generate - follow-up questions. (default 4) - - Returns: - -------- - - - List[str]: A list containing the generated answer to the user's query. - """ - completion = await self.get_structured_completion( - query=query, - context=context, - session_id=session_id, - max_iter=max_iter, - response_model=str, - ) - return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index b7ab4edae..b544e8ead 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -146,6 +146,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, + response_model: Type = str, ) -> List[str]: """ Generates a completion using graph connections context based on a query. @@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, + response_model=response_model, ) if self.save_interaction and context and triplets and completion: diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index ec68d37bb..38d69ec80 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -146,7 +146,11 @@ class TemporalRetriever(GraphCompletionRetriever): return self.descriptions_to_string(top_k_events) async def get_completion( - self, query: str, context: Optional[str] = None, session_id: Optional[str] = None + self, + query: str, + context: Optional[str] = None, + session_id: Optional[str] = None, + response_model: Type = str, ) -> List[str]: """ Generates a response using the query and optional context. @@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever): retrieved based on the query. (default None) - session_id (Optional[str]): Optional session identifier for caching. If None, defaults to 'default_session'. (default None) + - response_model (Type): The Pydantic model type for structured output. (default str) Returns: -------- @@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, conversation_history=conversation_history, + response_model=response_model, ), ) else: @@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever): context=context, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, + response_model=response_model, ) if session_save: diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index db7a10252..b77d7ef90 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt -async def generate_structured_completion( +async def generate_completion( query: str, context: str, user_prompt_path: str, @@ -11,8 +11,8 @@ async def generate_structured_completion( system_prompt: Optional[str] = None, conversation_history: Optional[str] = None, response_model: Type = str, -) -> Any: - """Generates a structured completion using LLM with given context and prompts.""" +) -> str: + """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = render_prompt(user_prompt_path, args) system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path) @@ -28,26 +28,6 @@ async def generate_structured_completion( ) -async def generate_completion( - query: str, - context: str, - user_prompt_path: str, - system_prompt_path: str, - system_prompt: Optional[str] = None, - conversation_history: Optional[str] = None, -) -> str: - """Generates a completion using LLM with given context and prompts.""" - return await generate_structured_completion( - query=query, - context=context, - user_prompt_path=user_prompt_path, - system_prompt_path=system_prompt_path, - system_prompt=system_prompt, - conversation_history=conversation_history, - response_model=str, - ) - - async def summarize_text( text: str, system_prompt_path: str = "summarize_search_results.txt", diff --git a/cognee/tasks/feedback/generate_improved_answers.py b/cognee/tasks/feedback/generate_improved_answers.py index e439cf9e5..d2b143d29 100644 --- a/cognee/tasks/feedback/generate_improved_answers.py +++ b/cognee/tasks/feedback/generate_improved_answers.py @@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction( ) retrieved_context = await retriever.get_context(query_text) - completion = await retriever.get_structured_completion( + completion = await retriever.get_completion( query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse, @@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction( new_context_text = await retriever.resolve_edges_to_text(retrieved_context) if completion: - enrichment.improved_answer = completion.answer + enrichment.improved_answer = completion[0].answer enrichment.new_context = new_context_text - enrichment.explanation = completion.explanation + enrichment.explanation = completion[0].explanation return enrichment else: logger.warning( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 7fcfe0d6b..bf10dc023 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -206,16 +206,22 @@ class TestGraphCompletionCoTRetriever: retriever = GraphCompletionCotRetriever() # Test with string response model (default) - string_answer = await retriever.get_structured_completion("Who works at Figma?") - assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}" - assert string_answer.strip(), "Answer should not be empty" + string_answer = await retriever.get_completion("Who works at Figma?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) # Test with structured response model - structured_answer = await retriever.get_structured_completion( + structured_answer = await retriever.get_completion( "Who works at Figma?", response_model=TestAnswer ) - assert isinstance(structured_answer, TestAnswer), ( + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in string_answer), ( f"Expected TestAnswer, got {type(structured_answer).__name__}" ) - assert structured_answer.answer.strip(), "Answer field should not be empty" - assert structured_answer.explanation.strip(), "Explanation field should not be empty" + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" From 33b05163811c7f7f74db39abfea711103b823132 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 4 Nov 2025 15:27:03 +0100 Subject: [PATCH 2/6] test: fix completion tests --- ...letion_retriever_context_extension_test.py | 59 +++++++++++++++++++ .../graph_completion_retriever_cot_test.py | 2 +- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 0e21fe351..5335a3ca7 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -2,6 +2,7 @@ import os import pytest import pathlib from typing import Optional, Union +from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -12,6 +13,11 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor ) +class TestAnswer(BaseModel): + answer: str + explanation: str + + class TestGraphCompletionWithContextExtensionRetriever: @pytest.mark.asyncio async def test_graph_completion_extension_context_simple(self): @@ -175,3 +181,56 @@ class TestGraphCompletionWithContextExtensionRetriever: assert all(isinstance(item, str) and item.strip() for item in answer), ( "Answer must contain only non-empty strings" ) + + @pytest.mark.asyncio + async def test_get_structured_completion_extension_context(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, + ".cognee_system/test_get_structured_completion_extension_context", + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, + ".data_storage/test_get_structured_completion_extension_context", + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1) + + entities = [company1, person1] + await add_data_points(entities) + + retriever = GraphCompletionContextExtensionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( + f"Expected TestAnswer, got {type(structured_answer).__name__}" + ) + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index bf10dc023..731e9fccf 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -219,7 +219,7 @@ class TestGraphCompletionCoTRetriever: assert isinstance(structured_answer, list), ( f"Expected list, got {type(structured_answer).__name__}" ) - assert all(isinstance(item, TestAnswer) for item in string_answer), ( + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( f"Expected TestAnswer, got {type(structured_answer).__name__}" ) From 215ef7f3c213ea0be6b0a295be8069ba0364878d Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 5 Nov 2025 17:29:40 +0100 Subject: [PATCH 3/6] test: add retriever tests --- .../entity_completion_retriever_test.py | 65 +++++++++++++++ ...letion_retriever_context_extension_test.py | 6 +- .../graph_completion_retriever_cot_test.py | 6 +- .../graph_completion_retriever_test.py | 57 +++++++++++++ .../rag_completion_retriever_test.py | 79 +++++++++++++++++++ .../retrieval/temporal_retriever_test.py | 64 +++++++++++++++ 6 files changed, 271 insertions(+), 6 deletions(-) create mode 100644 cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py diff --git a/cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py new file mode 100644 index 000000000..064f4a31a --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py @@ -0,0 +1,65 @@ +import os +import pytest +import pathlib +from pydantic import BaseModel + +import cognee +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever +from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor +from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider + + +class TestAnswer(BaseModel): + answer: str + explanation: str + + +# TODO: Add more tests, similar to other retrievers. +# TODO: For the tests, one needs to define an Entity Extractor and a Context Provider. +class TestEntityCompletionRetriever: + @pytest.mark.asyncio + async def test_get_entity_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_entity_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_entity_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who is Albert Einstein?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who is Albert Einstein?", response_model=TestAnswer + ) + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( + f"Expected TestAnswer, got {type(structured_answer).__name__}" + ) + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 5335a3ca7..d15e55c23 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -183,15 +183,15 @@ class TestGraphCompletionWithContextExtensionRetriever: ) @pytest.mark.asyncio - async def test_get_structured_completion_extension_context(self): + async def test_get_graph_structured_completion_extension_context(self): system_directory_path = os.path.join( pathlib.Path(__file__).parent, - ".cognee_system/test_get_structured_completion_extension_context", + ".cognee_system/test_get_graph_structured_completion_extension_context", ) cognee.config.system_root_directory(system_directory_path) data_directory_path = os.path.join( pathlib.Path(__file__).parent, - ".data_storage/test_get_structured_completion_extension_context", + ".data_storage/test_get_graph_structured_completion_extension_context", ) cognee.config.data_root_directory(data_directory_path) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 731e9fccf..79e4bcec3 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -176,13 +176,13 @@ class TestGraphCompletionCoTRetriever: ) @pytest.mark.asyncio - async def test_get_structured_completion(self): + async def test_get_graph_structured_completion_cot(self): system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" + pathlib.Path(__file__).parent, ".cognee_system/test_get_graph_structured_completion_cot" ) cognee.config.system_root_directory(system_directory_path) data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" + pathlib.Path(__file__).parent, ".data_storage/test_get_graph_structured_completion_cot" ) cognee.config.data_root_directory(data_directory_path) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index f462baced..e320fcef1 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -2,6 +2,7 @@ import os import pytest import pathlib from typing import Optional, Union +from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +class TestAnswer(BaseModel): + answer: str + explanation: str + + class TestGraphCompletionRetriever: @pytest.mark.asyncio async def test_graph_completion_context_simple(self): @@ -221,3 +227,54 @@ class TestGraphCompletionRetriever: context = await retriever.get_context("Who works at Figma?") assert context == [], "Context should be empty on an empty graph" + + @pytest.mark.asyncio + async def test_get_graph_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_graph_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_graph_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1) + + entities = [company1, person1] + await add_data_points(entities) + + retriever = GraphCompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( + f"Expected TestAnswer, got {type(structured_answer).__name__}" + ) + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 252af8352..248ecc047 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -3,6 +3,7 @@ from typing import List import pytest import pathlib import cognee +from pydantic import BaseModel from cognee.low_level import setup from cognee.tasks.storage import add_data_points from cognee.infrastructure.databases.vector import get_vector_engine @@ -26,6 +27,11 @@ class DocumentChunkWithEntities(DataPoint): metadata: dict = {"index_fields": ["text"]} +class TestAnswer(BaseModel): + answer: str + explanation: str + + class TestRAGCompletionRetriever: @pytest.mark.asyncio async def test_rag_completion_context_simple(self): @@ -202,3 +208,76 @@ class TestRAGCompletionRetriever: context = await retriever.get_context("Christina Mayer") assert context == "", "Returned context should be empty on an empty graph" + + @pytest.mark.asyncio + async def test_get_rag_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_rag_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_rag_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + retriever = CompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Where does Steve work?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Where does Steve work?", response_model=TestAnswer + ) + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( + f"Expected TestAnswer, got {type(structured_answer).__name__}" + ) + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index a322cb237..5b274c822 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,7 +1,13 @@ import asyncio +import os +import pathlib +import cognee from types import SimpleNamespace import pytest +from pydantic import BaseModel +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.temporal_retriever import TemporalRetriever @@ -141,6 +147,64 @@ async def test_filter_top_k_events_error_handling(): await tr.filter_top_k_events([{}], []) +class TestAnswer(BaseModel): + answer: str + explanation: str + + +@pytest.mark.asyncio +async def test_get_temporal_structured_completion(): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_temporal_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_temporal_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + retriever = TemporalRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("When did Steve start working at Figma?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "When did Steve start working at Figma??", response_model=TestAnswer + ) + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( + f"Expected TestAnswer, got {type(structured_answer).__name__}" + ) + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" + + class _FakeRetriever(TemporalRetriever): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From da5055a0a96ce9647fef7245ab312301e4237165 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 6 Nov 2025 17:11:15 +0100 Subject: [PATCH 4/6] test: add one test that covers all retrievers. delete others --- .../retrieval/EntityCompletionRetriever.py | 2 +- .../modules/retrieval/completion_retriever.py | 4 +- ..._completion_context_extension_retriever.py | 4 +- .../graph_completion_cot_retriever.py | 2 +- .../retrieval/graph_completion_retriever.py | 2 +- .../modules/retrieval/temporal_retriever.py | 2 +- cognee/modules/retrieval/utils/completion.py | 2 +- .../entity_completion_retriever_test.py | 65 ------ ...letion_retriever_context_extension_test.py | 60 ----- .../graph_completion_retriever_cot_test.py | 58 ----- .../graph_completion_retriever_test.py | 58 ----- .../rag_completion_retriever_test.py | 81 +------ .../retrieval/structured_output_tests.py | 206 ++++++++++++++++++ .../retrieval/temporal_retriever_test.py | 66 ------ 14 files changed, 216 insertions(+), 396 deletions(-) delete mode 100644 cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/structured_output_tests.py diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py index 1f1ddad0a..14996f902 100644 --- a/cognee/modules/retrieval/EntityCompletionRetriever.py +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -90,7 +90,7 @@ class EntityCompletionRetriever(BaseRetriever): context: Optional[Any] = None, session_id: Optional[str] = None, response_model: Type = str, - ) -> List[str]: + ) -> List[Any]: """ Generate completion using provided context or fetch new context. diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index f071e41de..126ebcab8 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional, Type +from typing import Any, Optional, Type, List from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine @@ -80,7 +80,7 @@ class CompletionRetriever(BaseRetriever): context: Optional[Any] = None, session_id: Optional[str] = None, response_model: Type = str, - ) -> str: + ) -> List[Any]: """ Generates an LLM completion using the context. diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 6b2c6a9e6..b07d11fd2 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional, List, Type +from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -57,7 +57,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id: Optional[str] = None, context_extension_rounds=4, response_model: Type = str, - ) -> List[str]: + ) -> List[Any]: """ Extends the context for a given query by retrieving related triplets and generating new completions based on them. diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 39255fe68..eb8f502cb 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -171,7 +171,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id: Optional[str] = None, max_iter=4, response_model: Type = str, - ) -> List[str]: + ) -> List[Any]: """ Generate completion responses based on a user query and contextual information. diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index b544e8ead..df77a11ac 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -147,7 +147,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): context: Optional[List[Edge]] = None, session_id: Optional[str] = None, response_model: Type = str, - ) -> List[str]: + ) -> List[Any]: """ Generates a completion using graph connections context based on a query. diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 38d69ec80..f3da02c15 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -151,7 +151,7 @@ class TemporalRetriever(GraphCompletionRetriever): context: Optional[str] = None, session_id: Optional[str] = None, response_model: Type = str, - ) -> List[str]: + ) -> List[Any]: """ Generates a response using the query and optional context. diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index b77d7ef90..c90ce77f4 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -11,7 +11,7 @@ async def generate_completion( system_prompt: Optional[str] = None, conversation_history: Optional[str] = None, response_model: Type = str, -) -> str: +) -> Any: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} user_prompt = render_prompt(user_prompt_path, args) diff --git a/cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py deleted file mode 100644 index 064f4a31a..000000000 --- a/cognee/tests/unit/modules/retrieval/entity_completion_retriever_test.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import pytest -import pathlib -from pydantic import BaseModel - -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.modules.engine.models import Entity, EntityType -from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever -from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor -from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider - - -class TestAnswer(BaseModel): - answer: str - explanation: str - - -# TODO: Add more tests, similar to other retrievers. -# TODO: For the tests, one needs to define an Entity Extractor and a Context Provider. -class TestEntityCompletionRetriever: - @pytest.mark.asyncio - async def test_get_entity_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_entity_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_entity_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - entity_type = EntityType(name="Person", description="A human individual") - entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") - - entities = [entity] - await add_data_points(entities) - - retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who is Albert Einstein?") - assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in string_answer), ( - "Answer should not be empty" - ) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who is Albert Einstein?", response_model=TestAnswer - ) - assert isinstance(structured_answer, list), ( - f"Expected list, got {type(structured_answer).__name__}" - ) - assert all(isinstance(item, TestAnswer) for item in structured_answer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - - assert structured_answer[0].answer.strip(), "Answer field should not be empty" - assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index d15e55c23..29c8b7c95 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -2,7 +2,6 @@ import os import pytest import pathlib from typing import Optional, Union -from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -12,12 +11,6 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor GraphCompletionContextExtensionRetriever, ) - -class TestAnswer(BaseModel): - answer: str - explanation: str - - class TestGraphCompletionWithContextExtensionRetriever: @pytest.mark.asyncio async def test_graph_completion_extension_context_simple(self): @@ -181,56 +174,3 @@ class TestGraphCompletionWithContextExtensionRetriever: assert all(isinstance(item, str) and item.strip() for item in answer), ( "Answer must contain only non-empty strings" ) - - @pytest.mark.asyncio - async def test_get_graph_structured_completion_extension_context(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_structured_completion_extension_context", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_structured_completion_extension_context", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1) - - entities = [company1, person1] - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in string_answer), ( - "Answer should not be empty" - ) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - assert isinstance(structured_answer, list), ( - f"Expected list, got {type(structured_answer).__name__}" - ) - assert all(isinstance(item, TestAnswer) for item in structured_answer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - - assert structured_answer[0].answer.strip(), "Answer field should not be empty" - assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 79e4bcec3..ac58793be 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -2,7 +2,6 @@ import os import pytest import pathlib from typing import Optional, Union -from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -10,12 +9,6 @@ from cognee.modules.graph.utils import resolve_edges_to_text from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever - -class TestAnswer(BaseModel): - answer: str - explanation: str - - class TestGraphCompletionCoTRetriever: @pytest.mark.asyncio async def test_graph_completion_cot_context_simple(self): @@ -174,54 +167,3 @@ class TestGraphCompletionCoTRetriever: assert all(isinstance(item, str) and item.strip() for item in answer), ( "Answer must contain only non-empty strings" ) - - @pytest.mark.asyncio - async def test_get_graph_structured_completion_cot(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_graph_structured_completion_cot" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_graph_structured_completion_cot" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1) - - entities = [company1, person1] - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in string_answer), ( - "Answer should not be empty" - ) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - assert isinstance(structured_answer, list), ( - f"Expected list, got {type(structured_answer).__name__}" - ) - assert all(isinstance(item, TestAnswer) for item in structured_answer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - - assert structured_answer[0].answer.strip(), "Answer field should not be empty" - assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index e320fcef1..21e2af199 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -2,7 +2,6 @@ import os import pytest import pathlib from typing import Optional, Union -from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -10,12 +9,6 @@ from cognee.modules.graph.utils import resolve_edges_to_text from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever - -class TestAnswer(BaseModel): - answer: str - explanation: str - - class TestGraphCompletionRetriever: @pytest.mark.asyncio async def test_graph_completion_context_simple(self): @@ -227,54 +220,3 @@ class TestGraphCompletionRetriever: context = await retriever.get_context("Who works at Figma?") assert context == [], "Context should be empty on an empty graph" - - @pytest.mark.asyncio - async def test_get_graph_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_graph_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_graph_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1) - - entities = [company1, person1] - await add_data_points(entities) - - retriever = GraphCompletionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in string_answer), ( - "Answer should not be empty" - ) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - assert isinstance(structured_answer, list), ( - f"Expected list, got {type(structured_answer).__name__}" - ) - assert all(isinstance(item, TestAnswer) for item in structured_answer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - - assert structured_answer[0].answer.strip(), "Answer field should not be empty" - assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 248ecc047..37876794f 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -3,7 +3,7 @@ from typing import List import pytest import pathlib import cognee -from pydantic import BaseModel + from cognee.low_level import setup from cognee.tasks.storage import add_data_points from cognee.infrastructure.databases.vector import get_vector_engine @@ -26,12 +26,6 @@ class DocumentChunkWithEntities(DataPoint): metadata: dict = {"index_fields": ["text"]} - -class TestAnswer(BaseModel): - answer: str - explanation: str - - class TestRAGCompletionRetriever: @pytest.mark.asyncio async def test_rag_completion_context_simple(self): @@ -208,76 +202,3 @@ class TestRAGCompletionRetriever: context = await retriever.get_context("Christina Mayer") assert context == "", "Returned context should be empty on an empty graph" - - @pytest.mark.asyncio - async def test_get_rag_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_rag_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_rag_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - await add_data_points(entities) - - retriever = CompletionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Where does Steve work?") - assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in string_answer), ( - "Answer should not be empty" - ) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Where does Steve work?", response_model=TestAnswer - ) - assert isinstance(structured_answer, list), ( - f"Expected list, got {type(structured_answer).__name__}" - ) - assert all(isinstance(item, TestAnswer) for item in structured_answer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - - assert structured_answer[0].answer.strip(), "Answer field should not be empty" - assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/structured_output_tests.py b/cognee/tests/unit/modules/retrieval/structured_output_tests.py new file mode 100644 index 000000000..95b4b9c20 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/structured_output_tests.py @@ -0,0 +1,206 @@ +import asyncio + +import pytest +import cognee +import pathlib +import os + +from pydantic import BaseModel +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor +from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever + + +class TestAnswer(BaseModel): + answer: str + explanation: str + + +def _assert_string_answer(answer: list[str]): + assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings" + assert all(item.strip() for item in answer), "Items should not be empty" + + +def _assert_structured_answer(answer: list[TestAnswer]): + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer" + assert all(x.answer.strip() for x in answer), "Answer text should not be empty" + assert all(x.explanation.strip() for x in answer), "Explanation should not be empty" + + +async def _test_get_structured_graph_completion_cot(): + retriever = GraphCompletionCotRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion(): + retriever = GraphCompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_temporal(): + retriever = TemporalRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("When did Steve start working at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "When did Steve start working at Figma??", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_rag(): + retriever = CompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Where does Steve work?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Where does Steve work?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_context_extension(): + retriever = GraphCompletionContextExtensionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_entity_completion(): + retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who is Albert Einstein?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who is Albert Einstein?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +class TestStructuredOutputCompletion: + @pytest.mark.asyncio + async def test_get_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + await asyncio.gather( + _test_get_structured_graph_completion_cot(), + _test_get_structured_graph_completion(), + _test_get_structured_graph_completion_temporal(), + _test_get_structured_graph_completion_rag(), + _test_get_structured_graph_completion_context_extension(), + _test_get_structured_entity_completion(), + ) diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index 5b274c822..22b2d3fe9 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,13 +1,6 @@ -import asyncio -import os -import pathlib -import cognee from types import SimpleNamespace import pytest -from pydantic import BaseModel -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.temporal_retriever import TemporalRetriever @@ -146,65 +139,6 @@ async def test_filter_top_k_events_error_handling(): with pytest.raises((KeyError, TypeError)): await tr.filter_top_k_events([{}], []) - -class TestAnswer(BaseModel): - answer: str - explanation: str - - -@pytest.mark.asyncio -async def test_get_temporal_structured_completion(): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_temporal_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_temporal_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - works_since: int - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) - - entities = [company1, person1] - await add_data_points(entities) - - retriever = TemporalRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("When did Steve start working at Figma?") - assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in string_answer), ( - "Answer should not be empty" - ) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "When did Steve start working at Figma??", response_model=TestAnswer - ) - assert isinstance(structured_answer, list), ( - f"Expected list, got {type(structured_answer).__name__}" - ) - assert all(isinstance(item, TestAnswer) for item in structured_answer), ( - f"Expected TestAnswer, got {type(structured_answer).__name__}" - ) - - assert structured_answer[0].answer.strip(), "Answer field should not be empty" - assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" - - class _FakeRetriever(TemporalRetriever): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From 72ba8d0dcb0306cfc5c618c854089bd68f4b9d3f Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 6 Nov 2025 17:12:33 +0100 Subject: [PATCH 5/6] chore: ruff format --- .../graph_completion_retriever_context_extension_test.py | 1 + .../modules/retrieval/graph_completion_retriever_cot_test.py | 1 + .../unit/modules/retrieval/graph_completion_retriever_test.py | 1 + .../unit/modules/retrieval/rag_completion_retriever_test.py | 1 + cognee/tests/unit/modules/retrieval/temporal_retriever_test.py | 1 + 5 files changed, 5 insertions(+) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 29c8b7c95..0e21fe351 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -11,6 +11,7 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor GraphCompletionContextExtensionRetriever, ) + class TestGraphCompletionWithContextExtensionRetriever: @pytest.mark.asyncio async def test_graph_completion_extension_context_simple(self): diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index ac58793be..206cfaf84 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -9,6 +9,7 @@ from cognee.modules.graph.utils import resolve_edges_to_text from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever + class TestGraphCompletionCoTRetriever: @pytest.mark.asyncio async def test_graph_completion_cot_context_simple(self): diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 21e2af199..f462baced 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -9,6 +9,7 @@ from cognee.modules.graph.utils import resolve_edges_to_text from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever + class TestGraphCompletionRetriever: @pytest.mark.asyncio async def test_graph_completion_context_simple(self): diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 37876794f..9bfed68f3 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -26,6 +26,7 @@ class DocumentChunkWithEntities(DataPoint): metadata: dict = {"index_fields": ["text"]} + class TestRAGCompletionRetriever: @pytest.mark.asyncio async def test_rag_completion_context_simple(self): diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index 22b2d3fe9..c3c6a47f6 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -139,6 +139,7 @@ async def test_filter_top_k_events_error_handling(): with pytest.raises((KeyError, TypeError)): await tr.filter_top_k_events([{}], []) + class _FakeRetriever(TemporalRetriever): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From 4ab53c9d64a1cde20c6b38e78eb2583bb43fbf65 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 7 Nov 2025 10:00:17 +0100 Subject: [PATCH 6/6] changes based on PR comments --- cognee/modules/retrieval/base_graph_retriever.py | 10 +++++++--- cognee/modules/retrieval/base_retriever.py | 10 +++++++--- ...d_output_tests.py => structured_output_test.py} | 14 ++++++-------- .../modules/retrieval/summaries_retriever_test.py | 2 +- 4 files changed, 21 insertions(+), 15 deletions(-) rename cognee/tests/unit/modules/retrieval/{structured_output_tests.py => structured_output_test.py} (94%) diff --git a/cognee/modules/retrieval/base_graph_retriever.py b/cognee/modules/retrieval/base_graph_retriever.py index b0abc2991..b203309ba 100644 --- a/cognee/modules/retrieval/base_graph_retriever.py +++ b/cognee/modules/retrieval/base_graph_retriever.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional, Type from abc import ABC, abstractmethod from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None - ) -> str: + self, + query: str, + context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context (triplets).""" pass diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py index 1533dd44f..b88c741b8 100644 --- a/cognee/modules/retrieval/base_retriever.py +++ b/cognee/modules/retrieval/base_retriever.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, Type, List class BaseRetriever(ABC): @@ -12,7 +12,11 @@ class BaseRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> Any: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context.""" pass diff --git a/cognee/tests/unit/modules/retrieval/structured_output_tests.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py similarity index 94% rename from cognee/tests/unit/modules/retrieval/structured_output_tests.py rename to cognee/tests/unit/modules/retrieval/structured_output_test.py index 95b4b9c20..4ad3019ff 100644 --- a/cognee/tests/unit/modules/retrieval/structured_output_tests.py +++ b/cognee/tests/unit/modules/retrieval/structured_output_test.py @@ -196,11 +196,9 @@ class TestStructuredOutputCompletion: entities = [entity] await add_data_points(entities) - await asyncio.gather( - _test_get_structured_graph_completion_cot(), - _test_get_structured_graph_completion(), - _test_get_structured_graph_completion_temporal(), - _test_get_structured_graph_completion_rag(), - _test_get_structured_graph_completion_context_extension(), - _test_get_structured_entity_completion(), - ) + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index fc96081bf..5f4b93425 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -class TextSummariesRetriever: +class TestSummariesRetriever: @pytest.mark.asyncio async def test_chunk_context(self): system_directory_path = os.path.join(