fix: fix failing tests

This commit is contained in:
Andrej Milicevic 2026-01-15 19:29:21 +01:00
parent 40667e63c9
commit 2655df9b21
3 changed files with 22 additions and 16 deletions

View file

@ -95,16 +95,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
if query:
# This is done mostly to avoid duplicating a lot of code unnecessarily
query_batch = [query]
query = None
if triplets:
triplets = [triplets]
if triplets is None:
triplets = await self.get_context(query, query_batch)
triplets = await self.get_context(query_batch=query_batch)
context_text = ""
context_texts = ""
if isinstance(triplets[0], list):
if triplets and isinstance(triplets[0], list):
context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
@ -123,6 +122,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
f"Context extension: round {round_idx} - generating next graph locational query."
)
# Filter out the queries that cannot be extended further, and their associated contexts
query_batch = [query for query in query_batch if query]
triplets = [triplet_element for triplet_element in triplets if triplet_element]
context_texts = [context_text for context_text in context_texts if context_text]
@ -181,8 +181,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
# Reset variables for the final generations. They contain the final state
# of triplets and contexts for each query, after all extension iterations.
query_batch = original_query_batch
context_texts = saved_context_texts
triplets = saved_triplets
context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts
triplets = saved_triplets if len(saved_triplets) > 0 else triplets
if len(query_batch) == 1:
triplets = [] if not triplets else triplets[0]
context_text = "" if not context_texts else context_texts[0]
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
@ -221,8 +225,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
)
if self.save_interaction and context_text and triplets and completion:
if isinstance(completion, list):
completion = completion[0]
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
question=query, answer=completion[0], context=context_text, triplets=triplets
)
if session_save:
@ -233,4 +239,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
session_id=session_id,
)
return [completion]
return completion if isinstance(completion, list) else [completion]

View file

@ -197,7 +197,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
context_text = ""
context_texts = ""
if isinstance(triplets[0], list):
if triplets and isinstance(triplets[0], list):
context_texts = await asyncio.gather(
*[resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)

View file

@ -81,7 +81,7 @@ async def test_get_completion_without_context(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
return_value=[mock_edge],
return_value=[[mock_edge]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -157,7 +157,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
retriever,
"get_context",
new_callable=AsyncMock,
side_effect=[[mock_edge], [mock_edge2]],
side_effect=[[[mock_edge]], [[mock_edge2]]],
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
@ -194,7 +194,7 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
retriever = GraphCompletionContextExtensionRetriever()
with (
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -240,7 +240,7 @@ async def test_get_completion_with_session(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -304,7 +304,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -361,7 +361,7 @@ async def test_get_completion_with_response_model(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -403,7 +403,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",
@ -446,7 +446,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[[mock_edge]]),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
return_value="Resolved context",