diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 0dc3a8bf6..afbaa2978 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -104,10 +104,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) - if query_batch and self.save_interaction: - raise QueryValidationError( - message="Cannot use batch queries with interaction saving currently." - ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -264,13 +260,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): for batched_query in query_batch: result_completion_batch.append(finished_queries_states[batched_query].completion) - # TODO: Do batch queries for save interaction if self.save_interaction and context_text_batch and triplets_batch and completion_batch: - await self.save_qa( - question=query, - answer=completion_batch[0], - context=context_text_batch[0], - triplets=triplets_batch[0], + await asyncio.gather( + *[ + self.save_qa( + question=batched_query, + answer=finished_queries_states[batched_query].completion, + context=finished_queries_states[batched_query].context_text, + triplets=finished_queries_states[batched_query].triplets, + ) + for batched_query in query_batch + ] ) if session_save: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 114578aa9..76f27b65e 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -332,10 +332,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) - if query_batch and self.save_interaction: - raise QueryValidationError( - message="Cannot use batch queries with interaction saving currently." - ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -355,14 +351,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): response_model=response_model, ) - # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: - await self.save_qa( - question=query, - answer=str(completion[0]), - context=context_text[0], - triplets=triplets[0], - ) + if query_batch: + await asyncio.gather( + *[ + self.save_qa( + question=batched_query, + answer=str(batched_completion), + context=batched_context_text, + triplets=batched_triplet, + ) + for batched_query, batched_completion, batched_context_text, batched_triplet in zip( + query_batch, completion, context_text, triplets + ) + ] + ) + else: + await self.save_qa( + question=query, + answer=str(completion[0]), + context=context_text[0], + triplets=triplets[0], + ) # TODO: Handle session save interaction for batch queries # Save to session cache if enabled diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index d9667a669..cbaec599f 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -220,10 +220,6 @@ class GraphCompletionRetriever(BaseGraphRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) - if query_batch and self.save_interaction: - raise QueryValidationError( - message="Cannot use batch queries with interaction saving currently." - ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -236,7 +232,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): context_text = "" context_text_batch = [] - if triplets and isinstance(triplets[0], list): + if query_batch: context_text_batch = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) @@ -284,9 +280,24 @@ class GraphCompletionRetriever(BaseGraphRetriever): ) if self.save_interaction and context and triplets and completion: - await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets - ) + if query: + await self.save_qa( + question=query, answer=completion, context=context_text, triplets=triplets + ) + else: + await asyncio.gather( + *[ + await self.save_qa( + question=batched_query, + answer=batched_completion, + context=batched_context_text, + triplets=batched_triplets, + ) + for batched_query, batched_completion, batched_context_text, batched_triplets in zip( + query_batch, completion, context_text_batch, triplets + ) + ] + ) if session_save: await save_conversation_history( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 9ceca96e2..cec71a443 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -804,3 +804,72 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_save_interaction(mock_edge): + """Test get_completion batch queries with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 1a6155c4f..4b05021c3 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -853,3 +853,87 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_save_interaction(mock_edge): + """Test get_completion batch queries with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionCotRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=[ + "validation_result", + "validation_result", + "followup_question", + "followup_question", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # Pass context so save_interaction condition is met + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + max_iter=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index a6fb05270..3673b7ace 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -842,3 +842,44 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_save_interaction(mock_edge): + """Test get_completion batch queries with save_interaction.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 1"], context=None + ) + + assert isinstance(completion, list) + assert len(completion) == 2