From d180e2aeae02b928a3ae057c3f4408288c4f56f4 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 15 Jan 2026 13:41:19 +0100 Subject: [PATCH] feat: enable batch queries on graph completion retrievers --- ..._completion_context_extension_retriever.py | 116 +++++++++++++----- .../graph_completion_cot_retriever.py | 35 ++++-- .../retrieval/graph_completion_retriever.py | 62 +++++++--- 3 files changed, 161 insertions(+), 52 deletions(-) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index fc49a139b..7774ba9e5 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -56,8 +56,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): async def get_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, context_extension_rounds=4, response_model: Type = str, @@ -91,46 +92,98 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): """ triplets = context - if triplets is None: - triplets = await self.get_context(query) + if query: + # This is done mostly to avoid duplicating a lot of code unnecessarily + query_batch = [query] + query = None + if triplets: + triplets = [triplets] - context_text = await self.resolve_edges_to_text(triplets) + if triplets is None: + triplets = await self.get_context(query, query_batch) + + context_text = "" + context_texts = "" + if isinstance(triplets[0], list): + context_texts = await asyncio.gather( + *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) + else: + context_text = await self.resolve_edges_to_text(triplets) round_idx = 1 + # We will be removing queries, and their associated triplets and context, as we go + # through iterations, so we need to save their final states for the final generation. + original_query_batch = query_batch + saved_triplets = [] + saved_context_texts = [] while round_idx <= context_extension_rounds: - prev_size = len(triplets) - logger.info( f"Context extension: round {round_idx} - generating next graph locational query." ) - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - ) - triplets += await self.get_context(completion) - triplets = list(set(triplets)) - context_text = await self.resolve_edges_to_text(triplets) - - num_triplets = len(triplets) - - if num_triplets == prev_size: + query_batch = [query for query in query_batch if query] + triplets = [triplet_element for triplet_element in triplets if triplet_element] + context_texts = [context_text for context_text in context_texts if context_text] + if len(query_batch) == 0: logger.info( f"Context extension: round {round_idx} – no new triplets found; stopping early." ) break + prev_sizes = [len(triplets_element) for triplets_element in triplets] + + completions = await asyncio.gather( + *[ + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + ) + for query, context in zip(query_batch, context_texts) + ], + ) + + new_triplets = await self.get_context(query_batch=completions) + for i, (triplets_element, new_triplets_element) in enumerate( + zip(triplets, new_triplets) + ): + triplets_element += new_triplets_element + triplets[i] = list(dict.fromkeys(triplets_element)) + + context_texts = await asyncio.gather( + *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) + + new_sizes = [len(triplets_element) for triplets_element in triplets] + + for i, (query, prev_size, new_size, triplet_element, context_text) in enumerate( + zip(query_batch, prev_sizes, new_sizes, triplets, context_texts) + ): + if prev_size == new_size: + # In this case, we can stop trying to extend the context of this query + query_batch[i] = "" + saved_triplets.append(triplet_element) + triplets[i] = [] + saved_context_texts.append(context_text) + context_texts[i] = "" + logger.info( f"Context extension: round {round_idx} - " - f"number of unique retrieved triplets: {num_triplets}" + f"number of unique retrieved triplets for each query : {new_sizes}" ) round_idx += 1 + # Reset variables for the final generations. They contain the final state + # of triplets and contexts for each query, after all extension iterations. + query_batch = original_query_batch + context_texts = saved_context_texts + triplets = saved_triplets + # Check if we need to generate context summary for caching cache_config = CacheConfig() user = session_user.get() @@ -153,13 +206,18 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ), ) else: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - response_model=response_model, + completion = await asyncio.gather( + *[ + generate_completion( + query=query, + context=context_text, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) + for query, context_text in zip(query_batch, context_texts) + ], ) if self.save_interaction and context_text and triplets and completion: diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 70fcb6cdb..5c0dde0df 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -170,7 +170,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def get_completion( self, - query: str, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, context: Optional[List[Edge]] = None, session_id: Optional[str] = None, max_iter=4, @@ -213,13 +214,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if session_save: conversation_history = await get_conversation_history(session_id=session_id) - completion, context_text, triplets = await self._run_cot_completion( - query=query, - context=context, - conversation_history=conversation_history, - max_iter=max_iter, - response_model=response_model, - ) + completion_results = [] + if query_batch and len(query_batch) > 0: + completion_results = await asyncio.gather( + *[ + self._run_cot_completion( + query=query, + context=context, + conversation_history=conversation_history, + max_iter=max_iter, + response_model=response_model, + ) + for query in query_batch + ] + ) + else: + completion, context_text, triplets = await self._run_cot_completion( + query=query, + context=context, + conversation_history=conversation_history, + max_iter=max_iter, + response_model=response_model, + ) if self.save_interaction and context and triplets and completion: await self.save_qa( @@ -236,4 +252,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) + if completion_results: + return [completion for completion, _, _ in completion_results] + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index bb8b34327..146b51fa9 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -79,7 +79,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): """ return await resolve_edges_to_text(retrieved_edges) - async def get_triplets(self, query: str) -> List[Edge]: + async def get_triplets( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + ) -> List[Edge] | List[List[Edge]]: """ Retrieves relevant graph triplets based on a query string. @@ -107,6 +111,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): found_triplets = await brute_force_triplet_search( query, + query_batch, top_k=self.top_k, collections=vector_index_collections or None, node_type=self.node_type, @@ -117,7 +122,11 @@ class GraphCompletionRetriever(BaseGraphRetriever): return found_triplets - async def get_context(self, query: str) -> List[Edge]: + async def get_context( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + ) -> List[Edge] | List[List[Edge]]: """ Retrieves and resolves graph triplets into context based on a query. @@ -139,7 +148,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): logger.warning("Search attempt on an empty knowledge graph") return [] - triplets = await self.get_triplets(query) + triplets = await self.get_triplets(query, query_batch) if len(triplets) == 0: logger.warning("Empty context was provided to the completion") @@ -158,8 +167,9 @@ class GraphCompletionRetriever(BaseGraphRetriever): async def get_completion( self, - query: str, - context: Optional[List[Edge]] = None, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + context: Optional[List[Edge] | List[List[Edge]]] = None, session_id: Optional[str] = None, response_model: Type = str, ) -> List[Any]: @@ -183,9 +193,16 @@ class GraphCompletionRetriever(BaseGraphRetriever): triplets = context if triplets is None: - triplets = await self.get_context(query) + triplets = await self.get_context(query, query_batch) - context_text = await resolve_edges_to_text(triplets) + context_text = "" + context_texts = "" + if isinstance(triplets[0], list): + context_texts = await asyncio.gather( + *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] + ) + else: + context_text = await resolve_edges_to_text(triplets) cache_config = CacheConfig() user = session_user.get() @@ -208,14 +225,29 @@ class GraphCompletionRetriever(BaseGraphRetriever): ), ) else: - completion = await generate_completion( - query=query, - context=context_text, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - system_prompt=self.system_prompt, - response_model=response_model, - ) + if query_batch and len(query_batch) > 0: + completion = await asyncio.gather( + *[ + generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) + for query, context in zip(query_batch, context_texts) + ], + ) + else: + completion = await generate_completion( + query=query, + context=context_text if context_text else context_texts, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + system_prompt=self.system_prompt, + response_model=response_model, + ) if self.save_interaction and context and triplets and completion: await self.save_qa(