fix: fix logic issue that coderabbit flagged regarding ordering of lists

This commit is contained in:
Andrej Milicevic 2026-01-16 19:32:15 +01:00
parent 98e8d226eb
commit 17554466ba
2 changed files with 25 additions and 23 deletions

View file

@ -102,21 +102,20 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
triplets = await self.get_context(query_batch=query_batch)
context_text = ""
context_texts = ""
if triplets and isinstance(triplets[0], list):
context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
else:
context_text = await self.resolve_edges_to_text(triplets)
context_texts = await asyncio.gather(
*[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets]
)
round_idx = 1
# We will be removing queries, and their associated triplets and context, as we go
# through iterations, so we need to save their final states for the final generation.
# Final state is stored in the finished_queries_data dict, and we populate it at the start as well.
original_query_batch = query_batch
saved_triplets = []
saved_context_texts = []
finished_queries_data = {}
for i, query in enumerate(query_batch):
finished_queries_data[query] = (triplets[i], context_texts[i])
while round_idx <= context_extension_rounds:
logger.info(
f"Context extension: round {round_idx} - generating next graph locational query."
@ -147,6 +146,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
],
)
# Get new triplets, and merge them with existing ones, filtering out duplicates
new_triplets = await self.get_context(query_batch=completions)
for i, (triplets_element, new_triplets_element) in enumerate(
zip(triplets, new_triplets)
@ -160,15 +160,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
new_sizes = [len(triplets_element) for triplets_element in triplets]
for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate(
for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate(
zip(query_batch, prev_sizes, new_sizes, triplets, context_texts)
):
finished_queries_data[query] = (triplets_element, context_text)
if prev_size == new_size:
# In this case, we can stop trying to extend the context of this query
query_batch[i] = ""
saved_triplets.append(triplet_element)
triplets[i] = []
saved_context_texts.append(context_text)
context_texts[i] = ""
logger.info(
@ -181,12 +180,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
# Reset variables for the final generations. They contain the final state
# of triplets and contexts for each query, after all extension iterations.
query_batch = original_query_batch
context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts
triplets = saved_triplets if len(saved_triplets) > 0 else triplets
if len(query_batch) == 1:
triplets = [] if not triplets else triplets[0]
context_text = "" if not context_texts else context_texts[0]
triplets = []
context_texts = []
for query in query_batch:
triplets.append(finished_queries_data[query][0])
context_texts.append(finished_queries_data[query][1])
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
@ -224,11 +222,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
],
)
if self.save_interaction and context_text and triplets and completion:
if isinstance(completion, list):
completion = completion[0]
# TODO: Do batch queries for save interaction
if self.save_interaction and context_texts and triplets and completion:
await self.save_qa(
question=query, answer=completion[0], context=context_text, triplets=triplets
question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0]
)
if session_save:

View file

@ -569,7 +569,12 @@ async def test_get_completion_batch_queries_context_extension_rounds(mock_edge):
),
patch(
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts
side_effect=[
"Resolved context",
"Resolved context",
"Extended context",
"Extended context",
], # Different contexts
),
patch(
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",