From 2655df9b218b4116d84c555e950dc3a92775d44d Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 15 Jan 2026 19:29:21 +0100 Subject: [PATCH] fix: fix failing tests --- ..._completion_context_extension_retriever.py | 20 ++++++++++++------- .../retrieval/graph_completion_retriever.py | 2 +- ...letion_retriever_context_extension_test.py | 16 +++++++-------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 7774ba9e5..ddf4ef615 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -95,16 +95,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): if query: # This is done mostly to avoid duplicating a lot of code unnecessarily query_batch = [query] - query = None if triplets: triplets = [triplets] if triplets is None: - triplets = await self.get_context(query, query_batch) + triplets = await self.get_context(query_batch=query_batch) context_text = "" context_texts = "" - if isinstance(triplets[0], list): + if triplets and isinstance(triplets[0], list): context_texts = await asyncio.gather( *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) @@ -123,6 +122,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): f"Context extension: round {round_idx} - generating next graph locational query." ) + # Filter out the queries that cannot be extended further, and their associated contexts query_batch = [query for query in query_batch if query] triplets = [triplet_element for triplet_element in triplets if triplet_element] context_texts = [context_text for context_text in context_texts if context_text] @@ -181,8 +181,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Reset variables for the final generations. They contain the final state # of triplets and contexts for each query, after all extension iterations. query_batch = original_query_batch - context_texts = saved_context_texts - triplets = saved_triplets + context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts + triplets = saved_triplets if len(saved_triplets) > 0 else triplets + + if len(query_batch) == 1: + triplets = [] if not triplets else triplets[0] + context_text = "" if not context_texts else context_texts[0] # Check if we need to generate context summary for caching cache_config = CacheConfig() @@ -221,8 +225,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) if self.save_interaction and context_text and triplets and completion: + if isinstance(completion, list): + completion = completion[0] await self.save_qa( - question=query, answer=completion, context=context_text, triplets=triplets + question=query, answer=completion[0], context=context_text, triplets=triplets ) if session_save: @@ -233,4 +239,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id=session_id, ) - return [completion] + return completion if isinstance(completion, list) else [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 1b78e18ac..5c2c3bf1b 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -197,7 +197,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): context_text = "" context_texts = "" - if isinstance(triplets[0], list): + if triplets and isinstance(triplets[0], list): context_texts = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 6a9b07d38..9095af69c 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -81,7 +81,7 @@ async def test_get_completion_without_context(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], + return_value=[[mock_edge]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -157,7 +157,7 @@ async def test_get_completion_context_extension_rounds(mock_edge): retriever, "get_context", new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], + side_effect=[[[mock_edge]], [[mock_edge2]]], ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", @@ -194,7 +194,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge): retriever = GraphCompletionContextExtensionRetriever() with ( - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -240,7 +240,7 @@ async def test_get_completion_with_session(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -304,7 +304,7 @@ async def test_get_completion_with_save_interaction(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -361,7 +361,7 @@ async def test_get_completion_with_response_model(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -403,7 +403,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context", @@ -446,7 +446,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge): "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", return_value=mock_graph_engine, ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", return_value="Resolved context",