From 98e8d226ebf5e70f61e486dd69a2997752ecc0cf Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 16 Jan 2026 11:57:03 +0100 Subject: [PATCH] test: add tests for batch query graph completions --- .../graph_completion_cot_retriever.py | 14 +- .../retrieval/graph_completion_retriever.py | 2 +- ...letion_retriever_context_extension_test.py | 270 ++++++++++++++++++ .../graph_completion_retriever_cot_test.py | 129 +++++++++ .../graph_completion_retriever_test.py | 156 ++++++++++ 5 files changed, 567 insertions(+), 4 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index eec8ba101..160419ee9 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -171,7 +171,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def get_completion( self, query: Optional[str] = None, - context: Optional[List[Edge]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, max_iter=4, response_model: Type = str, @@ -216,16 +216,22 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): completion_results = [] if query_batch and len(query_batch) > 0: + if not context: + # Having a list is necessary to zip through it + context = [] + for query in query_batch: + context.append(None) + completion_results = await asyncio.gather( *[ self._run_cot_completion( query=query, - context=context, + context=context_el, conversation_history=conversation_history, max_iter=max_iter, response_model=response_model, ) - for query in query_batch + for query, context_el in zip(query_batch, context) ] ) else: @@ -237,11 +243,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): response_model=response_model, ) + # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: await self.save_qa( question=query, answer=str(completion), context=context_text, triplets=triplets ) + # TODO: Handle session save interaction for batch queries # Save to session cache if enabled if session_save: context_summary = await summarize_text(context_text) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 5c2c3bf1b..f740496d0 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -262,7 +262,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): session_id=session_id, ) - return [completion] + return completion if isinstance(completion, list) else [completion] async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: """ diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 9095af69c..567edde51 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -1,4 +1,5 @@ import pytest +from itertools import cycle from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID @@ -467,3 +468,272 @@ async def test_get_completion_zero_extension_rounds(mock_edge): assert isinstance(completion, list) assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + """Test get_completion batch queries retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_provided_context(mock_edge): + """Test get_completion batch queries uses provided context.""" + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_context_extension_rounds(mock_edge): + """Test get_completion batch queries with multiple context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_context_extension_stops_early(mock_edge): + """Test get_completion batch queries stops early when no new triplets found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # When get_context returns same triplets, the loop should stop early + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=4, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_zero_extension_rounds(mock_edge): + """Test get_completion batch queries with zero context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=0 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + TestModel(answer="Test answer"), + TestModel(answer="Test answer"), + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + response_model=TestModel, + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 9f3147512..5ae901594 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID +from itertools import cycle from cognee.modules.retrieval.graph_completion_cot_retriever import ( GraphCompletionCotRetriever, @@ -686,3 +687,131 @@ async def test_as_answer_text_with_basemodel(): assert isinstance(result, str) assert "[Structured Response]" in result assert "test answer" in result + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_context(mock_edge): + """Test get_completion batch queries with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=cycle(["validation_result", "followup_question"]), + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + max_iter=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + """Test get_completion batch queries without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +# +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion of batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index c22f30fd0..ae0adb729 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -646,3 +646,159 @@ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge assert len(completion) == 1 assert completion[0] == "Generated answer" mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_context(mock_edge): + """Test get_completion correctly handles batch queries.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context=[[mock_edge], [mock_edge]] + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"]) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion of batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], response_model=TestModel + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[], []], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"]) + + assert isinstance(completion, list) + assert len(completion) == 2