From 4ab53c9d64a1cde20c6b38e78eb2583bb43fbf65 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 7 Nov 2025 10:00:17 +0100 Subject: [PATCH] changes based on PR comments --- cognee/modules/retrieval/base_graph_retriever.py | 10 +++++++--- cognee/modules/retrieval/base_retriever.py | 10 +++++++--- ...d_output_tests.py => structured_output_test.py} | 14 ++++++-------- .../modules/retrieval/summaries_retriever_test.py | 2 +- 4 files changed, 21 insertions(+), 15 deletions(-) rename cognee/tests/unit/modules/retrieval/{structured_output_tests.py => structured_output_test.py} (94%) diff --git a/cognee/modules/retrieval/base_graph_retriever.py b/cognee/modules/retrieval/base_graph_retriever.py index b0abc2991..b203309ba 100644 --- a/cognee/modules/retrieval/base_graph_retriever.py +++ b/cognee/modules/retrieval/base_graph_retriever.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional, Type from abc import ABC, abstractmethod from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None - ) -> str: + self, + query: str, + context: Optional[List[Edge]] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context (triplets).""" pass diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py index 1533dd44f..b88c741b8 100644 --- a/cognee/modules/retrieval/base_retriever.py +++ b/cognee/modules/retrieval/base_retriever.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, Type, List class BaseRetriever(ABC): @@ -12,7 +12,11 @@ class BaseRetriever(ABC): @abstractmethod async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None - ) -> Any: + self, + query: str, + context: Optional[Any] = None, + session_id: Optional[str] = None, + response_model: Type = str, + ) -> List[Any]: """Generates a response using the query and optional context.""" pass diff --git a/cognee/tests/unit/modules/retrieval/structured_output_tests.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py similarity index 94% rename from cognee/tests/unit/modules/retrieval/structured_output_tests.py rename to cognee/tests/unit/modules/retrieval/structured_output_test.py index 95b4b9c20..4ad3019ff 100644 --- a/cognee/tests/unit/modules/retrieval/structured_output_tests.py +++ b/cognee/tests/unit/modules/retrieval/structured_output_test.py @@ -196,11 +196,9 @@ class TestStructuredOutputCompletion: entities = [entity] await add_data_points(entities) - await asyncio.gather( - _test_get_structured_graph_completion_cot(), - _test_get_structured_graph_completion(), - _test_get_structured_graph_completion_temporal(), - _test_get_structured_graph_completion_rag(), - _test_get_structured_graph_completion_context_extension(), - _test_get_structured_entity_completion(), - ) + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index fc96081bf..5f4b93425 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -class TextSummariesRetriever: +class TestSummariesRetriever: @pytest.mark.asyncio async def test_chunk_context(self): system_directory_path = os.path.join(