fix: fix failing tests

This commit is contained in:
Andrej Milicevic 2026-01-15 19:29:21 +01:00
parent 40667e63c9
commit 2655df9b21
3 changed files with 22 additions and 16 deletions

View file

@ -95,16 +95,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
if query: if query:
# This is done mostly to avoid duplicating a lot of code unnecessarily # This is done mostly to avoid duplicating a lot of code unnecessarily
query_batch = [query] query_batch = [query]
query = None
if triplets: if triplets:
triplets = [triplets] triplets = [triplets]
if triplets is None: if triplets is None:
triplets = await self.get_context(query, query_batch) triplets = await self.get_context(query_batch=query_batch)
context_text = "" context_text = ""
context_texts = "" context_texts = ""
if isinstance(triplets[0], list): if triplets and isinstance(triplets[0], list):
context_texts = await asyncio.gather( context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
) )
@ -123,6 +122,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
f"Context extension: round {round_idx} - generating next graph locational query." f"Context extension: round {round_idx} - generating next graph locational query."
) )
# Filter out the queries that cannot be extended further, and their associated contexts
query_batch = [query for query in query_batch if query] query_batch = [query for query in query_batch if query]
triplets = [triplet_element for triplet_element in triplets if triplet_element] triplets = [triplet_element for triplet_element in triplets if triplet_element]
context_texts = [context_text for context_text in context_texts if context_text] context_texts = [context_text for context_text in context_texts if context_text]
@ -181,8 +181,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
# Reset variables for the final generations. They contain the final state # Reset variables for the final generations. They contain the final state
# of triplets and contexts for each query, after all extension iterations. # of triplets and contexts for each query, after all extension iterations.
query_batch = original_query_batch query_batch = original_query_batch
context_texts = saved_context_texts context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts
triplets = saved_triplets triplets = saved_triplets if len(saved_triplets) > 0 else triplets
if len(query_batch) == 1:
triplets = [] if not triplets else triplets[0]
context_text = "" if not context_texts else context_texts[0]
# Check if we need to generate context summary for caching # Check if we need to generate context summary for caching
cache_config = CacheConfig() cache_config = CacheConfig()
@ -221,8 +225,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
) )
if self.save_interaction and context_text and triplets and completion: if self.save_interaction and context_text and triplets and completion:
if isinstance(completion, list):
completion = completion[0]
await self.save_qa( await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion[0], context=context_text, triplets=triplets
) )
if session_save: if session_save:
@ -233,4 +239,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
session_id=session_id, session_id=session_id,
) )
return [completion] return completion if isinstance(completion, list) else [completion]

View file

@ -197,7 +197,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
context_text = "" context_text = ""
context_texts = "" context_texts = ""
if isinstance(triplets[0], list): if triplets and isinstance(triplets[0], list):
context_texts = await asyncio.gather( context_texts = await asyncio.gather(
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets] *[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
) )

View file

@ -81,7 +81,7 @@ async def test_get_completion_without_context(mock_edge):
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge], return_value=[[mock_edge]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -157,7 +157,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
retriever, retriever,
"get_context", "get_context",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]], side_effect=[[[mock_edge]], [[mock_edge2]]],
), ),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -194,7 +194,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
retriever = GraphCompletionContextExtensionRetriever() retriever = GraphCompletionContextExtensionRetriever()
with ( with (
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -240,7 +240,7 @@ async def test_get_completion_with_session(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -304,7 +304,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -361,7 +361,7 @@ async def test_get_completion_with_response_model(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -403,7 +403,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",
@ -446,7 +446,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine, return_value=mock_graph_engine,
), ),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch( patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context", return_value="Resolved context",