diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index ddf4ef615..1d20d4404 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -102,21 +102,20 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): triplets = await self.get_context(query_batch=query_batch) context_text = "" - context_texts = "" - if triplets and isinstance(triplets[0], list): - context_texts = await asyncio.gather( - *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] - ) - else: - context_text = await self.resolve_edges_to_text(triplets) + context_texts = await asyncio.gather( + *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) round_idx = 1 # We will be removing queries, and their associated triplets and context, as we go # through iterations, so we need to save their final states for the final generation. + # Final state is stored in the finished_queries_data dict, and we populate it at the start as well. original_query_batch = query_batch - saved_triplets = [] - saved_context_texts = [] + finished_queries_data = {} + for i, query in enumerate(query_batch): + finished_queries_data[query] = (triplets[i], context_texts[i]) + while round_idx <= context_extension_rounds: logger.info( f"Context extension: round {round_idx} - generating next graph locational query." @@ -147,6 +146,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ], ) + # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets = await self.get_context(query_batch=completions) for i, (triplets_element, new_triplets_element) in enumerate( zip(triplets, new_triplets) @@ -160,15 +160,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): new_sizes = [len(triplets_element) for triplets_element in triplets] - for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate( + for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate( zip(query_batch, prev_sizes, new_sizes, triplets, context_texts) ): + finished_queries_data[query] = (triplets_element, context_text) if prev_size == new_size: # In this case, we can stop trying to extend the context of this query query_batch[i] = "" - saved_triplets.append(triplet_element) triplets[i] = [] - saved_context_texts.append(context_text) context_texts[i] = "" logger.info( @@ -181,12 +180,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Reset variables for the final generations. They contain the final state # of triplets and contexts for each query, after all extension iterations. query_batch = original_query_batch - context_texts = saved_context_texts if len(saved_context_texts) > 0 else context_texts - triplets = saved_triplets if len(saved_triplets) > 0 else triplets - - if len(query_batch) == 1: - triplets = [] if not triplets else triplets[0] - context_text = "" if not context_texts else context_texts[0] + triplets = [] + context_texts = [] + for query in query_batch: + triplets.append(finished_queries_data[query][0]) + context_texts.append(finished_queries_data[query][1]) # Check if we need to generate context summary for caching cache_config = CacheConfig() @@ -224,11 +222,10 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ], ) - if self.save_interaction and context_text and triplets and completion: - if isinstance(completion, list): - completion = completion[0] + # TODO: Do batch queries for save interaction + if self.save_interaction and context_texts and triplets and completion: await self.save_qa( - question=query, answer=completion[0], context=context_text, triplets=triplets + question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0] ) if session_save: diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 567edde51..8d6310214 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -569,7 +569,12 @@ async def test_get_completion_batch_queries_context_extension_rounds(mock_edge): ), patch( "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - side_effect=cycle(["Resolved context", "Extended context"]), # Different contexts + side_effect=[ + "Resolved context", + "Resolved context", + "Extended context", + "Extended context", + ], # Different contexts ), patch( "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",