From a0a3283d8fcad72321e6024140b700d6568594f9 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 19:23:34 +0100 Subject: [PATCH] Revert "feat: add batch queries to save interaction" This reverts commit 5c8475a92a5f0eeb2c8fd3270bfc657d48c1d38a. --- ..._completion_context_extension_retriever.py | 20 ++--- .../graph_completion_cot_retriever.py | 32 +++---- .../retrieval/graph_completion_retriever.py | 27 ++---- ...letion_retriever_context_extension_test.py | 69 --------------- .../graph_completion_retriever_cot_test.py | 84 ------------------- .../graph_completion_retriever_test.py | 41 --------- 6 files changed, 29 insertions(+), 244 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index afbaa2978..0dc3a8bf6 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -104,6 +104,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -260,17 +264,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): for batched_query in query_batch: result_completion_batch.append(finished_queries_states[batched_query].completion) + # TODO: Do batch queries for save interaction if self.save_interaction and context_text_batch and triplets_batch and completion_batch: - await asyncio.gather( - *[ - self.save_qa( - question=batched_query, - answer=finished_queries_states[batched_query].completion, - context=finished_queries_states[batched_query].context_text, - triplets=finished_queries_states[batched_query].triplets, - ) - for batched_query in query_batch - ] + await self.save_qa( + question=query, + answer=completion_batch[0], + context=context_text_batch[0], + triplets=triplets_batch[0], ) if session_save: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 76f27b65e..114578aa9 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -332,6 +332,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -351,28 +355,14 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): response_model=response_model, ) + # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: - if query_batch: - await asyncio.gather( - *[ - self.save_qa( - question=batched_query, - answer=str(batched_completion), - context=batched_context_text, - triplets=batched_triplet, - ) - for batched_query, batched_completion, batched_context_text, batched_triplet in zip( - query_batch, completion, context_text, triplets - ) - ] - ) - else: - await self.save_qa( - question=query, - answer=str(completion[0]), - context=context_text[0], - triplets=triplets[0], - ) + await self.save_qa( + question=query, + answer=str(completion[0]), + context=context_text[0], + triplets=triplets[0], + ) # TODO: Handle session save interaction for batch queries # Save to session cache if enabled diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index cbaec599f..d9667a669 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -220,6 +220,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -232,7 +236,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): context_text = "" context_text_batch = [] - if query_batch: + if triplets and isinstance(triplets[0], list): context_text_batch = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) @@ -280,24 +284,9 @@ class GraphCompletionRetriever(BaseGraphRetriever): ) if self.save_interaction and context and triplets and completion: - if query: - await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets - ) - else: - await asyncio.gather( - *[ - await self.save_qa( - question=batched_query, - answer=batched_completion, - context=batched_context_text, - triplets=batched_triplets, - ) - for batched_query, batched_completion, batched_context_text, batched_triplets in zip( - query_batch, completion, context_text_batch, triplets - ) - ] - ) + await self.save_qa( + question=query, answer=completion, context=context_text, triplets=triplets + ) if session_save: await save_conversation_history( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index cec71a443..9ceca96e2 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -804,72 +804,3 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_batch_queries_with_save_interaction(mock_edge): - """Test get_completion batch queries with save_interaction enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object( - retriever, - "get_context", - new_callable=AsyncMock, - return_value=[[mock_edge], [mock_edge]], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - "Extension query", - "Generated answer", - "Generated answer", - ], # Extension query, then final answer - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=[ - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - query_batch=["test query 1", "test query 2"], - context=[[mock_edge], [mock_edge]], - context_extension_rounds=1, - ) - - assert isinstance(completion, list) - assert len(completion) == 2 - mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 4b05021c3..1a6155c4f 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -853,87 +853,3 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_batch_queries_with_save_interaction(mock_edge): - """Test get_completion batch queries with save_interaction enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionCotRetriever(save_interaction=True) - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch.object( - retriever, - "get_context", - new_callable=AsyncMock, - return_value=[[mock_edge], [mock_edge]], - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", - return_value="Rendered prompt", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - side_effect=[ - "validation_result", - "validation_result", - "followup_question", - "followup_question", - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=[ - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - # Pass context so save_interaction condition is met - completion = await retriever.get_completion( - query_batch=["test query 1", "test query 2"], - context=[[mock_edge], [mock_edge]], - max_iter=1, - ) - - assert isinstance(completion, list) - assert len(completion) == 2 - mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 3673b7ace..a6fb05270 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -842,44 +842,3 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_batch_queries_with_save_interaction(mock_edge): - """Test get_completion batch queries with save_interaction.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever(save_interaction=True) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[[mock_edge], [mock_edge]], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - query_batch=["test query 1", "test query 1"], context=None - ) - - assert isinstance(completion, list) - assert len(completion) == 2