From d258b1d7afd753c27b8b2af9948dc3551f398d76 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Mon, 19 Jan 2026 00:03:35 +0100 Subject: [PATCH] fix: fix tests --- .../graph_completion_context_extension_retriever.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index f7603faba..4c9122f85 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -58,11 +58,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): async def get_completion( self, query: Optional[str] = None, - query_batch: Optional[List[str]] = None, context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, context_extension_rounds=4, response_model: Type = str, + query_batch: Optional[List[str]] = None, ) -> List[Any]: """ Extends the context for a given query by retrieving related triplets and generating new @@ -107,6 +107,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): if triplets_batch is None: triplets_batch = await self.get_context(query_batch=query_batch) + if not triplets_batch: + return [] + context_text = "" context_text_batch = await asyncio.gather( *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] @@ -119,8 +122,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Final state is stored in the finished_queries_data dict, and we populate it at the start as well. original_query_batch = query_batch finished_queries_data = {} - for i, query in enumerate(query_batch): - finished_queries_data[query] = (triplets_batch[i], context_text_batch[i]) + for i, batched_query in enumerate(query_batch): + if not triplets_batch[i]: + query_batch[i] = "" + else: + finished_queries_data[batched_query] = (triplets_batch[i], context_text_batch[i]) while round_idx <= context_extension_rounds: logger.info(