From b88e4242ade6c2810f1b5aec404bbcecb0323bd2 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Sun, 18 Jan 2026 22:53:16 +0100 Subject: [PATCH] fix: PR comments fixes --- ..._completion_context_extension_retriever.py | 87 +++++++++++-------- .../graph_completion_cot_retriever.py | 26 +++--- .../retrieval/graph_completion_retriever.py | 41 +++++++-- .../utils/brute_force_triplet_search.py | 16 ++-- .../retrieval/utils/validate_queries.py | 14 +++ 5 files changed, 116 insertions(+), 68 deletions(-) create mode 100644 cognee/modules/retrieval/utils/validate_queries.py diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 1d20d4404..f7603faba 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, List, Type, Any from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text @@ -90,20 +91,25 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ - triplets = context + # TODO: This may be unnecessary in this retriever, will check later + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + + triplets_batch = context if query: # This is done mostly to avoid duplicating a lot of code unnecessarily query_batch = [query] - if triplets: - triplets = [triplets] + if triplets_batch: + triplets_batch = [triplets_batch] - if triplets is None: - triplets = await self.get_context(query_batch=query_batch) + if triplets_batch is None: + triplets_batch = await self.get_context(query_batch=query_batch) context_text = "" - context_texts = await asyncio.gather( - *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + context_text_batch = await asyncio.gather( + *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] ) round_idx = 1 @@ -114,7 +120,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): original_query_batch = query_batch finished_queries_data = {} for i, query in enumerate(query_batch): - finished_queries_data[query] = (triplets[i], context_texts[i]) + finished_queries_data[query] = (triplets_batch[i], context_text_batch[i]) while round_idx <= context_extension_rounds: logger.info( @@ -123,15 +129,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Filter out the queries that cannot be extended further, and their associated contexts query_batch = [query for query in query_batch if query] - triplets = [triplet_element for triplet_element in triplets if triplet_element] - context_texts = [context_text for context_text in context_texts if context_text] + triplets_batch = [triplets for triplets in triplets_batch if triplets] + context_text_batch = [ + context_text for context_text in context_text_batch if context_text + ] if len(query_batch) == 0: logger.info( f"Context extension: round {round_idx} – no new triplets found; stopping early." ) break - prev_sizes = [len(triplets_element) for triplets_element in triplets] + prev_sizes = [len(triplets) for triplets in triplets_batch] completions = await asyncio.gather( *[ @@ -142,33 +150,31 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) - for query, context in zip(query_batch, context_texts) + for query, context in zip(query_batch, context_text_batch) ], ) # Get new triplets, and merge them with existing ones, filtering out duplicates - new_triplets = await self.get_context(query_batch=completions) - for i, (triplets_element, new_triplets_element) in enumerate( - zip(triplets, new_triplets) - ): - triplets_element += new_triplets_element - triplets[i] = list(dict.fromkeys(triplets_element)) + new_triplets_batch = await self.get_context(query_batch=completions) + for i, (triplets, new_triplets) in enumerate(zip(triplets_batch, new_triplets_batch)): + triplets += new_triplets + triplets_batch[i] = list(dict.fromkeys(triplets)) - context_texts = await asyncio.gather( - *[self.resolve_edges_to_text(triplets_element) for triplets_element in triplets] + context_text_batch = await asyncio.gather( + *[self.resolve_edges_to_text(triplets) for triplets in triplets_batch] ) - new_sizes = [len(triplets_element) for triplets_element in triplets] + new_sizes = [len(triplets) for triplets in triplets_batch] - for i, (query, prev_size, new_size, triplets_element, context_text) in enumerate( - zip(query_batch, prev_sizes, new_sizes, triplets, context_texts) + for i, (batched_query, prev_size, new_size, triplets, context_text) in enumerate( + zip(query_batch, prev_sizes, new_sizes, triplets_batch, context_text_batch) ): - finished_queries_data[query] = (triplets_element, context_text) + finished_queries_data[query] = (triplets, context_text) if prev_size == new_size: # In this case, we can stop trying to extend the context of this query query_batch[i] = "" - triplets[i] = [] - context_texts[i] = "" + triplets_batch[i] = [] + context_text_batch[i] = "" logger.info( f"Context extension: round {round_idx} - " @@ -180,11 +186,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): # Reset variables for the final generations. They contain the final state # of triplets and contexts for each query, after all extension iterations. query_batch = original_query_batch - triplets = [] - context_texts = [] - for query in query_batch: - triplets.append(finished_queries_data[query][0]) - context_texts.append(finished_queries_data[query][1]) + triplets_batch = [] + context_text_batch = [] + for batched_query in query_batch: + triplets_batch.append(finished_queries_data[batched_query][0]) + context_text_batch.append(finished_queries_data[batched_query][1]) # Check if we need to generate context summary for caching cache_config = CacheConfig() @@ -192,6 +198,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): user_id = getattr(user, "id", None) session_save = user_id and cache_config.caching + completion_batch = [] + if session_save: conversation_history = await get_conversation_history(session_id=session_id) @@ -208,24 +216,27 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ), ) else: - completion = await asyncio.gather( + completion_batch = await asyncio.gather( *[ generate_completion( - query=query, - context=context_text, + query=batched_query, + context=batched_context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, response_model=response_model, ) - for query, context_text in zip(query_batch, context_texts) + for batched_query, batched_context_text in zip(query_batch, context_text_batch) ], ) # TODO: Do batch queries for save interaction - if self.save_interaction and context_texts and triplets and completion: + if self.save_interaction and context_text_batch and triplets_batch and completion_batch: await self.save_qa( - question=query, answer=completion[0], context=context_texts[0], triplets=triplets[0] + question=query, + answer=completion_batch[0], + context=context_text_batch[0], + triplets=triplets_batch[0], ) if session_save: @@ -236,4 +247,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): session_id=session_id, ) - return completion if isinstance(completion, list) else [completion] + return completion_batch if completion_batch else [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 160419ee9..4ecdc910a 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -3,6 +3,7 @@ import json from typing import Optional, List, Type, Any from pydantic import BaseModel from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -203,6 +204,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + # Check if session saving is enabled cache_config = CacheConfig() user = session_user.get() @@ -214,24 +219,25 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if session_save: conversation_history = await get_conversation_history(session_id=session_id) - completion_results = [] + context_batch = context + completion_batch = [] if query_batch and len(query_batch) > 0: - if not context: + if not context_batch: # Having a list is necessary to zip through it - context = [] - for query in query_batch: - context.append(None) + context_batch = [] + for _ in query_batch: + context_batch.append(None) - completion_results = await asyncio.gather( + completion_batch = await asyncio.gather( *[ self._run_cot_completion( query=query, - context=context_el, + context=context, conversation_history=conversation_history, max_iter=max_iter, response_model=response_model, ) - for query, context_el in zip(query_batch, context) + for batched_query, context in zip(query_batch, context_batch) ] ) else: @@ -260,7 +266,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): session_id=session_id, ) - if completion_results: - return [completion for completion, _, _ in completion_results] + if completion_batch: + return [completion for completion, _, _ in completion_batch] return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index f740496d0..a1b4c3833 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -4,6 +4,7 @@ from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses @@ -150,15 +151,33 @@ class GraphCompletionRetriever(BaseGraphRetriever): triplets = await self.get_triplets(query, query_batch) - if len(triplets) == 0: - logger.warning("Empty context was provided to the completion") - return [] + if query_batch: + for batched_triplets, batched_query in zip(triplets, query_batch): + if len(batched_triplets) == 0: + logger.warning( + f"Empty context was provided to the completion for the query: {batched_query}" + ) + entity_nodes_batch = [] + for batched_triplets in triplets: + entity_nodes_batch.append(get_entity_nodes_from_triplets(batched_triplets)) - # context = await self.resolve_edges_to_text(triplets) + await asyncio.gather( + *[ + update_node_access_timestamps(batched_entity_nodes) + for batched_entity_nodes in entity_nodes_batch + ] + ) + else: + if len(triplets) == 0: + logger.warning("Empty context was provided to the completion") + return [] - entity_nodes = get_entity_nodes_from_triplets(triplets) + # context = await self.resolve_edges_to_text(triplets) + + entity_nodes = get_entity_nodes_from_triplets(triplets) + + await update_node_access_timestamps(entity_nodes) - await update_node_access_timestamps(entity_nodes) return triplets async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): @@ -190,15 +209,19 @@ class GraphCompletionRetriever(BaseGraphRetriever): - Any: A generated completion based on the query and context provided. """ + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + triplets = context if triplets is None: triplets = await self.get_context(query, query_batch) context_text = "" - context_texts = "" + context_text_batch = [] if triplets and isinstance(triplets[0], list): - context_texts = await asyncio.gather( + context_text_batch = await asyncio.gather( *[resolve_edges_to_text(triplets_element) for triplets_element in triplets] ) else: @@ -236,7 +259,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): system_prompt=self.system_prompt, response_model=response_model, ) - for query, context in zip(query_batch, context_texts) + for query, context in zip(query_batch, context_text_batch) ], ) else: diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index ce84c1423..c0a0e7fab 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,5 +1,6 @@ from typing import List, Optional, Type, Union +from cognee.modules.retrieval.utils.validate_queries import validate_queries from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine @@ -146,17 +147,10 @@ async def brute_force_triplet_search( In single-query mode, node_distances and edge_distances are stored as flat lists. In batch mode, they are stored as list-of-lists (one list per query). """ - if query is not None and query_batch is not None: - raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") - if query is None and query_batch is None: - raise ValueError("Must provide either 'query' or 'query_batch'.") - if query is not None and (not query or not isinstance(query, str)): - raise ValueError("The query must be a non-empty string.") - if query_batch is not None: - if not isinstance(query_batch, list) or not query_batch: - raise ValueError("query_batch must be a non-empty list of strings.") - if not all(isinstance(q, str) and q for q in query_batch): - raise ValueError("All items in query_batch must be non-empty strings.") + query_validation = validate_queries(query, query_batch) + if not query_validation[0]: + raise ValueError(query_validation[1]) + if top_k <= 0: raise ValueError("top_k must be a positive integer.") diff --git a/cognee/modules/retrieval/utils/validate_queries.py b/cognee/modules/retrieval/utils/validate_queries.py new file mode 100644 index 000000000..913b0d665 --- /dev/null +++ b/cognee/modules/retrieval/utils/validate_queries.py @@ -0,0 +1,14 @@ +def validate_queries(query, query_batch) -> tuple[bool, str]: + if query is not None and query_batch is not None: + return False, "Cannot provide both 'query' and 'query_batch'; use exactly one." + if query is None and query_batch is None: + return False, "Must provide either 'query' or 'query_batch'." + if query is not None and (not query or not isinstance(query, str)): + return False, "The query must be a non-empty string." + if query_batch is not None: + if not isinstance(query_batch, list) or not query_batch: + return False, "query_batch must be a non-empty list of strings." + if not all(isinstance(q, str) and q for q in query_batch): + return False, "All items in query_batch must be non-empty strings." + + return True, ""