diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index ef7708dce..0dc3a8bf6 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -162,63 +162,58 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) break + relevant_queries = [ + rel_query + for rel_query in finished_queries_states.keys() + if not finished_queries_states[rel_query].finished_extending_context + ] + prev_sizes = [ - len(batched_query_state.triplets) - for batched_query_state in finished_queries_states.values() + len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries ] completions = await asyncio.gather( *[ generate_completion( - query=batched_query, - context=batched_query_state.context_text, + query=rel_query, + context=finished_queries_states[rel_query].context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) - for batched_query, batched_query_state in finished_queries_states.items() - if not batched_query_state.finished_extending_context + for rel_query in relevant_queries ], ) # Get new triplets, and merge them with existing ones, filtering out duplicates new_triplets_batch = await self.get_context(query_batch=completions) - for batched_query, batched_new_triplets in zip( - finished_queries_states.keys(), new_triplets_batch - ): - finished_queries_states[batched_query].triplets = list( + for rel_query, batched_new_triplets in zip(relevant_queries, new_triplets_batch): + finished_queries_states[rel_query].triplets = list( dict.fromkeys( - finished_queries_states[batched_query].triplets + batched_new_triplets + finished_queries_states[rel_query].triplets + batched_new_triplets ) ) # Resolve new triplets to text context_text_batch = await asyncio.gather( *[ - self.resolve_edges_to_text(batched_query_state.triplets) - for batched_query_state in finished_queries_states.values() - if not batched_query_state.finished_extending_context + self.resolve_edges_to_text(finished_queries_states[rel_query].triplets) + for rel_query in relevant_queries ] ) # Update context_texts in query states - for batched_query, batched_context_text in zip( - finished_queries_states.keys(), context_text_batch - ): - if not finished_queries_states[batched_query].finished_extending_context: - finished_queries_states[batched_query].context_text = batched_context_text + for rel_query, batched_context_text in zip(relevant_queries, context_text_batch): + finished_queries_states[rel_query].context_text = batched_context_text new_sizes = [ - len(batched_query_state.triplets) - for batched_query_state in finished_queries_states.values() + len(finished_queries_states[rel_query].triplets) for rel_query in relevant_queries ] - for batched_query, prev_size, new_size in zip( - finished_queries_states.keys(), prev_sizes, new_sizes - ): + for rel_query, prev_size, new_size in zip(relevant_queries, prev_sizes, new_sizes): # Mark done queries accordingly if prev_size == new_size: - finished_queries_states[batched_query].finished_extending_context = True + finished_queries_states[rel_query].finished_extending_context = True logger.info( f"Context extension: round {round_idx} - " diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 5675c8c70..114578aa9 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -245,7 +245,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ] ) - for batched_query, batched_reasoning in zip(query_batch, reasoning_batch): + for batched_query, batched_reasoning in zip( + query_state_tracker.keys(), reasoning_batch + ): query_state_tracker[batched_query].reasoning = batched_reasoning for batched_query, batched_query_state in query_state_tracker.items(): @@ -274,7 +276,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): ) for batched_query, batched_followup_question in zip( - query_batch, followup_question_batch + query_state_tracker.keys(), followup_question_batch ): query_state_tracker[batched_query].followup_question = batched_followup_question diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 159fa2df4..a6fb05270 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -806,7 +806,7 @@ async def test_get_completion_batch_queries_empty_context(mock_edge): @pytest.mark.asyncio async def test_get_completion_batch_queries_duplicate_queries(mock_edge): - """Test get_completion retrieves context when not provided.""" + """Test get_completion batch queries with duplicate queries.""" mock_graph_engine = AsyncMock() mock_graph_engine.is_empty = AsyncMock(return_value=False)