From 40667e63c943c5786ea83b11b00b96542e9a2d17 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 15 Jan 2026 14:01:10 +0100 Subject: [PATCH 01/19] feat: enable batch queries for graph completion retrievers --- ..._completion_context_extension_retriever.py | 116 +++++++++++++----- .../graph_completion_cot_retriever.py | 35 ++++-- .../retrieval/graph_completion_retriever.py | 62 +++++++--- 3 files changed, 161 insertions(+), 52 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index fc49a139b..7774ba9e5 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -56,8 +56,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): async def get_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, context_extension_rounds=4, response_model: Type = str, @@ -91,46 +92,98 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): """ triplets = context - if triplets is None: - triplets = await self.get_context(query) + if query: + # This is done mostly to avoid duplicating a lot of code unnecessarily + query_batch = [query] + query = None + if triplets: + triplets = [triplets] - context_text = await self.resolve_edges_to_text(triplets) + if triplets is None: + triplets = await self.get_context(query, query_batch) + + context_text = "" + context_texts = "" + if isinstance(triplets[0], list): + context_texts = await asyncio.gather( + *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) + else: + context_text = await self.resolve_edges_to_text(triplets) round_idx = 1 + # We will be removing queries, and their associated triplets and context, as we go + # through iterations, so we need to save their final states for the final generation. + original_query_batch = query_batch + saved_triplets = [] + saved_context_texts = [] while round_idx <= context_extension_rounds: - prev_size = len(triplets) - logger.info( f"Context extension: round {round_idx} - generating next graph locational query." ) - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) - triplets += await self.get_context(completion) - triplets = list(set(triplets)) - context_text = await self.resolve_edges_to_text(triplets) - - num_triplets = len(triplets) - - if num_triplets == prev_size: + query_batch = [query for query in query_batch if query] + triplets = [triplet_element for triplet_element in triplets if triplet_element] + context_texts = [context_text for context_text in context_texts if context_text] + if len(query_batch) == 0: logger.info( f"Context extension: round {round_idx} – no new triplets found; stopping early." ) break + prev_sizes = [len(triplets_element) for triplets_element in triplets] + + completions = await asyncio.gather( + *[ + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) + for query, context in zip(query_batch, context_texts) + ], + ) + + new_triplets = await self.get_context(query_batch=completions) + for i, (triplets_element, new_triplets_element) in enumerate( + zip(triplets, new_triplets) + ): + triplets_element += new_triplets_element + triplets[i] = list(dict.fromkeys(triplets_element)) + + context_texts = await asyncio.gather( + *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) + + new_sizes = [len(triplets_element) for triplets_element in triplets] + + for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate( + zip(query_batch, prev_sizes, new_sizes, triplets, context_texts) + ): + if prev_size == new_size: + # In this case, we can stop trying to extend the context of this query + query_batch[i] = "" + saved_triplets.append(triplet_element) + triplets[i] = [] + saved_context_texts.append(context_text) + context_texts[i] = "" + logger.info( f"Context extension: round {round_idx} - " - f"number of unique retrieved triplets: {num_triplets}" + f"number of unique retrieved triplets for each query : {new_sizes}" ) round_idx += 1 + # Reset variables for the final generations. They contain the final state + # of triplets and contexts for each query, after all extension iterations. + query_batch = original_query_batch + context_texts = saved_context_texts + triplets = saved_triplets + # Check if we need to generate context summary for caching cache_config = CacheConfig() user = session_user.get() @@ -153,13 +206,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ), ) else: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - response_model=response_model, + completion = await asyncio.gather( + *[ + generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) + for query, context_text in zip(query_batch, context_texts) + ], ) if self.save_interaction and context_text and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 70fcb6cdb..5c0dde0df 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -170,7 +170,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def get_completion( self, - query: str, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, max_iter=4, @@ -213,13 +214,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if session_save: conversation_history = await get_conversation_history(session_id=session_id) - completion, context_text, triplets = await self._run_cot_completion( - query=query, - context=context, - conversation_history=conversation_history, - max_iter=max_iter, - response_model=response_model, - ) + completion_results = [] + if query_batch and len(query_batch) > 0: + completion_results = await asyncio.gather( + *[ + self._run_cot_completion( + query=query, + context=context, + conversation_history=conversation_history, + max_iter=max_iter, + response_model=response_model, + ) + for query in query_batch + ] + ) + else: + completion, context_text, triplets = await self._run_cot_completion( + query=query, + context=context, + conversation_history=conversation_history, + max_iter=max_iter, + response_model=response_model, + ) if self.save_interaction and context and triplets and completion: await self.save_qa( @@ -236,4 +252,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) + if completion_results: + return [completion for completion, _, _ in completion_results] + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index bb8b34327..1b78e18ac 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -79,7 +79,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): """ return await resolve_edges_to_text(retrieved_edges) - async def get_triplets(self, query: str) -> List[Edge]: + async def get_triplets( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + ) -> List[Edge] | List[List[Edge]]: """ Retrieves relevant graph triplets based on a query string. @@ -107,6 +111,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): found_triplets = await brute_force_triplet_search( query, + query_batch, top_k=self.top_k, collections=vector_index_collections or None, node_type=self.node_type, @@ -117,7 +122,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): return found_triplets - async def get_context(self, query: str) -> List[Edge]: + async def get_context( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + ) -> List[Edge] | List[List[Edge]]: """ Retrieves and resolves graph triplets into context based on a query. @@ -139,7 +148,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): logger.warning("Search attempt on an empty knowledge graph") return [] - triplets = await self.get_triplets(query) + triplets = await self.get_triplets(query, query_batch) if len(triplets) == 0: logger.warning("Empty context was provided to the completion") @@ -158,10 +167,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): async def get_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, response_model: Type = str, + query_batch: Optional[List[str]] = None, ) -> List[Any]: """ Generates a completion using graph connections context based on a query. @@ -183,9 +193,16 @@ class GraphCompletionRetriever(BaseGraphRetriever): triplets = context if triplets is None: - triplets = await self.get_context(query) + triplets = await self.get_context(query, query_batch) - context_text = await resolve_edges_to_text(triplets) + context_text = "" + context_texts = "" + if isinstance(triplets[0], list): + context_texts = await asyncio.gather( + *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) + else: + context_text = await resolve_edges_to_text(triplets) cache_config = CacheConfig() user = session_user.get() @@ -208,14 +225,29 @@ class GraphCompletionRetriever(BaseGraphRetriever): ), ) else: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - response_model=response_model, - ) + if query_batch and len(query_batch) > 0: + completion = await asyncio.gather( + *[ + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) + for query, context in zip(query_batch, context_texts) + ], + ) + else: + completion = await generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) if self.save_interaction and context and triplets and completion: await self.save_qa( From 2655df9b218b4116d84c555e950dc3a92775d44d Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 15 Jan 2026 19:29:21 +0100 Subject: [PATCH 02/19] fix: fix failing tests --- ..._completion_context_extension_retriever.py | 20 ++++++++++++------- .../retrieval/graph_completion_retriever.py | 2 +- ...letion_retriever_context_extension_test.py | 16 +++++++-------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 7774ba9e5..ddf4ef615 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -95,16 +95,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): if query: # This is done mostly to avoid duplicating a lot of code unnecessarily query_batch = [query] - query = None if triplets: triplets = [triplets] if triplets is None: - triplets = await self.get_context(query, query_batch) + triplets = await self.get_context(query_batch=query_batch) context_text = "" context_texts = "" - if isinstance(triplets[0], list): + if triplets and isinstance(triplets[0], list): context_texts = await asyncio.gather( *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) @@ -123,6 +122,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): f"Context extension: round {round_idx} - generating next graph locational query." ) + # Filter out the queries that cannot be extended further, and their associated contexts query_batch = [query for query in query_batch if query] triplets = [triplet_element for triplet_element in triplets if triplet_element] context_texts = [context_text for context_text in context_texts if context_text] @@ -181,8 +181,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Reset variables for the final generations. They contain the final state # of triplets and contexts for each query, after all extension iterations. query_batch = original_query_batch - context_texts = saved_context_texts - triplets = saved_triplets + context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts + triplets = saved_triplets if len(saved_triplets) > 0 else triplets + + if len(query_batch) == 1: + triplets = [] if not triplets else triplets[0] + context_text = "" if not context_texts else context_texts[0] # Check if we need to generate context summary for caching cache_config = CacheConfig() @@ -221,8 +225,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) if self.save_interaction and context_text and triplets and completion: + if isinstance(completion, list): + completion = completion[0] await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets + question=query, answer=completion[0], context=context_text, triplets=triplets ) if session_save: @@ -233,4 +239,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id=session_id, ) - return [completion] + return completion if isinstance(completion, list) else [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 1b78e18ac..5c2c3bf1b 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -197,7 +197,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): context_text = "" context_texts = "" - if isinstance(triplets[0], list): + if triplets and isinstance(triplets[0], list): context_texts = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 6a9b07d38..9095af69c 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -81,7 +81,7 @@ async def test_get_completion_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -157,7 +157,7 @@ async def test_get_completion_context_extension_rounds(mock_edge): retriever, "get_context", new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], + side_effect=[[[mock_edge]], [[mock_edge2]]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -194,7 +194,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge): retriever = GraphCompletionContextExtensionRetriever() with ( - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -240,7 +240,7 @@ async def test_get_completion_with_session(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -304,7 +304,7 @@ async def test_get_completion_with_save_interaction(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -361,7 +361,7 @@ async def test_get_completion_with_response_model(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -403,7 +403,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -446,7 +446,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", From b05c93bf5ff1e71e1457f49fdbb97012ca19210f Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 15 Jan 2026 20:40:33 +0100 Subject: [PATCH 03/19] fix: fix sessions tests --- cognee/modules/retrieval/graph_completion_cot_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 5c0dde0df..eec8ba101 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -171,11 +171,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def get_completion( self, query: Optional[str] = None, - query_batch: Optional[List[str]] = None, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, max_iter=4, response_model: Type = str, + query_batch: Optional[List[str]] = None, ) -> List[Any]: """ Generate completion responses based on a user query and contextual information. From 98e8d226ebf5e70f61e486dd69a2997752ecc0cf Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 16 Jan 2026 11:57:03 +0100 Subject: [PATCH 04/19] test: add tests for batch query graph completions --- .../graph_completion_cot_retriever.py | 14 +- .../retrieval/graph_completion_retriever.py | 2 +- ...letion_retriever_context_extension_test.py | 270 ++++++++++++++++++ .../graph_completion_retriever_cot_test.py | 129 +++++++++ .../graph_completion_retriever_test.py | 156 ++++++++++ 5 files changed, 567 insertions(+), 4 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index eec8ba101..160419ee9 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -171,7 +171,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def get_completion( self, query: Optional[str] = None, - context: Optional[List[Edge]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, max_iter=4, response_model: Type = str, @@ -216,16 +216,22 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): completion_results = [] if query_batch and len(query_batch) > 0: + if not context: + # Having a list is necessary to zip through it + context = [] + for query in query_batch: + context.append(None) + completion_results = await asyncio.gather( *[ self._run_cot_completion( query=query, - context=context, + context=context_el, conversation_history=conversation_history, max_iter=max_iter, response_model=response_model, ) - for query in query_batch + for query, context_el in zip(query_batch, context) ] ) else: @@ -237,11 +243,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): response_model=response_model, ) + # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: await self.save_qa( question=query, answer=str(completion), context=context_text, triplets=triplets ) + # TODO: Handle session save interaction for batch queries # Save to session cache if enabled if session_save: context_summary = await summarize_text(context_text) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 5c2c3bf1b..f740496d0 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -262,7 +262,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): session_id=session_id, ) - return [completion] + return completion if isinstance(completion, list) else [completion] async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: """ diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 9095af69c..567edde51 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -1,4 +1,5 @@ import pytest +from itertools import cycle from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID @@ -467,3 +468,272 @@ async def test_get_completion_zero_extension_rounds(mock_edge): assert isinstance(completion, list) assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + """Test get_completion batch queries retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_provided_context(mock_edge): + """Test get_completion batch queries uses provided context.""" + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_context_extension_rounds(mock_edge): + """Test get_completion batch queries with multiple context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_context_extension_stops_early(mock_edge): + """Test get_completion batch queries stops early when no new triplets found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # When get_context returns same triplets, the loop should stop early + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=4, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_zero_extension_rounds(mock_edge): + """Test get_completion batch queries with zero context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=0 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + TestModel(answer="Test answer"), + TestModel(answer="Test answer"), + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + response_model=TestModel, + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 9f3147512..5ae901594 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID +from itertools import cycle from cognee.modules.retrieval.graph_completion_cot_retriever import ( GraphCompletionCotRetriever, @@ -686,3 +687,131 @@ async def test_as_answer_text_with_basemodel(): assert isinstance(result, str) assert "[Structured Response]" in result assert "test answer" in result + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_context(mock_edge): + """Test get_completion batch queries with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=cycle(["validation_result", "followup_question"]), + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + max_iter=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + """Test get_completion batch queries without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +# +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion of batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index c22f30fd0..ae0adb729 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -646,3 +646,159 @@ async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge assert len(completion) == 1 assert completion[0] == "Generated answer" mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_context(mock_edge): + """Test get_completion correctly handles batch queries.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context=[[mock_edge], [mock_edge]] + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_without_context(mock_edge): + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"]) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_response_model(mock_edge): + """Test get_completion of batch queries with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], response_model=TestModel + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[], []], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 2"]) + + assert isinstance(completion, list) + assert len(completion) == 2 From 17554466ba7c344802c7f4e8036d3ca731bc2258 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 16 Jan 2026 19:32:15 +0100 Subject: [PATCH 05/19] fix: fix logic issue that coderabbit flagged regarding ordering of lists --- ..._completion_context_extension_retriever.py | 41 +++++++++---------- ...letion_retriever_context_extension_test.py | 7 +++- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index ddf4ef615..1d20d4404 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -102,21 +102,20 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): triplets = await self.get_context(query_batch=query_batch) context_text = "" - context_texts = "" - if triplets and isinstance(triplets[0], list): - context_texts = await asyncio.gather( - *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] - ) - else: - context_text = await self.resolve_edges_to_text(triplets) + context_texts = await asyncio.gather( + *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) round_idx = 1 # We will be removing queries, and their associated triplets and context, as we go # through iterations, so we need to save their final states for the final generation. + # Final state is stored in the finished_queries_data dict, and we populate it at the start as well. original_query_batch = query_batch - saved_triplets = [] - saved_context_texts = [] + finished_queries_data = {} + for i, query in enumerate(query_batch): + finished_queries_data[query] = (triplets[i], context_texts[i]) + while round_idx <= context_extension_rounds: logger.info( f"Context extension: round {round_idx} - generating next graph locational query." @@ -147,6 +146,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ], ) + # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets = await self.get_context(query_batch=completions) for i, (triplets_element, new_triplets_element) in enumerate( zip(triplets, new_triplets) @@ -160,15 +160,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): new_sizes = [len(triplets_element) for triplets_element in triplets] - for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate( + for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate( zip(query_batch, prev_sizes, new_sizes, triplets, context_texts) ): + finished_queries_data[query] = (triplets_element, context_text) if prev_size == new_size: # In this case, we can stop trying to extend the context of this query query_batch[i] = "" - saved_triplets.append(triplet_element) triplets[i] = [] - saved_context_texts.append(context_text) context_texts[i] = "" logger.info( @@ -181,12 +180,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Reset variables for the final generations. They contain the final state # of triplets and contexts for each query, after all extension iterations. query_batch = original_query_batch - context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts - triplets = saved_triplets if len(saved_triplets) > 0 else triplets - - if len(query_batch) == 1: - triplets = [] if not triplets else triplets[0] - context_text = "" if not context_texts else context_texts[0] + triplets = [] + context_texts = [] + for query in query_batch: + triplets.append(finished_queries_data[query][0]) + context_texts.append(finished_queries_data[query][1]) # Check if we need to generate context summary for caching cache_config = CacheConfig() @@ -224,11 +222,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ], ) - if self.save_interaction and context_text and triplets and completion: - if isinstance(completion, list): - completion = completion[0] + # TODO: Do batch queries for save interaction + if self.save_interaction and context_texts and triplets and completion: await self.save_qa( - question=query, answer=completion[0], context=context_text, triplets=triplets + question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0] ) if session_save: diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 567edde51..8d6310214 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -569,7 +569,12 @@ async def test_get_completion_batch_queries_context_extension_rounds(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts + side_effect=[ + "Resolved context", + "Resolved context", + "Extended context", + "Extended context", + ], # Different contexts ), patch( "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", From b88e4242ade6c2810f1b5aec404bbcecb0323bd2 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Sun, 18 Jan 2026 22:53:16 +0100 Subject: [PATCH 06/19] fix: PR comments fixes --- ..._completion_context_extension_retriever.py | 87 +++++++++++-------- .../graph_completion_cot_retriever.py | 26 +++--- .../retrieval/graph_completion_retriever.py | 41 +++++++-- .../utils/brute_force_triplet_search.py | 16 ++-- .../retrieval/utils/validate_queries.py | 14 +++ 5 files changed, 116 insertions(+), 68 deletions(-) create mode 100644 cognee/modules/retrieval/utils/validate_queries.py diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 1d20d4404..f7603faba 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text @@ -90,20 +91,25 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ - triplets = context + # TODO: This may be unnecessary in this retriever, will check later + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + + triplets_batch = context if query: # This is done mostly to avoid duplicating a lot of code unnecessarily query_batch = [query] - if triplets: - triplets = [triplets] + if triplets_batch: + triplets_batch = [triplets_batch] - if triplets is None: - triplets = await self.get_context(query_batch=query_batch) + if triplets_batch is None: + triplets_batch = await self.get_context(query_batch=query_batch) context_text = "" - context_texts = await asyncio.gather( - *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + context_text_batch = await asyncio.gather( + *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] ) round_idx = 1 @@ -114,7 +120,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): original_query_batch = query_batch finished_queries_data = {} for i, query in enumerate(query_batch): - finished_queries_data[query] = (triplets[i], context_texts[i]) + finished_queries_data[query] = (triplets_batch[i], context_text_batch[i]) while round_idx <= context_extension_rounds: logger.info( @@ -123,15 +129,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Filter out the queries that cannot be extended further, and their associated contexts query_batch = [query for query in query_batch if query] - triplets = [triplet_element for triplet_element in triplets if triplet_element] - context_texts = [context_text for context_text in context_texts if context_text] + triplets_batch = [triplets for triplets in triplets_batch if triplets] + context_text_batch = [ + context_text for context_text in context_text_batch if context_text + ] if len(query_batch) == 0: logger.info( f"Context extension: round {round_idx} – no new triplets found; stopping early." ) break - prev_sizes = [len(triplets_element) for triplets_element in triplets] + prev_sizes = [len(triplets) for triplets in triplets_batch] completions = await asyncio.gather( *[ @@ -142,33 +150,31 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) - for query, context in zip(query_batch, context_texts) + for query, context in zip(query_batch, context_text_batch) ], ) # Get new triplets, and merge them with existing ones, filtering out duplicates - new_triplets = await self.get_context(query_batch=completions) - for i, (triplets_element, new_triplets_element) in enumerate( - zip(triplets, new_triplets) - ): - triplets_element += new_triplets_element - triplets[i] = list(dict.fromkeys(triplets_element)) + new_triplets_batch = await self.get_context(query_batch=completions) + for i, (triplets, new_triplets) in enumerate(zip(triplets_batch, new_triplets_batch)): + triplets += new_triplets + triplets_batch[i] = list(dict.fromkeys(triplets)) - context_texts = await asyncio.gather( - *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + context_text_batch = await asyncio.gather( + *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] ) - new_sizes = [len(triplets_element) for triplets_element in triplets] + new_sizes = [len(triplets) for triplets in triplets_batch] - for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate( - zip(query_batch, prev_sizes, new_sizes, triplets, context_texts) + for i, (batched_query, prev_size, new_size, triplets, context_text) in enumerate( + zip(query_batch, prev_sizes, new_sizes, triplets_batch, context_text_batch) ): - finished_queries_data[query] = (triplets_element, context_text) + finished_queries_data[query] = (triplets, context_text) if prev_size == new_size: # In this case, we can stop trying to extend the context of this query query_batch[i] = "" - triplets[i] = [] - context_texts[i] = "" + triplets_batch[i] = [] + context_text_batch[i] = "" logger.info( f"Context extension: round {round_idx} - " @@ -180,11 +186,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Reset variables for the final generations. They contain the final state # of triplets and contexts for each query, after all extension iterations. query_batch = original_query_batch - triplets = [] - context_texts = [] - for query in query_batch: - triplets.append(finished_queries_data[query][0]) - context_texts.append(finished_queries_data[query][1]) + triplets_batch = [] + context_text_batch = [] + for batched_query in query_batch: + triplets_batch.append(finished_queries_data[batched_query][0]) + context_text_batch.append(finished_queries_data[batched_query][1]) # Check if we need to generate context summary for caching cache_config = CacheConfig() @@ -192,6 +198,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_id = getattr(user, "id", None) session_save = user_id and cache_config.caching + completion_batch = [] + if session_save: conversation_history = await get_conversation_history(session_id=session_id) @@ -208,24 +216,27 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ), ) else: - completion = await asyncio.gather( + completion_batch = await asyncio.gather( *[ generate_completion( - query=query, - context=context_text, + query=batched_query, + context=batched_context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, response_model=response_model, ) - for query, context_text in zip(query_batch, context_texts) + for batched_query, batched_context_text in zip(query_batch, context_text_batch) ], ) # TODO: Do batch queries for save interaction - if self.save_interaction and context_texts and triplets and completion: + if self.save_interaction and context_text_batch and triplets_batch and completion_batch: await self.save_qa( - question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0] + question=query, + answer=completion_batch[0], + context=context_text_batch[0], + triplets=triplets_batch[0], ) if session_save: @@ -236,4 +247,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion if isinstance(completion, list) else [completion] + return completion_batch if completion_batch else [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 160419ee9..4ecdc910a 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -3,6 +3,7 @@ import json from typing import Optional, List, Type, Any from pydantic import BaseModel from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -203,6 +204,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + # Check if session saving is enabled cache_config = CacheConfig() user = session_user.get() @@ -214,24 +219,25 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if session_save: conversation_history = await get_conversation_history(session_id=session_id) - completion_results = [] + context_batch = context + completion_batch = [] if query_batch and len(query_batch) > 0: - if not context: + if not context_batch: # Having a list is necessary to zip through it - context = [] - for query in query_batch: - context.append(None) + context_batch = [] + for _ in query_batch: + context_batch.append(None) - completion_results = await asyncio.gather( + completion_batch = await asyncio.gather( *[ self._run_cot_completion( query=query, - context=context_el, + context=context, conversation_history=conversation_history, max_iter=max_iter, response_model=response_model, ) - for query, context_el in zip(query_batch, context) + for batched_query, context in zip(query_batch, context_batch) ] ) else: @@ -260,7 +266,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - if completion_results: - return [completion for completion, _, _ in completion_results] + if completion_batch: + return [completion for completion, _, _ in completion_batch] return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index f740496d0..a1b4c3833 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -4,6 +4,7 @@ from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses @@ -150,15 +151,33 @@ class GraphCompletionRetriever(BaseGraphRetriever): triplets = await self.get_triplets(query, query_batch) - if len(triplets) == 0: - logger.warning("Empty context was provided to the completion") - return [] + if query_batch: + for batched_triplets, batched_query in zip(triplets, query_batch): + if len(batched_triplets) == 0: + logger.warning( + f"Empty context was provided to the completion for the query: {batched_query}" + ) + entity_nodes_batch = [] + for batched_triplets in triplets: + entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets)) - # context = await self.resolve_edges_to_text(triplets) + await asyncio.gather( + *[ + update_node_access_timestamps(batched_entity_nodes) + for batched_entity_nodes in entity_nodes_batch + ] + ) + else: + if len(triplets) == 0: + logger.warning("Empty context was provided to the completion") + return [] - entity_nodes = get_entity_nodes_from_triplets(triplets) + # context = await self.resolve_edges_to_text(triplets) + + entity_nodes = get_entity_nodes_from_triplets(triplets) + + await update_node_access_timestamps(entity_nodes) - await update_node_access_timestamps(entity_nodes) return triplets async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): @@ -190,15 +209,19 @@ class GraphCompletionRetriever(BaseGraphRetriever): - Any: A generated completion based on the query and context provided. """ + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + triplets = context if triplets is None: triplets = await self.get_context(query, query_batch) context_text = "" - context_texts = "" + context_text_batch = [] if triplets and isinstance(triplets[0], list): - context_texts = await asyncio.gather( + context_text_batch = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) else: @@ -236,7 +259,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): system_prompt=self.system_prompt, response_model=response_model, ) - for query, context in zip(query_batch, context_texts) + for query, context in zip(query_batch, context_text_batch) ], ) else: diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index ce84c1423..c0a0e7fab 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,5 +1,6 @@ from typing import List, Optional, Type, Union +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine @@ -146,17 +147,10 @@ async def brute_force_triplet_search( In single-query mode, node_distances and edge_distances are stored as flat lists. In batch mode, they are stored as list-of-lists (one list per query). """ - if query is not None and query_batch is not None: - raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") - if query is None and query_batch is None: - raise ValueError("Must provide either 'query' or 'query_batch'.") - if query is not None and (not query or not isinstance(query, str)): - raise ValueError("The query must be a non-empty string.") - if query_batch is not None: - if not isinstance(query_batch, list) or not query_batch: - raise ValueError("query_batch must be a non-empty list of strings.") - if not all(isinstance(q, str) and q for q in query_batch): - raise ValueError("All items in query_batch must be non-empty strings.") + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + if top_k <= 0: raise ValueError("top_k must be a positive integer.") diff --git a/cognee/modules/retrieval/utils/validate_queries.py b/cognee/modules/retrieval/utils/validate_queries.py new file mode 100644 index 000000000..913b0d665 --- /dev/null +++ b/cognee/modules/retrieval/utils/validate_queries.py @@ -0,0 +1,14 @@ +def validate_queries(query, query_batch) -> tuple[bool, str]: + if query is not None and query_batch is not None: + return False, "Cannot provide both 'query' and 'query_batch'; use exactly one." + if query is None and query_batch is None: + return False, "Must provide either 'query' or 'query_batch'." + if query is not None and (not query or not isinstance(query, str)): + return False, "The query must be a non-empty string." + if query_batch is not None: + if not isinstance(query_batch, list) or not query_batch: + return False, "query_batch must be a non-empty list of strings." + if not all(isinstance(q, str) and q for q in query_batch): + return False, "All items in query_batch must be non-empty strings." + + return True, "" From d258b1d7afd753c27b8b2af9948dc3551f398d76 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Mon, 19 Jan 2026 00:03:35 +0100 Subject: [PATCH 07/19] fix: fix tests --- .../graph_completion_context_extension_retriever.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index f7603faba..4c9122f85 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -58,11 +58,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): async def get_completion( self, query: Optional[str] = None, - query_batch: Optional[List[str]] = None, context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, context_extension_rounds=4, response_model: Type = str, + query_batch: Optional[List[str]] = None, ) -> List[Any]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -107,6 +107,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): if triplets_batch is None: triplets_batch = await self.get_context(query_batch=query_batch) + if not triplets_batch: + return [] + context_text = "" context_text_batch = await asyncio.gather( *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] @@ -119,8 +122,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Final state is stored in the finished_queries_data dict, and we populate it at the start as well. original_query_batch = query_batch finished_queries_data = {} - for i, query in enumerate(query_batch): - finished_queries_data[query] = (triplets_batch[i], context_text_batch[i]) + for i, batched_query in enumerate(query_batch): + if not triplets_batch[i]: + query_batch[i] = "" + else: + finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i]) while round_idx <= context_extension_rounds: logger.info( From 8e0d1124399fb6f6592c279c564700b4355732fd Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Mon, 19 Jan 2026 10:56:50 +0100 Subject: [PATCH 08/19] test: tiny test fix --- .../graph_completion_retriever_context_extension_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 8d6310214..7ce0d3a57 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -611,7 +611,12 @@ async def test_get_completion_batch_queries_context_extension_stops_early(mock_e retriever = GraphCompletionContextExtensionRetriever() with ( - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", From 8c7b309199630ca264c2d42ce58162a00981566a Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Mon, 19 Jan 2026 23:58:14 +0100 Subject: [PATCH 09/19] fix: fix context extension and cot retrievers --- ..._completion_context_extension_retriever.py | 118 ++++---- .../graph_completion_cot_retriever.py | 255 ++++++++++++------ .../retrieval/graph_completion_retriever.py | 21 +- .../utils/brute_force_triplet_search.py | 6 +- .../test_graph_completion_retriever_cot.py | 1 - .../graph_completion_retriever_cot_test.py | 51 ++-- 6 files changed, 286 insertions(+), 166 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 4c9122f85..89e8b80a2 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -15,6 +15,19 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() +class QueryState: + """ + Helper class containing all necessary information about the query state: + the triplets and context associated with it, and also a check whether + it has fully extended the context. + """ + + def __init__(self, triplets: List[Edge], context_text: str, finished_extending_context: bool): + self.triplets = triplets + self.context_text = context_text + self.finished_extending_context = finished_extending_context + + class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): """ Handles graph context completion for question answering tasks, extending context based @@ -91,10 +104,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ - # TODO: This may be unnecessary in this retriever, will check later - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) triplets_batch = context @@ -117,70 +129,87 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx = 1 - # We will be removing queries, and their associated triplets and context, as we go - # through iterations, so we need to save their final states for the final generation. - # Final state is stored in the finished_queries_data dict, and we populate it at the start as well. - original_query_batch = query_batch - finished_queries_data = {} - for i, batched_query in enumerate(query_batch): - if not triplets_batch[i]: - query_batch[i] = "" - else: - finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i]) + # We store queries as keys and their associated states in this dict. + # The state is a 3-item object QueryState, which holds triplets, context text, + # and a boolean marking whether we should continue extending the context for that query. + finished_queries_states = {} + + for batched_query, batched_triplets, batched_context_text in zip( + query_batch, triplets_batch, context_text_batch + ): + # Populating the dict at the start with initial information. + finished_queries_states[batched_query] = QueryState( + batched_triplets, batched_context_text, False + ) while round_idx <= context_extension_rounds: logger.info( f"Context extension: round {round_idx} - generating next graph locational query." ) - # Filter out the queries that cannot be extended further, and their associated contexts - query_batch = [query for query in query_batch if query] - triplets_batch = [triplets for triplets in triplets_batch if triplets] - context_text_batch = [ - context_text for context_text in context_text_batch if context_text - ] - if len(query_batch) == 0: + if all( + batched_query_state.finished_extending_context + for batched_query_state in finished_queries_states.values() + ): + # We stop early only if all queries in the batch have reached their final state logger.info( f"Context extension: round {round_idx} – no new triplets found; stopping early." ) break - prev_sizes = [len(triplets) for triplets in triplets_batch] + prev_sizes = [ + len(batched_query_state.triplets) + for batched_query_state in finished_queries_states.values() + if not batched_query_state.finished_extending_context + ] completions = await asyncio.gather( *[ generate_completion( - query=query, - context=context, + query=batched_query, + context=batched_query_state.context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) - for query, context in zip(query_batch, context_text_batch) + for batched_query, batched_query_state in finished_queries_states.items() + if not batched_query_state.finished_extending_context ], ) # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets_batch = await self.get_context(query_batch=completions) - for i, (triplets, new_triplets) in enumerate(zip(triplets_batch, new_triplets_batch)): - triplets += new_triplets - triplets_batch[i] = list(dict.fromkeys(triplets)) + for batched_query, batched_new_triplets in zip(query_batch, new_triplets_batch): + finished_queries_states[batched_query].triplets = list( + dict.fromkeys( + finished_queries_states[batched_query].triplets + batched_new_triplets + ) + ) + # Resolve new triplets to text context_text_batch = await asyncio.gather( - *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] + *[ + self.resolve_edges_to_text(batched_query_state.triplets) + for batched_query_state in finished_queries_states.values() + if not batched_query_state.finished_extending_context + ] ) - new_sizes = [len(triplets) for triplets in triplets_batch] + # Update context_texts in query states + for batched_query, batched_context_text in zip(query_batch, context_text_batch): + if not finished_queries_states[batched_query].finished_extending_context: + finished_queries_states[batched_query].context_text = batched_context_text - for i, (batched_query, prev_size, new_size, triplets, context_text) in enumerate( - zip(query_batch, prev_sizes, new_sizes, triplets_batch, context_text_batch) - ): - finished_queries_data[query] = (triplets, context_text) + new_sizes = [ + len(batched_query_state.triplets) + for batched_query_state in finished_queries_states.values() + if not batched_query_state.finished_extending_context + ] + + for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes): + # Mark done queries accordingly if prev_size == new_size: - # In this case, we can stop trying to extend the context of this query - query_batch[i] = "" - triplets_batch[i] = [] - context_text_batch[i] = "" + finished_queries_states[batched_query].finished_extending_context = True logger.info( f"Context extension: round {round_idx} - " @@ -189,15 +218,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 - # Reset variables for the final generations. They contain the final state - # of triplets and contexts for each query, after all extension iterations. - query_batch = original_query_batch - triplets_batch = [] - context_text_batch = [] - for batched_query in query_batch: - triplets_batch.append(finished_queries_data[batched_query][0]) - context_text_batch.append(finished_queries_data[batched_query][1]) - # Check if we need to generate context summary for caching cache_config = CacheConfig() user = session_user.get() @@ -226,13 +246,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): *[ generate_completion( query=batched_query, - context=batched_context_text, + context=batched_query_state.context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, response_model=response_model, ) - for batched_query, batched_context_text in zip(query_batch, context_text_batch) + for batched_query, batched_query_state in finished_queries_states.items() ], ) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 4ecdc910a..58c9bad15 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -23,6 +23,27 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() +class QueryState: + """ + Helper class containing all necessary information about the query state. + Used to keep track of important information in a more readable way, and + enable as many parallel calls to llms as possible. + """ + + completion: str = "" + triplets: List[Edge] = [] + context_text: str = "" + + answer_text: str = "" + valid_user_prompt: str = "" + valid_system_prompt: str = "" + reasoning: str = "" + + followup_question: str = "" + followup_prompt: str = "" + followup_system: str = "" + + def _as_answer_text(completion: Any) -> str: """Convert completion to human-readable text for validation and follow-up prompts.""" if isinstance(completion, str): @@ -87,12 +108,13 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def _run_cot_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, conversation_history: str = "", max_iter: int = 4, response_model: Type = str, - ) -> tuple[Any, str, List[Edge]]: + ) -> tuple[List[Any], List[str], List[List[Edge]]]: """ Run chain-of-thought completion with optional structured output. @@ -110,64 +132,158 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - context_text: The resolved context text - triplets: The list of triplets used """ - followup_question = "" - triplets = [] - completion = "" + followup_question_batch = [] + completion_batch = [] + context_text_batch = [] + + if query: + # Treat a single query as a batch of queries, mainly avoiding massive code duplication + query_batch = [query] + if context: + context = [context] + + triplets_batch = context + + # dict containing query -> QueryState key-value pairs + # For every query, we save necessary data so we can execute requests in parallel + query_state_tracker = {} + for batched_query in query_batch: + query_state_tracker[batched_query] = QueryState() for round_idx in range(max_iter + 1): if round_idx == 0: if context is None: - triplets = await self.get_context(query) - context_text = await self.resolve_edges_to_text(triplets) + # Get context, resolve to text, and store info in the query state + triplets_batch = await self.get_context(query_batch=query_batch) + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_triplets) + for batched_triplets in triplets_batch + ] + ) + for batched_query, batched_triplets, batched_context_text in zip( + query_batch, triplets_batch, context_text_batch + ): + query_state_tracker[batched_query].triplets = batched_triplets + query_state_tracker[batched_query].context_text = batched_context_text else: - context_text = await self.resolve_edges_to_text(context) + # In this case just resolve to text and save to the query state + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_context) + for batched_context in context + ] + ) + for batched_query, batched_triplets, batched_context_text in zip( + query_batch, context, context_text_batch + ): + query_state_tracker[batched_query].triplets = batched_triplets + query_state_tracker[batched_query].context_text = batched_context_text else: - triplets += await self.get_context(followup_question) - context_text = await self.resolve_edges_to_text(list(set(triplets))) + # Find new triplets, and update existing query states + followup_triplets_batch = await self.get_context( + query_batch=followup_question_batch + ) + for batched_query, batched_followup_triplets in zip( + query_batch, followup_triplets_batch + ): + query_state_tracker[batched_query].triplets = list( + dict.fromkeys( + query_state_tracker[batched_query].triplets + batched_followup_triplets + ) + ) - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - conversation_history=conversation_history if conversation_history else None, - response_model=response_model, + context_text_batch = await asyncio.gather( + *[ + self.resolve_edges_to_text(batched_query_state.triplets) + for batched_query_state in query_state_tracker.values() + ] + ) + + for batched_query, batched_context_text in zip(query_batch, context_text_batch): + query_state_tracker[batched_query].context_text = batched_context_text + + completion_batch = await asyncio.gather( + *[ + generate_completion( + query=batched_query, + context=batched_query_state.context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + conversation_history=conversation_history if conversation_history else None, + response_model=response_model, + ) + for batched_query, batched_query_state in query_state_tracker.items() + ] ) - logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") + for batched_query, batched_completion in zip(query_batch, completion_batch): + query_state_tracker[batched_query].completion = batched_completion + + logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}") if round_idx < max_iter: - answer_text = _as_answer_text(completion) - valid_args = {"query": query, "answer": answer_text, "context": context_text} - valid_user_prompt = render_prompt( - filename=self.validation_user_prompt_path, context=valid_args - ) - valid_system_prompt = read_query_prompt( - prompt_file_name=self.validation_system_prompt_path + for batched_query, batched_query_state in query_state_tracker.items(): + batched_query_state.answer_text = _as_answer_text( + batched_query_state.completion + ) + valid_args = { + "query": batched_query, + "answer": batched_query_state.answer_text, + "context": batched_query_state.context_text, + } + batched_query_state.valid_user_prompt = render_prompt( + filename=self.validation_user_prompt_path, + context=valid_args, + ) + batched_query_state.valid_system_prompt = read_query_prompt( + prompt_file_name=self.validation_system_prompt_path + ) + + reasoning_batch = await asyncio.gather( + *[ + LLMGateway.acreate_structured_output( + text_input=batched_query_state.valid_user_prompt, + system_prompt=batched_query_state.valid_system_prompt, + response_model=str, + ) + for batched_query_state in query_state_tracker.values() + ] ) - reasoning = await LLMGateway.acreate_structured_output( - text_input=valid_user_prompt, - system_prompt=valid_system_prompt, - response_model=str, - ) - followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning} - followup_prompt = render_prompt( - filename=self.followup_user_prompt_path, context=followup_args - ) - followup_system = read_query_prompt( - prompt_file_name=self.followup_system_prompt_path - ) + for batched_query, batched_reasoning in zip(query_batch, reasoning_batch): + query_state_tracker[batched_query].reasoning = batched_reasoning - followup_question = await LLMGateway.acreate_structured_output( - text_input=followup_prompt, system_prompt=followup_system, response_model=str + for batched_query, batched_query_state in query_state_tracker.items(): + followup_args = { + "query": query, + "answer": batched_query_state.answer_text, + "reasoning": batched_query_state.reasoning, + } + batched_query_state.followup_prompt = render_prompt( + filename=self.followup_user_prompt_path, + context=followup_args, + ) + batched_query_state.followup_system = read_query_prompt( + prompt_file_name=self.followup_system_prompt_path + ) + + followup_question_batch = await asyncio.gather( + *[ + LLMGateway.acreate_structured_output( + text_input=batched_query_state.followup_prompt, + system_prompt=batched_query_state.followup_system, + response_model=str, + ) + for batched_query_state in query_state_tracker.values() + ] ) logger.info( - f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" + f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}" ) - return completion, context_text, triplets + return completion_batch, context_text_batch, triplets_batch async def get_completion( self, @@ -204,9 +320,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) # Check if session saving is enabled cache_config = CacheConfig() @@ -219,40 +335,22 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if session_save: conversation_history = await get_conversation_history(session_id=session_id) - context_batch = context - completion_batch = [] - if query_batch and len(query_batch) > 0: - if not context_batch: - # Having a list is necessary to zip through it - context_batch = [] - for _ in query_batch: - context_batch.append(None) - - completion_batch = await asyncio.gather( - *[ - self._run_cot_completion( - query=query, - context=context, - conversation_history=conversation_history, - max_iter=max_iter, - response_model=response_model, - ) - for batched_query, context in zip(query_batch, context_batch) - ] - ) - else: - completion, context_text, triplets = await self._run_cot_completion( - query=query, - context=context, - conversation_history=conversation_history, - max_iter=max_iter, - response_model=response_model, - ) + completion, context_text, triplets = await self._run_cot_completion( + query=query, + query_batch=query_batch, + context=context, + conversation_history=conversation_history, + max_iter=max_iter, + response_model=response_model, + ) # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: await self.save_qa( - question=query, answer=str(completion), context=context_text, triplets=triplets + question=query, + answer=str(completion[0]), + context=context_text[0], + triplets=triplets[0], ) # TODO: Handle session save interaction for batch queries @@ -266,7 +364,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - if completion_batch: - return [completion for completion, _, _ in completion_batch] - - return [completion] + return completion diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index a1b4c3833..711c7ab40 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -158,15 +158,18 @@ class GraphCompletionRetriever(BaseGraphRetriever): f"Empty context was provided to the completion for the query: {batched_query}" ) entity_nodes_batch = [] + for batched_triplets in triplets: entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets)) - await asyncio.gather( - *[ - update_node_access_timestamps(batched_entity_nodes) - for batched_entity_nodes in entity_nodes_batch - ] - ) + # Remove duplicates and update node access, if it is enabled + for batched_entity_nodes in entity_nodes_batch: + # from itertools import chain + # + # flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch)) + # entity_nodes = list(set(flattened_entity_nodes)) + + await update_node_access_timestamps(batched_entity_nodes) else: if len(triplets) == 0: logger.warning("Empty context was provided to the completion") @@ -209,9 +212,9 @@ class GraphCompletionRetriever(BaseGraphRetriever): - Any: A generated completion based on the query and context provided. """ - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) triplets = context diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index c0a0e7fab..764f8d77d 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -147,9 +147,9 @@ async def brute_force_triplet_search( In single-query mode, node_distances and edge_distances are stored as flat lists. In batch mode, they are stored as list-of-lists (one list per query). """ - query_validation = validate_queries(query, query_batch) - if not query_validation[0]: - raise ValueError(query_validation[1]) + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) if top_k <= 0: raise ValueError("top_k must be a positive integer.") diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py index 0db035e03..54ca09b8a 100644 --- a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py @@ -1,4 +1,3 @@ -import os import pytest import pathlib import pytest_asyncio diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 5ae901594..aa71bd9c7 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,3 +1,5 @@ +import os + import pytest from unittest.mock import AsyncMock, patch, MagicMock from uuid import UUID @@ -80,7 +82,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -106,8 +108,8 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge): max_iter=1, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -126,7 +128,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -143,8 +145,8 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge): max_iter=1, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -168,7 +170,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge): retriever, "get_context", new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], + side_effect=[[[mock_edge]], [[mock_edge2]]], ), patch( "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", @@ -200,8 +202,8 @@ async def test_run_cot_completion_multiple_rounds(mock_edge): max_iter=2, ) - assert completion == "Generated answer" - assert context_text == "Resolved context" + assert completion == ["Generated answer"] + assert context_text == ["Resolved context"] assert len(triplets) >= 1 @@ -227,7 +229,7 @@ async def test_run_cot_completion_with_conversation_history(mock_edge): max_iter=1, ) - assert completion == "Generated answer" + assert completion == ["Generated answer"] call_kwargs = mock_generate.call_args[1] assert call_kwargs.get("conversation_history") == "Previous conversation" @@ -259,8 +261,9 @@ async def test_run_cot_completion_with_response_model(mock_edge): max_iter=1, ) - assert isinstance(completion, TestModel) - assert completion.answer == "Test answer" + assert isinstance(completion, list) + assert isinstance(completion[0], TestModel) + assert completion[0].answer == "Test answer" @pytest.mark.asyncio @@ -285,7 +288,7 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge): max_iter=1, ) - assert completion == "Generated answer" + assert completion == ["Generated answer"] # Verify conversation_history was passed as None when empty call_kwargs = mock_generate.call_args[1] assert call_kwargs.get("conversation_history") is None @@ -306,7 +309,7 @@ async def test_get_completion_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -316,7 +319,7 @@ async def test_get_completion_without_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -372,7 +375,7 @@ async def test_get_completion_with_provided_context(mock_edge): mock_config.caching = False mock_cache_config.return_value = mock_config - completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + completion = await retriever.get_completion("test query", context=[[mock_edge]], max_iter=1) assert isinstance(completion, list) assert len(completion) == 1 @@ -397,7 +400,7 @@ async def test_get_completion_with_session(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -463,7 +466,7 @@ async def test_get_completion_with_save_interaction(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -528,7 +531,7 @@ async def test_get_completion_with_response_model(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -570,7 +573,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -612,7 +615,7 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -703,7 +706,7 @@ async def test_get_completion_batch_queries_with_context(mock_edge): "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", return_value="Generated answer", ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", return_value="Generated answer", @@ -749,7 +752,7 @@ async def test_get_completion_batch_queries_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -790,7 +793,7 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", From 58071288ecf9b0c5fd69abd5d85128dc484410fc Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 10:05:06 +0100 Subject: [PATCH 10/19] fix: tiny fix in cot retreiver --- cognee/modules/retrieval/graph_completion_cot_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 58c9bad15..0a4dfb9b6 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -257,7 +257,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): for batched_query, batched_query_state in query_state_tracker.items(): followup_args = { - "query": query, + "query": batched_query, "answer": batched_query_state.answer_text, "reasoning": batched_query_state.reasoning, } From 952618d5616079d1e0fa45f4e36e57dfe061fc1a Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 10:18:56 +0100 Subject: [PATCH 11/19] fix: small class definition change in cot retriever --- .../graph_completion_cot_retriever.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 0a4dfb9b6..a985900e8 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -30,18 +30,19 @@ class QueryState: enable as many parallel calls to llms as possible. """ - completion: str = "" - triplets: List[Edge] = [] - context_text: str = "" + def __init__(self): + self.completion: str = "" + self.triplets: List[Edge] = [] + self.context_text: str = "" - answer_text: str = "" - valid_user_prompt: str = "" - valid_system_prompt: str = "" - reasoning: str = "" + self.answer_text: str = "" + self.valid_user_prompt: str = "" + self.valid_system_prompt: str = "" + self.reasoning: str = "" - followup_question: str = "" - followup_prompt: str = "" - followup_system: str = "" + self.followup_question: str = "" + self.followup_prompt: str = "" + self.followup_system: str = "" def _as_answer_text(completion: Any) -> str: From 05b5add48025bb8373e72f2e08b7a3acd62c6b54 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 10:59:39 +0100 Subject: [PATCH 12/19] fix: cot retriever fix --- .../modules/retrieval/graph_completion_cot_retriever.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index a985900e8..5f61d26d0 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -182,12 +182,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): query_state_tracker[batched_query].context_text = batched_context_text else: # Find new triplets, and update existing query states - followup_triplets_batch = await self.get_context( - query_batch=followup_question_batch - ) - for batched_query, batched_followup_triplets in zip( - query_batch, followup_triplets_batch - ): + triplets_batch = await self.get_context(query_batch=followup_question_batch) + + for batched_query, batched_followup_triplets in zip(query_batch, triplets_batch): query_state_tracker[batched_query].triplets = list( dict.fromkeys( query_state_tracker[batched_query].triplets + batched_followup_triplets From 77fd18a60c0e07579652defba096b8002c6b0df6 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 11:36:01 +0100 Subject: [PATCH 13/19] fix: raise error for query batch plus sessions and save interaction --- ..._completion_context_extension_retriever.py | 20 ++++++++----- .../graph_completion_cot_retriever.py | 12 ++++++-- .../retrieval/graph_completion_retriever.py | 30 ++++++++++++------- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 89e8b80a2..d005759c9 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -104,6 +104,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ + + # Check if we need to generate context summary for caching + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if query_batch and session_save: + raise ValueError("You cannot use batch queries with session saving currently.") + if query_batch and self.save_interaction: + raise ValueError("Cannot use batch queries with interaction saving currently.") + is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: raise ValueError(msg) @@ -160,7 +172,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): prev_sizes = [ len(batched_query_state.triplets) for batched_query_state in finished_queries_states.values() - if not batched_query_state.finished_extending_context ] completions = await asyncio.gather( @@ -203,7 +214,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): new_sizes = [ len(batched_query_state.triplets) for batched_query_state in finished_queries_states.values() - if not batched_query_state.finished_extending_context ] for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes): @@ -218,12 +228,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 - # Check if we need to generate context summary for caching - cache_config = CacheConfig() - user = session_user.get() - user_id = getattr(user, "id", None) - session_save = user_id and cache_config.caching - completion_batch = [] if session_save: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 5f61d26d0..d42dbda0f 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -318,9 +318,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - is_query_valid, msg = validate_queries(query, query_batch) - if not is_query_valid: - raise ValueError(msg) # Check if session saving is enabled cache_config = CacheConfig() @@ -328,6 +325,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): user_id = getattr(user, "id", None) session_save = user_id and cache_config.caching + if query_batch and session_save: + raise ValueError("You cannot use batch queries with session saving currently.") + if query_batch and self.save_interaction: + raise ValueError("Cannot use batch queries with interaction saving currently.") + + is_query_valid, msg = validate_queries(query, query_batch) + if not is_query_valid: + raise ValueError(msg) + # Load conversation history if enabled conversation_history = "" if session_save: diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 711c7ab40..d64267c10 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -163,13 +163,16 @@ class GraphCompletionRetriever(BaseGraphRetriever): entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets)) # Remove duplicates and update node access, if it is enabled - for batched_entity_nodes in entity_nodes_batch: - # from itertools import chain - # - # flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch)) - # entity_nodes = list(set(flattened_entity_nodes)) + import os - await update_node_access_timestamps(batched_entity_nodes) + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true": + for batched_entity_nodes in entity_nodes_batch: + # from itertools import chain + # + # flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch)) + # entity_nodes = list(set(flattened_entity_nodes)) + + await update_node_access_timestamps(batched_entity_nodes) else: if len(triplets) == 0: logger.warning("Empty context was provided to the completion") @@ -212,6 +215,16 @@ class GraphCompletionRetriever(BaseGraphRetriever): - Any: A generated completion based on the query and context provided. """ + cache_config = CacheConfig() + user = session_user.get() + user_id = getattr(user, "id", None) + session_save = user_id and cache_config.caching + + if query_batch and session_save: + raise ValueError("You cannot use batch queries with session saving currently.") + if query_batch and self.save_interaction: + raise ValueError("Cannot use batch queries with interaction saving currently.") + is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: raise ValueError(msg) @@ -230,11 +243,6 @@ class GraphCompletionRetriever(BaseGraphRetriever): else: context_text = await resolve_edges_to_text(triplets) - cache_config = CacheConfig() - user = session_user.get() - user_id = getattr(user, "id", None) - session_save = user_id and cache_config.caching - if session_save: conversation_history = await get_conversation_history(session_id=session_id) From aa2712ddb889a33ea87bf959643144d71c39b220 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 11:36:58 +0100 Subject: [PATCH 14/19] chore: remove unnecessary comments --- cognee/modules/retrieval/graph_completion_retriever.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index d64267c10..eae23bb0d 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -167,11 +167,6 @@ class GraphCompletionRetriever(BaseGraphRetriever): if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true": for batched_entity_nodes in entity_nodes_batch: - # from itertools import chain - # - # flattened_entity_nodes = list(chain.from_iterable(entity_nodes_batch)) - # entity_nodes = list(set(flattened_entity_nodes)) - await update_node_access_timestamps(batched_entity_nodes) else: if len(triplets) == 0: From 8eee3990f7a41215e8a2aa7123a5c0f0269553f9 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 11:55:41 +0100 Subject: [PATCH 15/19] fix: track followup question in cot retriever --- cognee/modules/retrieval/graph_completion_cot_retriever.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index d42dbda0f..be28cc97c 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -277,6 +277,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): for batched_query_state in query_state_tracker.values() ] ) + + for batched_query, batched_followup_question in zip( + query_batch, followup_question_batch + ): + query_state_tracker[batched_query].followup_question = batched_followup_question + logger.info( f"Chain-of-thought: round {round_idx} - follow-up questions: {followup_question_batch}" ) From a5494513d76420856d2fab2a571ec379f3c21288 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 15:11:58 +0100 Subject: [PATCH 16/19] fix: retrievers work with duplicates, added tests also, plus PR change requests done --- .../retrieval/exceptions/exceptions.py | 10 +++ ..._completion_context_extension_retriever.py | 49 +++++++++------ .../graph_completion_cot_retriever.py | 61 +++++++++---------- .../retrieval/graph_completion_retriever.py | 11 +++- cognee/modules/retrieval/utils/query_state.py | 34 +++++++++++ ...letion_retriever_context_extension_test.py | 57 +++++++++++++++++ .../graph_completion_retriever_cot_test.py | 37 ++++++++++- .../graph_completion_retriever_test.py | 40 ++++++++++++ 8 files changed, 244 insertions(+), 55 deletions(-) create mode 100644 cognee/modules/retrieval/utils/query_state.py diff --git a/cognee/modules/retrieval/exceptions/exceptions.py b/cognee/modules/retrieval/exceptions/exceptions.py index 3e934909b..0efaf7351 100644 --- a/cognee/modules/retrieval/exceptions/exceptions.py +++ b/cognee/modules/retrieval/exceptions/exceptions.py @@ -40,3 +40,13 @@ class CollectionDistancesNotFoundError(CogneeValidationError): status_code: int = status.HTTP_404_NOT_FOUND, ): super().__init__(message, name, status_code) + + +class QueryValidationError(CogneeValidationError): + def __init__( + self, + message: str = "Queries not supplied in the correct format.", + name: str = "QueryValidationError", + status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT, + ): + super().__init__(message, name, status_code) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index d005759c9..ef7708dce 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,6 +1,8 @@ import asyncio from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError +from cognee.modules.retrieval.utils.query_state import QueryState from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -15,19 +17,6 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() -class QueryState: - """ - Helper class containing all necessary information about the query state: - the triplets and context associated with it, and also a check whether - it has fully extended the context. - """ - - def __init__(self, triplets: List[Edge], context_text: str, finished_extending_context: bool): - self.triplets = triplets - self.context_text = context_text - self.finished_extending_context = finished_extending_context - - class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): """ Handles graph context completion for question answering tasks, extending context based @@ -112,13 +101,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_save = user_id and cache_config.caching if query_batch and session_save: - raise ValueError("You cannot use batch queries with session saving currently.") + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) if query_batch and self.save_interaction: - raise ValueError("Cannot use batch queries with interaction saving currently.") + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: - raise ValueError(msg) + raise QueryValidationError(message=msg) triplets_batch = context @@ -190,7 +183,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets_batch = await self.get_context(query_batch=completions) - for batched_query, batched_new_triplets in zip(query_batch, new_triplets_batch): + for batched_query, batched_new_triplets in zip( + finished_queries_states.keys(), new_triplets_batch + ): finished_queries_states[batched_query].triplets = list( dict.fromkeys( finished_queries_states[batched_query].triplets + batched_new_triplets @@ -207,7 +202,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) # Update context_texts in query states - for batched_query, batched_context_text in zip(query_batch, context_text_batch): + for batched_query, batched_context_text in zip( + finished_queries_states.keys(), context_text_batch + ): if not finished_queries_states[batched_query].finished_extending_context: finished_queries_states[batched_query].context_text = batched_context_text @@ -216,7 +213,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): for batched_query_state in finished_queries_states.values() ] - for batched_query, prev_size, new_size in zip(query_batch, prev_sizes, new_sizes): + for batched_query, prev_size, new_size in zip( + finished_queries_states.keys(), prev_sizes, new_sizes + ): # Mark done queries accordingly if prev_size == new_size: finished_queries_states[batched_query].finished_extending_context = True @@ -229,6 +228,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): round_idx += 1 completion_batch = [] + result_completion_batch = [] if session_save: conversation_history = await get_conversation_history(session_id=session_id) @@ -260,6 +260,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ], ) + # Make sure answers are returned for duplicate queries, in the order they were asked. + for batched_query, batched_completion in zip( + finished_queries_states.keys(), completion_batch + ): + finished_queries_states[batched_query].completion = batched_completion + + for batched_query in query_batch: + result_completion_batch.append(finished_queries_states[batched_query].completion) + # TODO: Do batch queries for save interaction if self.save_interaction and context_text_batch and triplets_batch and completion_batch: await self.save_qa( @@ -277,4 +286,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion_batch if completion_batch else [completion] + return result_completion_batch if result_completion_batch else [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index be28cc97c..5675c8c70 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -3,6 +3,8 @@ import json from typing import Optional, List, Type, Any from pydantic import BaseModel from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError +from cognee.modules.retrieval.utils.query_state import QueryState from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger @@ -23,28 +25,6 @@ from cognee.infrastructure.databases.cache.config import CacheConfig logger = get_logger() -class QueryState: - """ - Helper class containing all necessary information about the query state. - Used to keep track of important information in a more readable way, and - enable as many parallel calls to llms as possible. - """ - - def __init__(self): - self.completion: str = "" - self.triplets: List[Edge] = [] - self.context_text: str = "" - - self.answer_text: str = "" - self.valid_user_prompt: str = "" - self.valid_system_prompt: str = "" - self.reasoning: str = "" - - self.followup_question: str = "" - self.followup_prompt: str = "" - self.followup_system: str = "" - - def _as_answer_text(completion: Any) -> str: """Convert completion to human-readable text for validation and follow-up prompts.""" if isinstance(completion, str): @@ -155,7 +135,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if round_idx == 0: if context is None: # Get context, resolve to text, and store info in the query state - triplets_batch = await self.get_context(query_batch=query_batch) + triplets_batch = await self.get_context( + query_batch=list(query_state_tracker.keys()) + ) context_text_batch = await asyncio.gather( *[ self.resolve_edges_to_text(batched_triplets) @@ -163,7 +145,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) for batched_query, batched_triplets, batched_context_text in zip( - query_batch, triplets_batch, context_text_batch + query_state_tracker.keys(), triplets_batch, context_text_batch ): query_state_tracker[batched_query].triplets = batched_triplets query_state_tracker[batched_query].context_text = batched_context_text @@ -176,7 +158,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) for batched_query, batched_triplets, batched_context_text in zip( - query_batch, context, context_text_batch + query_state_tracker.keys(), context, context_text_batch ): query_state_tracker[batched_query].triplets = batched_triplets query_state_tracker[batched_query].context_text = batched_context_text @@ -184,7 +166,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): # Find new triplets, and update existing query states triplets_batch = await self.get_context(query_batch=followup_question_batch) - for batched_query, batched_followup_triplets in zip(query_batch, triplets_batch): + for batched_query, batched_followup_triplets in zip( + query_state_tracker.keys(), triplets_batch + ): query_state_tracker[batched_query].triplets = list( dict.fromkeys( query_state_tracker[batched_query].triplets + batched_followup_triplets @@ -198,7 +182,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) - for batched_query, batched_context_text in zip(query_batch, context_text_batch): + for batched_query, batched_context_text in zip( + query_state_tracker.keys(), context_text_batch + ): query_state_tracker[batched_query].context_text = batched_context_text completion_batch = await asyncio.gather( @@ -216,9 +202,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) - for batched_query, batched_completion in zip(query_batch, completion_batch): + for batched_query, batched_completion in zip( + query_state_tracker.keys(), completion_batch + ): query_state_tracker[batched_query].completion = batched_completion + if round_idx == max_iter: + # When we finish all iterations: + # Make sure answers are returned for duplicate queries, in the order they were asked. + completion_batch = [] + for batched_query in query_batch: + completion_batch.append(query_state_tracker[batched_query].completion) + logger.info(f"Chain-of-thought: round {round_idx} - answers: {completion_batch}") if round_idx < max_iter: @@ -332,13 +327,17 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_save = user_id and cache_config.caching if query_batch and session_save: - raise ValueError("You cannot use batch queries with session saving currently.") + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) if query_batch and self.save_interaction: - raise ValueError("Cannot use batch queries with interaction saving currently.") + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: - raise ValueError(msg) + raise QueryValidationError(message=msg) # Load conversation history if enabled conversation_history = "" diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index eae23bb0d..d9667a669 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -4,6 +4,7 @@ from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.exceptions.exceptions import QueryValidationError from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils import resolve_edges_to_text @@ -216,13 +217,17 @@ class GraphCompletionRetriever(BaseGraphRetriever): session_save = user_id and cache_config.caching if query_batch and session_save: - raise ValueError("You cannot use batch queries with session saving currently.") + raise QueryValidationError( + message="You cannot use batch queries with session saving currently." + ) if query_batch and self.save_interaction: - raise ValueError("Cannot use batch queries with interaction saving currently.") + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: - raise ValueError(msg) + raise QueryValidationError(message=msg) triplets = context diff --git a/cognee/modules/retrieval/utils/query_state.py b/cognee/modules/retrieval/utils/query_state.py new file mode 100644 index 000000000..a926a952e --- /dev/null +++ b/cognee/modules/retrieval/utils/query_state.py @@ -0,0 +1,34 @@ +from typing import List +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +class QueryState: + """ + Helper class containing all necessary information about the query state. + Used (for now) in COT and Context Extension Retrievers to keep track of important information + in a more readable way, and enable as many parallel calls to llms as possible. + """ + + def __init__( + self, + triplets: List[Edge] = None, + context_text: str = "", + finished_extending_context: bool = False, + ): + # Mutual fields for COT and Context Extension + self.triplets = triplets if triplets else [] + self.context_text = context_text + self.completion = "" + + # Context Extension specific + self.finished_extending_context = finished_extending_context + + # COT specific + self.answer_text: str = "" + self.valid_user_prompt: str = "" + self.valid_system_prompt: str = "" + self.reasoning: str = "" + + self.followup_question: str = "" + self.followup_prompt: str = "" + self.followup_system: str = "" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 7ce0d3a57..9ceca96e2 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -747,3 +747,60 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion batch queries with duplicate queries.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[[mock_edge], [mock_edge]], [[mock_edge2], [mock_edge2]]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=[ + "Resolved context", + "Resolved context", + "Extended context", + "Extended context", + ], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index aa71bd9c7..1a6155c4f 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -375,7 +375,7 @@ async def test_get_completion_with_provided_context(mock_edge): mock_config.caching = False mock_cache_config.return_value = mock_config - completion = await retriever.get_completion("test query", context=[[mock_edge]], max_iter=1) + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) assert isinstance(completion, list) assert len(completion) == 1 @@ -818,3 +818,38 @@ async def test_get_completion_batch_queries_with_response_model(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert isinstance(completion[0], TestModel) and isinstance(completion[1], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion batch queries without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 1"], max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index ae0adb729..159fa2df4 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -802,3 +802,43 @@ async def test_get_completion_batch_queries_empty_context(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_duplicate_queries(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion(query_batch=["test query 1", "test query 1"]) + + assert isinstance(completion, list) + assert len(completion) == 2 + assert completion[0] == "Generated answer" and completion[1] == "Generated answer" From abf1ef9d2943805117ddc0337bfd230832044519 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 16:26:17 +0100 Subject: [PATCH 17/19] fix: some new fixes --- ..._completion_context_extension_retriever.py | 45 +++++++++---------- .../graph_completion_cot_retriever.py | 6 ++- .../graph_completion_retriever_test.py | 2 +- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index ef7708dce..0dc3a8bf6 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -162,63 +162,58 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) break + relevant_queries = [ + rel_query + for rel_query in finished_queries_states.keys() + if not finished_queries_states[rel_query].finished_extending_context + ] + prev_sizes = [ - len(batched_query_state.triplets) - for batched_query_state in finished_queries_states.values() + len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries ] completions = await asyncio.gather( *[ generate_completion( - query=batched_query, - context=batched_query_state.context_text, + query=rel_query, + context=finished_queries_states[rel_query].context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) - for batched_query, batched_query_state in finished_queries_states.items() - if not batched_query_state.finished_extending_context + for rel_query in relevant_queries ], ) # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets_batch = await self.get_context(query_batch=completions) - for batched_query, batched_new_triplets in zip( - finished_queries_states.keys(), new_triplets_batch - ): - finished_queries_states[batched_query].triplets = list( + for rel_query, batched_new_triplets in zip(relevant_queries, new_triplets_batch): + finished_queries_states[rel_query].triplets = list( dict.fromkeys( - finished_queries_states[batched_query].triplets + batched_new_triplets + finished_queries_states[rel_query].triplets + batched_new_triplets ) ) # Resolve new triplets to text context_text_batch = await asyncio.gather( *[ - self.resolve_edges_to_text(batched_query_state.triplets) - for batched_query_state in finished_queries_states.values() - if not batched_query_state.finished_extending_context + self.resolve_edges_to_text(finished_queries_states[rel_query].triplets) + for rel_query in relevant_queries ] ) # Update context_texts in query states - for batched_query, batched_context_text in zip( - finished_queries_states.keys(), context_text_batch - ): - if not finished_queries_states[batched_query].finished_extending_context: - finished_queries_states[batched_query].context_text = batched_context_text + for rel_query, batched_context_text in zip(relevant_queries, context_text_batch): + finished_queries_states[rel_query].context_text = batched_context_text new_sizes = [ - len(batched_query_state.triplets) - for batched_query_state in finished_queries_states.values() + len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries ] - for batched_query, prev_size, new_size in zip( - finished_queries_states.keys(), prev_sizes, new_sizes - ): + for rel_query, prev_size, new_size in zip(relevant_queries, prev_sizes, new_sizes): # Mark done queries accordingly if prev_size == new_size: - finished_queries_states[batched_query].finished_extending_context = True + finished_queries_states[rel_query].finished_extending_context = True logger.info( f"Context extension: round {round_idx} - " diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 5675c8c70..114578aa9 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -245,7 +245,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) - for batched_query, batched_reasoning in zip(query_batch, reasoning_batch): + for batched_query, batched_reasoning in zip( + query_state_tracker.keys(), reasoning_batch + ): query_state_tracker[batched_query].reasoning = batched_reasoning for batched_query, batched_query_state in query_state_tracker.items(): @@ -274,7 +276,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ) for batched_query, batched_followup_question in zip( - query_batch, followup_question_batch + query_state_tracker.keys(), followup_question_batch ): query_state_tracker[batched_query].followup_question = batched_followup_question diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 159fa2df4..a6fb05270 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -806,7 +806,7 @@ async def test_get_completion_batch_queries_empty_context(mock_edge): @pytest.mark.asyncio async def test_get_completion_batch_queries_duplicate_queries(mock_edge): - """Test get_completion retrieves context when not provided.""" + """Test get_completion batch queries with duplicate queries.""" mock_graph_engine = AsyncMock() mock_graph_engine.is_empty = AsyncMock(return_value=False) From 5c8475a92a5f0eeb2c8fd3270bfc657d48c1d38a Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 18:59:18 +0100 Subject: [PATCH 18/19] feat: add batch queries to save interaction --- ..._completion_context_extension_retriever.py | 20 ++--- .../graph_completion_cot_retriever.py | 32 ++++--- .../retrieval/graph_completion_retriever.py | 27 ++++-- ...letion_retriever_context_extension_test.py | 69 +++++++++++++++ .../graph_completion_retriever_cot_test.py | 84 +++++++++++++++++++ .../graph_completion_retriever_test.py | 41 +++++++++ 6 files changed, 244 insertions(+), 29 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 0dc3a8bf6..afbaa2978 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -104,10 +104,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) - if query_batch and self.save_interaction: - raise QueryValidationError( - message="Cannot use batch queries with interaction saving currently." - ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -264,13 +260,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): for batched_query in query_batch: result_completion_batch.append(finished_queries_states[batched_query].completion) - # TODO: Do batch queries for save interaction if self.save_interaction and context_text_batch and triplets_batch and completion_batch: - await self.save_qa( - question=query, - answer=completion_batch[0], - context=context_text_batch[0], - triplets=triplets_batch[0], + await asyncio.gather( + *[ + self.save_qa( + question=batched_query, + answer=finished_queries_states[batched_query].completion, + context=finished_queries_states[batched_query].context_text, + triplets=finished_queries_states[batched_query].triplets, + ) + for batched_query in query_batch + ] ) if session_save: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 114578aa9..76f27b65e 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -332,10 +332,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) - if query_batch and self.save_interaction: - raise QueryValidationError( - message="Cannot use batch queries with interaction saving currently." - ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -355,14 +351,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): response_model=response_model, ) - # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: - await self.save_qa( - question=query, - answer=str(completion[0]), - context=context_text[0], - triplets=triplets[0], - ) + if query_batch: + await asyncio.gather( + *[ + self.save_qa( + question=batched_query, + answer=str(batched_completion), + context=batched_context_text, + triplets=batched_triplet, + ) + for batched_query, batched_completion, batched_context_text, batched_triplet in zip( + query_batch, completion, context_text, triplets + ) + ] + ) + else: + await self.save_qa( + question=query, + answer=str(completion[0]), + context=context_text[0], + triplets=triplets[0], + ) # TODO: Handle session save interaction for batch queries # Save to session cache if enabled diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index d9667a669..cbaec599f 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -220,10 +220,6 @@ class GraphCompletionRetriever(BaseGraphRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) - if query_batch and self.save_interaction: - raise QueryValidationError( - message="Cannot use batch queries with interaction saving currently." - ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -236,7 +232,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): context_text = "" context_text_batch = [] - if triplets and isinstance(triplets[0], list): + if query_batch: context_text_batch = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) @@ -284,9 +280,24 @@ class GraphCompletionRetriever(BaseGraphRetriever): ) if self.save_interaction and context and triplets and completion: - await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets - ) + if query: + await self.save_qa( + question=query, answer=completion, context=context_text, triplets=triplets + ) + else: + await asyncio.gather( + *[ + await self.save_qa( + question=batched_query, + answer=batched_completion, + context=batched_context_text, + triplets=batched_triplets, + ) + for batched_query, batched_completion, batched_context_text, batched_triplets in zip( + query_batch, completion, context_text_batch, triplets + ) + ] + ) if session_save: await save_conversation_history( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 9ceca96e2..cec71a443 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -804,3 +804,72 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_save_interaction(mock_edge): + """Test get_completion batch queries with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Extension query", + "Generated answer", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + context_extension_rounds=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 1a6155c4f..4b05021c3 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -853,3 +853,87 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_save_interaction(mock_edge): + """Test get_completion batch queries with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionCotRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=[ + "validation_result", + "validation_result", + "followup_question", + "followup_question", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # Pass context so save_interaction condition is met + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 2"], + context=[[mock_edge], [mock_edge]], + max_iter=1, + ) + + assert isinstance(completion, list) + assert len(completion) == 2 + mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index a6fb05270..3673b7ace 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -842,3 +842,44 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_batch_queries_with_save_interaction(mock_edge): + """Test get_completion batch queries with save_interaction.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[[mock_edge], [mock_edge]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + query_batch=["test query 1", "test query 1"], context=None + ) + + assert isinstance(completion, list) + assert len(completion) == 2 From a0a3283d8fcad72321e6024140b700d6568594f9 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Tue, 20 Jan 2026 19:23:34 +0100 Subject: [PATCH 19/19] Revert "feat: add batch queries to save interaction" This reverts commit 5c8475a92a5f0eeb2c8fd3270bfc657d48c1d38a. --- ..._completion_context_extension_retriever.py | 20 ++--- .../graph_completion_cot_retriever.py | 32 +++---- .../retrieval/graph_completion_retriever.py | 27 ++---- ...letion_retriever_context_extension_test.py | 69 --------------- .../graph_completion_retriever_cot_test.py | 84 ------------------- .../graph_completion_retriever_test.py | 41 --------- 6 files changed, 29 insertions(+), 244 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index afbaa2978..0dc3a8bf6 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -104,6 +104,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -260,17 +264,13 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): for batched_query in query_batch: result_completion_batch.append(finished_queries_states[batched_query].completion) + # TODO: Do batch queries for save interaction if self.save_interaction and context_text_batch and triplets_batch and completion_batch: - await asyncio.gather( - *[ - self.save_qa( - question=batched_query, - answer=finished_queries_states[batched_query].completion, - context=finished_queries_states[batched_query].context_text, - triplets=finished_queries_states[batched_query].triplets, - ) - for batched_query in query_batch - ] + await self.save_qa( + question=query, + answer=completion_batch[0], + context=context_text_batch[0], + triplets=triplets_batch[0], ) if session_save: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 76f27b65e..114578aa9 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -332,6 +332,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -351,28 +355,14 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): response_model=response_model, ) + # TODO: Handle save interaction for batch queries if self.save_interaction and context and triplets and completion: - if query_batch: - await asyncio.gather( - *[ - self.save_qa( - question=batched_query, - answer=str(batched_completion), - context=batched_context_text, - triplets=batched_triplet, - ) - for batched_query, batched_completion, batched_context_text, batched_triplet in zip( - query_batch, completion, context_text, triplets - ) - ] - ) - else: - await self.save_qa( - question=query, - answer=str(completion[0]), - context=context_text[0], - triplets=triplets[0], - ) + await self.save_qa( + question=query, + answer=str(completion[0]), + context=context_text[0], + triplets=triplets[0], + ) # TODO: Handle session save interaction for batch queries # Save to session cache if enabled diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index cbaec599f..d9667a669 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -220,6 +220,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): raise QueryValidationError( message="You cannot use batch queries with session saving currently." ) + if query_batch and self.save_interaction: + raise QueryValidationError( + message="Cannot use batch queries with interaction saving currently." + ) is_query_valid, msg = validate_queries(query, query_batch) if not is_query_valid: @@ -232,7 +236,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): context_text = "" context_text_batch = [] - if query_batch: + if triplets and isinstance(triplets[0], list): context_text_batch = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) @@ -280,24 +284,9 @@ class GraphCompletionRetriever(BaseGraphRetriever): ) if self.save_interaction and context and triplets and completion: - if query: - await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets - ) - else: - await asyncio.gather( - *[ - await self.save_qa( - question=batched_query, - answer=batched_completion, - context=batched_context_text, - triplets=batched_triplets, - ) - for batched_query, batched_completion, batched_context_text, batched_triplets in zip( - query_batch, completion, context_text_batch, triplets - ) - ] - ) + await self.save_qa( + question=query, answer=completion, context=context_text, triplets=triplets + ) if session_save: await save_conversation_history( diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index cec71a443..9ceca96e2 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -804,72 +804,3 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_batch_queries_with_save_interaction(mock_edge): - """Test get_completion batch queries with save_interaction enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object( - retriever, - "get_context", - new_callable=AsyncMock, - return_value=[[mock_edge], [mock_edge]], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - "Extension query", - "Generated answer", - "Generated answer", - ], # Extension query, then final answer - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=[ - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - query_batch=["test query 1", "test query 2"], - context=[[mock_edge], [mock_edge]], - context_extension_rounds=1, - ) - - assert isinstance(completion, list) - assert len(completion) == 2 - mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 4b05021c3..1a6155c4f 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -853,87 +853,3 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_batch_queries_with_save_interaction(mock_edge): - """Test get_completion batch queries with save_interaction enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionCotRetriever(save_interaction=True) - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch.object( - retriever, - "get_context", - new_callable=AsyncMock, - return_value=[[mock_edge], [mock_edge]], - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", - return_value="Rendered prompt", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - side_effect=[ - "validation_result", - "validation_result", - "followup_question", - "followup_question", - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=[ - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - # Pass context so save_interaction condition is met - completion = await retriever.get_completion( - query_batch=["test query 1", "test query 2"], - context=[[mock_edge], [mock_edge]], - max_iter=1, - ) - - assert isinstance(completion, list) - assert len(completion) == 2 - mock_add_data.assert_awaited() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 3673b7ace..a6fb05270 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -842,44 +842,3 @@ async def test_get_completion_batch_queries_duplicate_queries(mock_edge): assert isinstance(completion, list) assert len(completion) == 2 assert completion[0] == "Generated answer" and completion[1] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_batch_queries_with_save_interaction(mock_edge): - """Test get_completion batch queries with save_interaction.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever(save_interaction=True) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[[mock_edge], [mock_edge]], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - query_batch=["test query 1", "test query 1"], context=None - ) - - assert isinstance(completion, list) - assert len(completion) == 2