diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 0e21fe351..5335a3ca7 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -2,6 +2,7 @@ import os import pytest import pathlib from typing import Optional, Union +from pydantic import BaseModel import cognee from cognee.low_level import setup, DataPoint @@ -12,6 +13,11 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor ) +class TestAnswer(BaseModel): + answer: str + explanation: str + + class TestGraphCompletionWithContextExtensionRetriever: @pytest.mark.asyncio async def test_graph_completion_extension_context_simple(self): @@ -175,3 +181,56 @@ class TestGraphCompletionWithContextExtensionRetriever: assert all(isinstance(item, str) and item.strip() for item in answer), ( "Answer must contain only non-empty strings" ) + + @pytest.mark.asyncio + async def test_get_structured_completion_extension_context(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, + ".cognee_system/test_get_structured_completion_extension_context", + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, + ".data_storage/test_get_structured_completion_extension_context", + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1) + + entities = [company1, person1] + await add_data_points(entities) + + retriever = GraphCompletionContextExtensionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + assert isinstance(string_answer, list), f"Expected str, got {type(string_answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in string_answer), ( + "Answer should not be empty" + ) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + assert isinstance(structured_answer, list), ( + f"Expected list, got {type(structured_answer).__name__}" + ) + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( + f"Expected TestAnswer, got {type(structured_answer).__name__}" + ) + + assert structured_answer[0].answer.strip(), "Answer field should not be empty" + assert structured_answer[0].explanation.strip(), "Explanation field should not be empty" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index bf10dc023..731e9fccf 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -219,7 +219,7 @@ class TestGraphCompletionCoTRetriever: assert isinstance(structured_answer, list), ( f"Expected list, got {type(structured_answer).__name__}" ) - assert all(isinstance(item, TestAnswer) for item in string_answer), ( + assert all(isinstance(item, TestAnswer) for item in structured_answer), ( f"Expected TestAnswer, got {type(structured_answer).__name__}" )