From f9cb490ad96073e424489821eb5dd273ed966336 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Thu, 8 Jan 2026 17:20:28 +0100 Subject: [PATCH 01/14] refactor: brute_force_triplet_search.py --- .../utils/brute_force_triplet_search.py | 216 ++++++++++++------ 1 file changed, 141 insertions(+), 75 deletions(-) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index a70fa661b..5f367ca7f 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import List, Optional, Type +from typing import Any, List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError @@ -9,13 +9,12 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -from cognee.modules.users.models import User -from cognee.shared.utils import send_telemetry logger = get_logger(level=ERROR) def format_triplets(edges): + """Formats edges into human-readable triplet strings.""" triplets = [] for edge in edges: node1 = edge.node1 @@ -51,7 +50,6 @@ async def get_memory_fragment( try: graph_engine = await get_graph_engine() - await memory_fragment.project_graph_from_db( graph_engine, node_properties_to_project=properties_to_project, @@ -61,18 +59,142 @@ async def get_memory_fragment( relevant_ids_to_filter=relevant_ids_to_filter, triplet_distance_penalty=triplet_distance_penalty, ) - except EntityNotFoundError: - # This is expected behavior - continue with empty fragment pass except Exception as e: logger.error(f"Error during memory fragment creation: {str(e)}") - # Still return the fragment even if projection failed - pass return memory_fragment +class _BruteForceTripletSearchEngine: + """Internal search engine for brute force triplet search operations.""" + + def __init__( + self, + query: str, + top_k: int, + collections: List[str], + properties_to_project: Optional[List[str]], + memory_fragment: Optional[CogneeGraph], + node_type: Optional[Type], + node_name: Optional[List[str]], + wide_search_limit: Optional[int], + triplet_distance_penalty: float, + ): + self.query = query + self.top_k = top_k + self.collections = collections + self.properties_to_project = properties_to_project + self.memory_fragment = memory_fragment + self.node_type = node_type + self.node_name = node_name + self.wide_search_limit = wide_search_limit + self.triplet_distance_penalty = triplet_distance_penalty + self.vector_engine = self._load_vector_engine() + self.query_vector = None + self.node_distances = None + self.edge_distances = None + + async def search(self) -> List[Edge]: + """Orchestrates the brute force triplet search workflow.""" + await self._embed_query_text() + await self._retrieve_and_set_vector_distances() + + if not (self.edge_distances or any(self.node_distances.values())): + return [] + + await self._ensure_memory_fragment_is_loaded() + await self._map_distances_to_memory_fragment() + + return await self.memory_fragment.calculate_top_triplet_importances(k=self.top_k) + + def _load_vector_engine(self): + """Loads the vector engine instance.""" + try: + return get_vector_engine() + except Exception as e: + logger.error("Failed to initialize vector engine: %s", e) + raise RuntimeError("Initialization error") from e + + async def _embed_query_text(self): + """Converts query text into embedding vector.""" + query_embeddings = await self.vector_engine.embedding_engine.embed_text([self.query]) + self.query_vector = query_embeddings[0] + + async def _retrieve_and_set_vector_distances(self): + """Searches all collections in parallel and sets node/edge distances directly.""" + start_time = time.time() + search_results = await self._run_parallel_collection_searches() + elapsed_time = time.time() - start_time + + collections_with_results = sum(1 for result in search_results if result) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + self.node_distances = {} + for collection, result in zip(self.collections, search_results): + if collection == "EdgeType_relationship_name": + self.edge_distances = result + else: + self.node_distances[collection] = result + + async def _run_parallel_collection_searches(self) -> List[List[Any]]: + """Executes vector searches across all collections concurrently.""" + search_tasks = [ + self._search_single_collection(collection_name) for collection_name in self.collections + ] + return await asyncio.gather(*search_tasks) + + async def _search_single_collection(self, collection_name: str): + """Searches one collection and returns results or empty list if not found.""" + try: + return await self.vector_engine.search( + collection_name=collection_name, + query_vector=self.query_vector, + limit=self.wide_search_limit, + ) + except CollectionNotFoundError: + return [] + + async def _ensure_memory_fragment_is_loaded(self): + """Loads memory fragment if not already provided.""" + if self.memory_fragment is None: + relevant_node_ids = self._extract_relevant_node_ids_for_filtering() + self.memory_fragment = await get_memory_fragment( + properties_to_project=self.properties_to_project, + node_type=self.node_type, + node_name=self.node_name, + relevant_ids_to_filter=relevant_node_ids, + triplet_distance_penalty=self.triplet_distance_penalty, + ) + + def _extract_relevant_node_ids_for_filtering(self) -> Optional[List[str]]: + """Extracts unique node IDs from search results to filter graph projection.""" + if self.wide_search_limit is None: + return None + + relevant_node_ids = { + str(getattr(scored_node, "id")) + for score_collection in self.node_distances.values() + if isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + return list(relevant_node_ids) + + async def _map_distances_to_memory_fragment(self): + """Maps vector distances to nodes and edges in the memory fragment.""" + await self.memory_fragment.map_vector_distances_to_graph_nodes( + node_distances=self.node_distances + ) + await self.memory_fragment.map_vector_distances_to_graph_edges( + edge_distances=self.edge_distances + ) + + async def brute_force_triplet_search( query: str, top_k: int = 5, @@ -108,7 +230,6 @@ async def brute_force_triplet_search( # Setting wide search limit based on the parameters non_global_search = node_name is None - wide_search_limit = wide_search_top_k if non_global_search else None if collections is None: @@ -123,73 +244,18 @@ async def brute_force_triplet_search( collections.append("EdgeType_relationship_name") try: - vector_engine = get_vector_engine() - except Exception as e: - logger.error("Failed to initialize vector engine: %s", e) - raise RuntimeError("Initialization error") from e - - query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0] - - async def search_in_collection(collection_name: str): - try: - return await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit - ) - except CollectionNotFoundError: - return [] - - try: - start_time = time.time() - - results = await asyncio.gather( - *[search_in_collection(collection_name) for collection_name in collections] + engine = _BruteForceTripletSearchEngine( + query=query, + top_k=top_k, + collections=collections, + properties_to_project=properties_to_project, + memory_fragment=memory_fragment, + node_type=node_type, + node_name=node_name, + wide_search_limit=wide_search_limit, + triplet_distance_penalty=triplet_distance_penalty, ) - - if all(not item for item in results): - return [] - - # Final statistics - vector_collection_search_time = time.time() - start_time - logger.info( - f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s" - ) - - node_distances = {collection: result for collection, result in zip(collections, results)} - - edge_distances = node_distances.get("EdgeType_relationship_name", None) - - if wide_search_limit is not None: - relevant_ids_to_filter = list( - { - str(getattr(scored_node, "id")) - for collection_name, score_collection in node_distances.items() - if collection_name != "EdgeType_relationship_name" - and isinstance(score_collection, (list, tuple)) - for scored_node in score_collection - if getattr(scored_node, "id", None) - } - ) - else: - relevant_ids_to_filter = None - - if memory_fragment is None: - memory_fragment = await get_memory_fragment( - properties_to_project=properties_to_project, - node_type=node_type, - node_name=node_name, - relevant_ids_to_filter=relevant_ids_to_filter, - triplet_distance_penalty=triplet_distance_penalty, - ) - - await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - - results = await memory_fragment.calculate_top_triplet_importances(k=top_k) - - return results - - except CollectionNotFoundError: - return [] + return await engine.search() except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", From 876120853f11bbaf1fd66d65d23d113d067928c1 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 9 Jan 2026 09:10:29 +0100 Subject: [PATCH 02/14] refactor: brute_force_triplet_search.py with context class --- .../utils/brute_force_triplet_search.py | 185 +++++++++--------- 1 file changed, 90 insertions(+), 95 deletions(-) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 5f367ca7f..50d16edb2 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -23,12 +23,10 @@ def format_triplets(edges): node1_attributes = node1.attributes node2_attributes = node2.attributes - # Filter only non-None properties node1_info = {key: value for key, value in node1_attributes.items() if value is not None} node2_info = {key: value for key, value in node2_attributes.items() if value is not None} edge_info = {key: value for key, value in edge_attributes.items() if value is not None} - # Create the formatted triplet triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n" triplets.append(triplet) @@ -67,8 +65,8 @@ async def get_memory_fragment( return memory_fragment -class _BruteForceTripletSearchEngine: - """Internal search engine for brute force triplet search operations.""" +class TripletSearchContext: + """Pure state container for triplet search operations.""" def __init__( self, @@ -76,7 +74,6 @@ class _BruteForceTripletSearchEngine: top_k: int, collections: List[str], properties_to_project: Optional[List[str]], - memory_fragment: Optional[CogneeGraph], node_type: Optional[Type], node_name: Optional[List[str]], wide_search_limit: Optional[int], @@ -86,92 +83,20 @@ class _BruteForceTripletSearchEngine: self.top_k = top_k self.collections = collections self.properties_to_project = properties_to_project - self.memory_fragment = memory_fragment self.node_type = node_type self.node_name = node_name self.wide_search_limit = wide_search_limit self.triplet_distance_penalty = triplet_distance_penalty - self.vector_engine = self._load_vector_engine() + self.query_vector = None self.node_distances = None self.edge_distances = None - async def search(self) -> List[Edge]: - """Orchestrates the brute force triplet search workflow.""" - await self._embed_query_text() - await self._retrieve_and_set_vector_distances() + def has_results(self) -> bool: + """Checks if any collections returned results.""" + return bool(self.edge_distances or any(self.node_distances.values())) - if not (self.edge_distances or any(self.node_distances.values())): - return [] - - await self._ensure_memory_fragment_is_loaded() - await self._map_distances_to_memory_fragment() - - return await self.memory_fragment.calculate_top_triplet_importances(k=self.top_k) - - def _load_vector_engine(self): - """Loads the vector engine instance.""" - try: - return get_vector_engine() - except Exception as e: - logger.error("Failed to initialize vector engine: %s", e) - raise RuntimeError("Initialization error") from e - - async def _embed_query_text(self): - """Converts query text into embedding vector.""" - query_embeddings = await self.vector_engine.embedding_engine.embed_text([self.query]) - self.query_vector = query_embeddings[0] - - async def _retrieve_and_set_vector_distances(self): - """Searches all collections in parallel and sets node/edge distances directly.""" - start_time = time.time() - search_results = await self._run_parallel_collection_searches() - elapsed_time = time.time() - start_time - - collections_with_results = sum(1 for result in search_results if result) - logger.info( - f"Vector collection retrieval completed: Retrieved distances from " - f"{collections_with_results} collections in {elapsed_time:.2f}s" - ) - - self.node_distances = {} - for collection, result in zip(self.collections, search_results): - if collection == "EdgeType_relationship_name": - self.edge_distances = result - else: - self.node_distances[collection] = result - - async def _run_parallel_collection_searches(self) -> List[List[Any]]: - """Executes vector searches across all collections concurrently.""" - search_tasks = [ - self._search_single_collection(collection_name) for collection_name in self.collections - ] - return await asyncio.gather(*search_tasks) - - async def _search_single_collection(self, collection_name: str): - """Searches one collection and returns results or empty list if not found.""" - try: - return await self.vector_engine.search( - collection_name=collection_name, - query_vector=self.query_vector, - limit=self.wide_search_limit, - ) - except CollectionNotFoundError: - return [] - - async def _ensure_memory_fragment_is_loaded(self): - """Loads memory fragment if not already provided.""" - if self.memory_fragment is None: - relevant_node_ids = self._extract_relevant_node_ids_for_filtering() - self.memory_fragment = await get_memory_fragment( - properties_to_project=self.properties_to_project, - node_type=self.node_type, - node_name=self.node_name, - relevant_ids_to_filter=relevant_node_ids, - triplet_distance_penalty=self.triplet_distance_penalty, - ) - - def _extract_relevant_node_ids_for_filtering(self) -> Optional[List[str]]: + def extract_relevant_node_ids(self) -> Optional[List[str]]: """Extracts unique node IDs from search results to filter graph projection.""" if self.wide_search_limit is None: return None @@ -185,14 +110,76 @@ class _BruteForceTripletSearchEngine: } return list(relevant_node_ids) - async def _map_distances_to_memory_fragment(self): - """Maps vector distances to nodes and edges in the memory fragment.""" - await self.memory_fragment.map_vector_distances_to_graph_nodes( - node_distances=self.node_distances - ) - await self.memory_fragment.map_vector_distances_to_graph_edges( - edge_distances=self.edge_distances + def set_distances_from_results(self, search_results: List[List[Any]]): + """Separates search results into node and edge distances.""" + self.node_distances = {} + for collection, result in zip(self.collections, search_results): + if collection == "EdgeType_relationship_name": + self.edge_distances = result + else: + self.node_distances[collection] = result + + +async def _search_single_collection( + vector_engine: Any, search_context: TripletSearchContext, collection_name: str +): + """Searches one collection and returns results or empty list if not found.""" + try: + return await vector_engine.search( + collection_name=collection_name, + query_vector=search_context.query_vector, + limit=search_context.wide_search_limit, ) + except CollectionNotFoundError: + return [] + + +async def _embed_and_retrieve_distances(search_context: TripletSearchContext): + """Embeds query and retrieves vector distances from all collections.""" + vector_engine = get_vector_engine() + + query_embeddings = await vector_engine.embedding_engine.embed_text([search_context.query]) + search_context.query_vector = query_embeddings[0] + + start_time = time.time() + search_tasks = [ + _search_single_collection(vector_engine, search_context, collection) + for collection in search_context.collections + ] + search_results = await asyncio.gather(*search_tasks) + + elapsed_time = time.time() - start_time + collections_with_results = sum(1 for result in search_results if result) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + search_context.set_distances_from_results(search_results) + + +async def _create_memory_fragment(search_context: TripletSearchContext) -> CogneeGraph: + """Creates memory fragment using search context properties.""" + relevant_node_ids = search_context.extract_relevant_node_ids() + return await get_memory_fragment( + properties_to_project=search_context.properties_to_project, + node_type=search_context.node_type, + node_name=search_context.node_name, + relevant_ids_to_filter=relevant_node_ids, + triplet_distance_penalty=search_context.triplet_distance_penalty, + ) + + +async def _map_distances_to_fragment( + search_context: TripletSearchContext, memory_fragment: CogneeGraph +): + """Maps vector distances from search context to memory fragment.""" + await memory_fragment.map_vector_distances_to_graph_nodes( + node_distances=search_context.node_distances + ) + await memory_fragment.map_vector_distances_to_graph_edges( + edge_distances=search_context.edge_distances + ) async def brute_force_triplet_search( @@ -228,9 +215,7 @@ async def brute_force_triplet_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") - # Setting wide search limit based on the parameters - non_global_search = node_name is None - wide_search_limit = wide_search_top_k if non_global_search else None + wide_search_limit = wide_search_top_k if node_name is None else None if collections is None: collections = [ @@ -244,18 +229,28 @@ async def brute_force_triplet_search( collections.append("EdgeType_relationship_name") try: - engine = _BruteForceTripletSearchEngine( + search_context = TripletSearchContext( query=query, top_k=top_k, collections=collections, properties_to_project=properties_to_project, - memory_fragment=memory_fragment, node_type=node_type, node_name=node_name, wide_search_limit=wide_search_limit, triplet_distance_penalty=triplet_distance_penalty, ) - return await engine.search() + + await _embed_and_retrieve_distances(search_context) + + if not search_context.has_results(): + return [] + + if memory_fragment is None: + memory_fragment = await _create_memory_fragment(search_context) + + await _map_distances_to_fragment(search_context, memory_fragment) + + return await memory_fragment.calculate_top_triplet_importances(k=search_context.top_k) except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", From c79af6c8cccb241200d4596a87a068c013854a74 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 9 Jan 2026 10:05:44 +0100 Subject: [PATCH 03/14] refactor: brute_force_triplet_search.py and node_edge_vector_search.py --- .../utils/brute_force_triplet_search.py | 170 ++++-------------- .../utils/node_edge_vector_search.py | 81 +++++++++ 2 files changed, 119 insertions(+), 132 deletions(-) create mode 100644 cognee/modules/retrieval/utils/node_edge_vector_search.py diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 50d16edb2..ef805a127 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,14 +1,11 @@ -import asyncio -import time -from typing import Any, List, Optional, Type +from typing import List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch logger = get_logger(level=ERROR) @@ -65,122 +62,36 @@ async def get_memory_fragment( return memory_fragment -class TripletSearchContext: - """Pure state container for triplet search operations.""" - - def __init__( - self, - query: str, - top_k: int, - collections: List[str], - properties_to_project: Optional[List[str]], - node_type: Optional[Type], - node_name: Optional[List[str]], - wide_search_limit: Optional[int], - triplet_distance_penalty: float, - ): - self.query = query - self.top_k = top_k - self.collections = collections - self.properties_to_project = properties_to_project - self.node_type = node_type - self.node_name = node_name - self.wide_search_limit = wide_search_limit - self.triplet_distance_penalty = triplet_distance_penalty - - self.query_vector = None - self.node_distances = None - self.edge_distances = None - - def has_results(self) -> bool: - """Checks if any collections returned results.""" - return bool(self.edge_distances or any(self.node_distances.values())) - - def extract_relevant_node_ids(self) -> Optional[List[str]]: - """Extracts unique node IDs from search results to filter graph projection.""" - if self.wide_search_limit is None: - return None - - relevant_node_ids = { - str(getattr(scored_node, "id")) - for score_collection in self.node_distances.values() - if isinstance(score_collection, (list, tuple)) - for scored_node in score_collection - if getattr(scored_node, "id", None) - } - return list(relevant_node_ids) - - def set_distances_from_results(self, search_results: List[List[Any]]): - """Separates search results into node and edge distances.""" - self.node_distances = {} - for collection, result in zip(self.collections, search_results): - if collection == "EdgeType_relationship_name": - self.edge_distances = result - else: - self.node_distances[collection] = result - - -async def _search_single_collection( - vector_engine: Any, search_context: TripletSearchContext, collection_name: str -): - """Searches one collection and returns results or empty list if not found.""" - try: - return await vector_engine.search( - collection_name=collection_name, - query_vector=search_context.query_vector, - limit=search_context.wide_search_limit, +async def _get_top_triplet_importances( + memory_fragment: Optional[CogneeGraph], + vector_search: NodeEdgeVectorSearch, + properties_to_project: Optional[List[str]], + node_type: Optional[Type], + node_name: Optional[List[str]], + triplet_distance_penalty: float, + wide_search_limit: Optional[int], + top_k: int, +) -> List[Edge]: + """Creates memory fragment (if needed), maps distances, and calculates top triplet importances.""" + if memory_fragment is None: + relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, + node_type=node_type, + node_name=node_name, + relevant_ids_to_filter=relevant_node_ids, + triplet_distance_penalty=triplet_distance_penalty, ) - except CollectionNotFoundError: - return [] - -async def _embed_and_retrieve_distances(search_context: TripletSearchContext): - """Embeds query and retrieves vector distances from all collections.""" - vector_engine = get_vector_engine() - - query_embeddings = await vector_engine.embedding_engine.embed_text([search_context.query]) - search_context.query_vector = query_embeddings[0] - - start_time = time.time() - search_tasks = [ - _search_single_collection(vector_engine, search_context, collection) - for collection in search_context.collections - ] - search_results = await asyncio.gather(*search_tasks) - - elapsed_time = time.time() - start_time - collections_with_results = sum(1 for result in search_results if result) - logger.info( - f"Vector collection retrieval completed: Retrieved distances from " - f"{collections_with_results} collections in {elapsed_time:.2f}s" - ) - - search_context.set_distances_from_results(search_results) - - -async def _create_memory_fragment(search_context: TripletSearchContext) -> CogneeGraph: - """Creates memory fragment using search context properties.""" - relevant_node_ids = search_context.extract_relevant_node_ids() - return await get_memory_fragment( - properties_to_project=search_context.properties_to_project, - node_type=search_context.node_type, - node_name=search_context.node_name, - relevant_ids_to_filter=relevant_node_ids, - triplet_distance_penalty=search_context.triplet_distance_penalty, - ) - - -async def _map_distances_to_fragment( - search_context: TripletSearchContext, memory_fragment: CogneeGraph -): - """Maps vector distances from search context to memory fragment.""" await memory_fragment.map_vector_distances_to_graph_nodes( - node_distances=search_context.node_distances + node_distances=vector_search.node_distances ) await memory_fragment.map_vector_distances_to_graph_edges( - edge_distances=search_context.edge_distances + edge_distances=vector_search.edge_distances ) + return await memory_fragment.calculate_top_triplet_importances(k=top_k) + async def brute_force_triplet_search( query: str, @@ -229,28 +140,23 @@ async def brute_force_triplet_search( collections.append("EdgeType_relationship_name") try: - search_context = TripletSearchContext( - query=query, - top_k=top_k, - collections=collections, - properties_to_project=properties_to_project, - node_type=node_type, - node_name=node_name, - wide_search_limit=wide_search_limit, - triplet_distance_penalty=triplet_distance_penalty, - ) + vector_search = NodeEdgeVectorSearch() - await _embed_and_retrieve_distances(search_context) + await vector_search.embed_and_retrieve_distances(query, collections, wide_search_limit) - if not search_context.has_results(): + if not vector_search.has_results(): return [] - if memory_fragment is None: - memory_fragment = await _create_memory_fragment(search_context) - - await _map_distances_to_fragment(search_context, memory_fragment) - - return await memory_fragment.calculate_top_triplet_importances(k=search_context.top_k) + return await _get_top_triplet_importances( + memory_fragment, + vector_search, + properties_to_project, + node_type, + node_name, + triplet_distance_penalty, + wide_search_limit, + top_k, + ) except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py new file mode 100644 index 000000000..777751cf2 --- /dev/null +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -0,0 +1,81 @@ +import asyncio +import time +from typing import Any, List, Optional + +from cognee.shared.logging_utils import get_logger, ERROR +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError +from cognee.infrastructure.databases.vector import get_vector_engine + +logger = get_logger(level=ERROR) + + +class NodeEdgeVectorSearch: + """Manages vector search and distance retrieval for graph nodes and edges.""" + + def __init__(self, edge_collection: str = "EdgeType_relationship_name"): + self.edge_collection = edge_collection + self.query_vector: Optional[Any] = None + self.node_distances: dict[str, list[Any]] = {} + self.edge_distances: Optional[list[Any]] = None + + def has_results(self) -> bool: + """Checks if any collections returned results.""" + return bool(self.edge_distances) or any(self.node_distances.values()) + + def set_distances_from_results(self, collections: List[str], search_results: List[List[Any]]): + """Separates search results into node and edge distances.""" + self.node_distances = {} + for collection, result in zip(collections, search_results): + if collection == self.edge_collection: + self.edge_distances = result + else: + self.node_distances[collection] = result + + def extract_relevant_node_ids(self) -> List[str]: + """Extracts unique node IDs from search results.""" + relevant_node_ids = { + str(getattr(scored_node, "id")) + for score_collection in self.node_distances.values() + if isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + return list(relevant_node_ids) + + async def embed_and_retrieve_distances( + self, query: str, collections: List[str], wide_search_limit: Optional[int] + ): + """Embeds query and retrieves vector distances from all collections.""" + vector_engine = get_vector_engine() + + query_embeddings = await vector_engine.embedding_engine.embed_text([query]) + self.query_vector = query_embeddings[0] + + start_time = time.time() + search_tasks = [ + self._search_single_collection(vector_engine, wide_search_limit, collection) + for collection in collections + ] + search_results = await asyncio.gather(*search_tasks) + + elapsed_time = time.time() - start_time + collections_with_results = sum(1 for result in search_results if result) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + self.set_distances_from_results(collections, search_results) + + async def _search_single_collection( + self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str + ): + """Searches one collection and returns results or empty list if not found.""" + try: + return await vector_engine.search( + collection_name=collection_name, + query_vector=self.query_vector, + limit=wide_search_limit, + ) + except CollectionNotFoundError: + return [] From fad75e21c1bcf8765265bcf4d2b794cc03a3a403 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 9 Jan 2026 14:53:47 +0100 Subject: [PATCH 04/14] refactor: minor tweaks --- .../utils/brute_force_triplet_search.py | 9 +++++++- .../utils/node_edge_vector_search.py | 22 ++++++++++++++----- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index ef805a127..3c3603f01 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -3,6 +3,7 @@ from typing import List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch @@ -74,7 +75,11 @@ async def _get_top_triplet_importances( ) -> List[Edge]: """Creates memory fragment (if needed), maps distances, and calculates top triplet importances.""" if memory_fragment is None: - relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None + if wide_search_limit is None: + relevant_node_ids = None + else: + relevant_node_ids = vector_search.extract_relevant_node_ids() + memory_fragment = await get_memory_fragment( properties_to_project=properties_to_project, node_type=node_type, @@ -157,6 +162,8 @@ async def brute_force_triplet_search( wide_search_limit, top_k, ) + except CollectionNotFoundError: + return [] except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index 777751cf2..08f76218c 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -12,12 +12,20 @@ logger = get_logger(level=ERROR) class NodeEdgeVectorSearch: """Manages vector search and distance retrieval for graph nodes and edges.""" - def __init__(self, edge_collection: str = "EdgeType_relationship_name"): + def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None): self.edge_collection = edge_collection + self.vector_engine = vector_engine or self._init_vector_engine() self.query_vector: Optional[Any] = None self.node_distances: dict[str, list[Any]] = {} self.edge_distances: Optional[list[Any]] = None + def _init_vector_engine(self): + try: + return get_vector_engine() + except Exception as e: + logger.error("Failed to initialize vector engine: %s", e) + raise RuntimeError("Initialization error") from e + def has_results(self) -> bool: """Checks if any collections returned results.""" return bool(self.edge_distances) or any(self.node_distances.values()) @@ -42,18 +50,20 @@ class NodeEdgeVectorSearch: } return list(relevant_node_ids) + async def _embed_query(self, query: str): + """Embeds the query and stores the resulting vector.""" + query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) + self.query_vector = query_embeddings[0] + async def embed_and_retrieve_distances( self, query: str, collections: List[str], wide_search_limit: Optional[int] ): """Embeds query and retrieves vector distances from all collections.""" - vector_engine = get_vector_engine() - - query_embeddings = await vector_engine.embedding_engine.embed_text([query]) - self.query_vector = query_embeddings[0] + await self._embed_query(query) start_time = time.time() search_tasks = [ - self._search_single_collection(vector_engine, wide_search_limit, collection) + self._search_single_collection(self.vector_engine, wide_search_limit, collection) for collection in collections ] search_results = await asyncio.gather(*search_tasks) From 58dd518690d61632dd92814ebc2c43a4b6932612 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 9 Jan 2026 14:54:10 +0100 Subject: [PATCH 05/14] chore: update tests --- .../test_brute_force_triplet_search.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index b7cbe08d7..00db1e794 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -57,7 +57,7 @@ async def test_brute_force_triplet_search_wide_search_limit_global_search(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search( @@ -79,7 +79,7 @@ async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search( @@ -101,7 +101,7 @@ async def test_brute_force_triplet_search_wide_search_default(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test", node_name=None) @@ -119,7 +119,7 @@ async def test_brute_force_triplet_search_default_collections(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test") @@ -149,7 +149,7 @@ async def test_brute_force_triplet_search_custom_collections(): custom_collections = ["CustomCol1", "CustomCol2"] with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test", collections=custom_collections) @@ -171,7 +171,7 @@ async def test_brute_force_triplet_search_always_includes_edge_collection(): collections_without_edge = ["Entity_name", "TextSummary_text"] with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test", collections=collections_without_edge) @@ -194,7 +194,7 @@ async def test_brute_force_triplet_search_all_collections_empty(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): results = await brute_force_triplet_search(query="test") @@ -216,7 +216,7 @@ async def test_brute_force_triplet_search_embeds_query(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query=query_text) @@ -249,7 +249,7 @@ async def test_brute_force_triplet_search_extracts_node_ids_global_search(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -279,7 +279,7 @@ async def test_brute_force_triplet_search_reuses_provided_fragment(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -311,7 +311,7 @@ async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -340,7 +340,7 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -430,7 +430,7 @@ async def test_brute_force_triplet_search_deduplicates_node_ids(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -471,7 +471,7 @@ async def test_brute_force_triplet_search_excludes_edge_collection(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -523,7 +523,7 @@ async def test_brute_force_triplet_search_skips_nodes_without_ids(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -564,7 +564,7 @@ async def test_brute_force_triplet_search_handles_tuple_results(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -606,7 +606,7 @@ async def test_brute_force_triplet_search_mixed_empty_collections(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -689,7 +689,7 @@ async def test_brute_force_triplet_search_vector_engine_init_error(): """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine" ) as mock_get_vector_engine, ): mock_get_vector_engine.side_effect = Exception("Initialization error") @@ -716,7 +716,7 @@ async def test_brute_force_triplet_search_collection_not_found_error(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -743,7 +743,7 @@ async def test_brute_force_triplet_search_generic_exception(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), ): @@ -769,7 +769,7 @@ async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_no with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -804,7 +804,7 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( From 701a92cdec2df38b0e6684930a7cdae2ee14c68a Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:27:18 +0100 Subject: [PATCH 06/14] feat: add batch search to node_edge_vector_search.py --- .../utils/brute_force_triplet_search.py | 4 +- .../utils/node_edge_vector_search.py | 120 ++++++++++++++---- 2 files changed, 100 insertions(+), 24 deletions(-) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 3c3603f01..a39ef50e1 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -147,7 +147,9 @@ async def brute_force_triplet_search( try: vector_search = NodeEdgeVectorSearch() - await vector_search.embed_and_retrieve_distances(query, collections, wide_search_limit) + await vector_search.embed_and_retrieve_distances( + query=query, collections=collections, wide_search_limit=wide_search_limit + ) if not vector_search.has_results(): return [] diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index 08f76218c..e8dd0dc48 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -16,8 +16,9 @@ class NodeEdgeVectorSearch: self.edge_collection = edge_collection self.vector_engine = vector_engine or self._init_vector_engine() self.query_vector: Optional[Any] = None - self.node_distances: dict[str, list[Any]] = {} - self.edge_distances: Optional[list[Any]] = None + self.node_distances: dict[str, list[list[Any]]] = {} + self.edge_distances: list[list[Any]] = [] + self.query_list_length: Optional[int] = None def _init_vector_engine(self): try: @@ -28,26 +29,56 @@ class NodeEdgeVectorSearch: def has_results(self) -> bool: """Checks if any collections returned results.""" - return bool(self.edge_distances) or any(self.node_distances.values()) + if self.query_list_length is None: + if self.edge_distances and any(self.edge_distances): + return True + return any( + bool(collection_results) for collection_results in self.node_distances.values() + ) - def set_distances_from_results(self, collections: List[str], search_results: List[List[Any]]): - """Separates search results into node and edge distances.""" + if self.edge_distances and any(self.edge_distances): + return True + return any( + any(results_per_query for results_per_query in collection_results) + for collection_results in self.node_distances.values() + ) + + def set_distances_from_results( + self, + collections: List[str], + search_results: List[List[Any]], + query_list_length: Optional[int] = None, + ): + """Separates search results into node and edge distances with stable shapes.""" self.node_distances = {} + self.edge_distances = ( + [] if query_list_length is None else [[] for _ in range(query_list_length)] + ) for collection, result in zip(collections, search_results): - if collection == self.edge_collection: - self.edge_distances = result + if not result: + empty_result = ( + [] if query_list_length is None else [[] for _ in range(query_list_length)] + ) + if collection == self.edge_collection: + self.edge_distances = empty_result + else: + self.node_distances[collection] = empty_result else: - self.node_distances[collection] = result + if collection == self.edge_collection: + self.edge_distances = result + else: + self.node_distances[collection] = result def extract_relevant_node_ids(self) -> List[str]: """Extracts unique node IDs from search results.""" - relevant_node_ids = { - str(getattr(scored_node, "id")) - for score_collection in self.node_distances.values() - if isinstance(score_collection, (list, tuple)) - for scored_node in score_collection - if getattr(scored_node, "id", None) - } + if self.query_list_length is not None: + return [] + relevant_node_ids = set() + for scored_results in self.node_distances.values(): + for scored_node in scored_results: + node_id = getattr(scored_node, "id", None) + if node_id: + relevant_node_ids.add(str(node_id)) return list(relevant_node_ids) async def _embed_query(self, query: str): @@ -55,27 +86,70 @@ class NodeEdgeVectorSearch: query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) self.query_vector = query_embeddings[0] - async def embed_and_retrieve_distances( - self, query: str, collections: List[str], wide_search_limit: Optional[int] - ): - """Embeds query and retrieves vector distances from all collections.""" - await self._embed_query(query) + async def _run_batch_search( + self, collections: List[str], query_batch: List[str] + ) -> List[List[Any]]: + """Runs batch search across all collections and returns list-of-lists per collection.""" + search_tasks = [ + self._search_batch_collection(collection, query_batch) for collection in collections + ] + return await asyncio.gather(*search_tasks) - start_time = time.time() + async def _search_batch_collection( + self, collection_name: str, query_batch: List[str] + ) -> List[List[Any]]: + """Searches one collection with batch queries and returns list-of-lists.""" + try: + return await self.vector_engine.batch_search( + collection_name=collection_name, query_texts=query_batch, limit=None + ) + except CollectionNotFoundError: + return [[]] * len(query_batch) + + async def _run_single_search( + self, collections: List[str], query: str, wide_search_limit: Optional[int] + ) -> List[List[Any]]: + """Runs single query search and wraps results in list-of-lists for shape consistency.""" + await self._embed_query(query) search_tasks = [ self._search_single_collection(self.vector_engine, wide_search_limit, collection) for collection in collections ] search_results = await asyncio.gather(*search_tasks) + return search_results + + async def embed_and_retrieve_distances( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + collections: List[str] = None, + wide_search_limit: Optional[int] = None, + ): + """Embeds query/queries and retrieves vector distances from all collections.""" + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if not collections: + raise ValueError("'collections' must be a non-empty list.") + + start_time = time.time() + + if query_batch is not None: + self.query_list_length = len(query_batch) + search_results = await self._run_batch_search(collections, query_batch) + else: + self.query_list_length = None + search_results = await self._run_single_search(collections, query, wide_search_limit) elapsed_time = time.time() - start_time - collections_with_results = sum(1 for result in search_results if result) + collections_with_results = sum(1 for result in search_results if any(result)) logger.info( f"Vector collection retrieval completed: Retrieved distances from " f"{collections_with_results} collections in {elapsed_time:.2f}s" ) - self.set_distances_from_results(collections, search_results) + self.set_distances_from_results(collections, search_results, self.query_list_length) async def _search_single_collection( self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str From fce018d43d56b76d185cae065801c030935f8cc3 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:38:24 +0100 Subject: [PATCH 07/14] test: add tests for node_edge_vector_search.py --- .../utils/node_edge_vector_search.py | 4 +- .../retrieval/test_node_edge_vector_search.py | 214 ++++++++++++++++++ 2 files changed, 216 insertions(+), 2 deletions(-) create mode 100644 cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index e8dd0dc48..db9acc121 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -36,7 +36,7 @@ class NodeEdgeVectorSearch: bool(collection_results) for collection_results in self.node_distances.values() ) - if self.edge_distances and any(self.edge_distances): + if self.edge_distances and any(inner_list for inner_list in self.edge_distances): return True return any( any(results_per_query for results_per_query in collection_results) @@ -109,7 +109,7 @@ class NodeEdgeVectorSearch: async def _run_single_search( self, collections: List[str], query: str, wide_search_limit: Optional[int] ) -> List[List[Any]]: - """Runs single query search and wraps results in list-of-lists for shape consistency.""" + """Runs single query search and returns list-of-lists per collection.""" await self._embed_query(query) search_tasks = [ self._search_single_collection(self.vector_engine, wide_search_limit, collection) diff --git a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py new file mode 100644 index 000000000..d93dce42b --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py @@ -0,0 +1,214 @@ +import pytest +from unittest.mock import AsyncMock + +from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_single_query_shape(): + """Test that single query mode produces flat lists (not list-of-lists).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + node_results = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)] + edge_results = [MockScoredResult("edge1", 0.92)] + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "EdgeType_relationship_name": + return edge_results + return node_results + + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name", "EdgeType_relationship_name"] + + await vector_search.embed_and_retrieve_distances( + query="test query", query_batch=None, collections=collections, wide_search_limit=10 + ) + + assert vector_search.query_list_length is None + assert vector_search.edge_distances == edge_results + assert vector_search.node_distances["Entity_name"] == node_results + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with(["test query"]) + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_batch_query_shape_and_empties(): + """Test that batch query mode produces list-of-lists with correct length and handles empty collections.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + + query_batch = ["query a", "query b"] + node_results_query_a = [MockScoredResult("node1", 0.95)] + node_results_query_b = [MockScoredResult("node2", 0.87)] + edge_results_query_a = [MockScoredResult("edge1", 0.92)] + edge_results_query_b = [] + + def batch_search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "EdgeType_relationship_name": + return [edge_results_query_a, edge_results_query_b] + elif collection_name == "Entity_name": + return [node_results_query_a, node_results_query_b] + elif collection_name == "MissingCollection": + raise CollectionNotFoundError("Collection not found") + return [[], []] + + mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = [ + "Entity_name", + "EdgeType_relationship_name", + "MissingCollection", + "EmptyCollection", + ] + + await vector_search.embed_and_retrieve_distances( + query=None, query_batch=query_batch, collections=collections, wide_search_limit=None + ) + + assert vector_search.query_list_length == 2 + assert len(vector_search.edge_distances) == 2 + assert vector_search.edge_distances[0] == edge_results_query_a + assert vector_search.edge_distances[1] == edge_results_query_b + assert len(vector_search.node_distances["Entity_name"]) == 2 + assert vector_search.node_distances["Entity_name"][0] == node_results_query_a + assert vector_search.node_distances["Entity_name"][1] == node_results_query_b + assert len(vector_search.node_distances["MissingCollection"]) == 2 + assert vector_search.node_distances["MissingCollection"] == [[], []] + assert len(vector_search.node_distances["EmptyCollection"]) == 2 + assert vector_search.node_distances["EmptyCollection"] == [[], []] + mock_vector_engine.embedding_engine.embed_text.assert_not_called() + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_input_validation_both_provided(): + """Test that providing both query and query_batch raises ValueError.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name"] + + with pytest.raises(ValueError, match="Cannot provide both 'query' and 'query_batch'"): + await vector_search.embed_and_retrieve_distances( + query="test", query_batch=["test1", "test2"], collections=collections + ) + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_input_validation_neither_provided(): + """Test that providing neither query nor query_batch raises ValueError.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name"] + + with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'"): + await vector_search.embed_and_retrieve_distances( + query=None, query_batch=None, collections=collections + ) + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_extract_relevant_node_ids_single_query(): + """Test that extract_relevant_node_ids returns IDs for single query mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = None + vector_search.node_distances = { + "Entity_name": [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)], + "TextSummary_text": [MockScoredResult("node1", 0.90), MockScoredResult("node3", 0.92)], + } + + node_ids = vector_search.extract_relevant_node_ids() + assert set(node_ids) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_extract_relevant_node_ids_batch(): + """Test that extract_relevant_node_ids returns empty list for batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.node_distances = { + "Entity_name": [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ], + } + + node_ids = vector_search.extract_relevant_node_ids() + assert node_ids == [] + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_single_query(): + """Test has_results returns True when results exist and False when only empties.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + + vector_search.edge_distances = [MockScoredResult("edge1", 0.92)] + vector_search.node_distances = {} + assert vector_search.has_results() is True + + vector_search.edge_distances = [] + vector_search.node_distances = {"Entity_name": [MockScoredResult("node1", 0.95)]} + assert vector_search.has_results() is True + + vector_search.edge_distances = [] + vector_search.node_distances = {} + assert vector_search.has_results() is False + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch(): + """Test has_results works correctly for batch mode with list-of-lists.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + + vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []] + vector_search.node_distances = {} + assert vector_search.has_results() is True + + vector_search.edge_distances = [[], []] + vector_search.node_distances = { + "Entity_name": [[MockScoredResult("node1", 0.95)], []], + } + assert vector_search.has_results() is True + + vector_search.edge_distances = [[], []] + vector_search.node_distances = {"Entity_name": [[], []]} + assert vector_search.has_results() is False + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_single_query_collection_not_found(): + """Test that CollectionNotFoundError in single query mode returns empty list.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["MissingCollection"] + + await vector_search.embed_and_retrieve_distances( + query="test query", query_batch=None, collections=collections, wide_search_limit=10 + ) + + assert vector_search.node_distances["MissingCollection"] == [] From 5ac288afa3e82c40579901a0463ea290e56e8197 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:44:04 +0100 Subject: [PATCH 08/14] chore: tweak type hints --- .../modules/retrieval/utils/node_edge_vector_search.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index db9acc121..ff2d98eb8 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -16,8 +16,8 @@ class NodeEdgeVectorSearch: self.edge_collection = edge_collection self.vector_engine = vector_engine or self._init_vector_engine() self.query_vector: Optional[Any] = None - self.node_distances: dict[str, list[list[Any]]] = {} - self.edge_distances: list[list[Any]] = [] + self.node_distances: dict[str, list[Any]] = {} + self.edge_distances: list[Any] = [] self.query_list_length: Optional[int] = None def _init_vector_engine(self): @@ -109,7 +109,11 @@ class NodeEdgeVectorSearch: async def _run_single_search( self, collections: List[str], query: str, wide_search_limit: Optional[int] ) -> List[List[Any]]: - """Runs single query search and returns list-of-lists per collection.""" + """Runs single query search and returns flat lists per collection. + + Returns a list where each element is a collection's results (flat list). + These are stored as flat lists in node_distances/edge_distances for single-query mode. + """ await self._embed_query(query) search_tasks = [ self._search_single_collection(self.vector_engine, wide_search_limit, collection) From 7833189001b4ca4fc0b5a46f869d9cc8632c73e8 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:40:17 +0100 Subject: [PATCH 09/14] feat: enable batch search in brute_force_triplet_search --- .../utils/brute_force_triplet_search.py | 74 ++++++++++++++----- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index a39ef50e1..ce84c1423 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Type +from typing import List, Optional, Type, Union from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError @@ -72,8 +72,18 @@ async def _get_top_triplet_importances( triplet_distance_penalty: float, wide_search_limit: Optional[int], top_k: int, -) -> List[Edge]: - """Creates memory fragment (if needed), maps distances, and calculates top triplet importances.""" + query_list_length: Optional[int] = None, +) -> Union[List[Edge], List[List[Edge]]]: + """Creates memory fragment (if needed), maps distances, and calculates top triplet importances. + + Args: + query_list_length: Number of queries in batch mode (None for single-query mode). + When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists. + + Returns: + List[Edge]: For single-query mode (query_list_length is None). + List[List[Edge]]: For batch mode (query_list_length is set), one list per query. + """ if memory_fragment is None: if wide_search_limit is None: relevant_node_ids = None @@ -89,17 +99,20 @@ async def _get_top_triplet_importances( ) await memory_fragment.map_vector_distances_to_graph_nodes( - node_distances=vector_search.node_distances + node_distances=vector_search.node_distances, query_list_length=query_list_length ) await memory_fragment.map_vector_distances_to_graph_edges( - edge_distances=vector_search.edge_distances + edge_distances=vector_search.edge_distances, query_list_length=query_list_length ) - return await memory_fragment.calculate_top_triplet_importances(k=top_k) + return await memory_fragment.calculate_top_triplet_importances( + k=top_k, query_list_length=query_list_length + ) async def brute_force_triplet_search( - query: str, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, top_k: int = 5, collections: Optional[List[str]] = None, properties_to_project: Optional[List[str]] = None, @@ -108,30 +121,49 @@ async def brute_force_triplet_search( node_name: Optional[List[str]] = None, wide_search_top_k: Optional[int] = 100, triplet_distance_penalty: Optional[float] = 3.5, -) -> List[Edge]: +) -> Union[List[Edge], List[List[Edge]]]: """ Performs a brute force search to retrieve the top triplets from the graph. Args: - query (str): The search query. + query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided. + query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided. top_k (int): The number of top results to retrieve. collections (Optional[List[str]]): List of collections to query. properties_to_project (Optional[List[str]]): List of properties to project. memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. node_type: node type to filter node_name: node name to filter - wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections + wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections. + Ignored in batch mode (always None to project full graph). triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection Returns: - list: The top triplet results. + List[Edge]: The top triplet results for single query mode (flat list). + List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists). + + Note: + In single-query mode, node_distances and edge_distances are stored as flat lists. + In batch mode, they are stored as list-of-lists (one list per query). """ - if not query or not isinstance(query, str): + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if query is not None and (not query or not isinstance(query, str)): raise ValueError("The query must be a non-empty string.") + if query_batch is not None: + if not isinstance(query_batch, list) or not query_batch: + raise ValueError("query_batch must be a non-empty list of strings.") + if not all(isinstance(q, str) and q for q in query_batch): + raise ValueError("All items in query_batch must be non-empty strings.") if top_k <= 0: raise ValueError("top_k must be a positive integer.") - wide_search_limit = wide_search_top_k if node_name is None else None + query_list_length = len(query_batch) if query_batch is not None else None + wide_search_limit = ( + None if query_list_length else (wide_search_top_k if node_name is None else None) + ) if collections is None: collections = [ @@ -148,13 +180,16 @@ async def brute_force_triplet_search( vector_search = NodeEdgeVectorSearch() await vector_search.embed_and_retrieve_distances( - query=query, collections=collections, wide_search_limit=wide_search_limit + query=None if query_list_length else query, + query_batch=query_batch if query_list_length else None, + collections=collections, + wide_search_limit=wide_search_limit, ) if not vector_search.has_results(): - return [] + return [[] for _ in range(query_list_length)] if query_list_length else [] - return await _get_top_triplet_importances( + results = await _get_top_triplet_importances( memory_fragment, vector_search, properties_to_project, @@ -163,13 +198,16 @@ async def brute_force_triplet_search( triplet_distance_penalty, wide_search_limit, top_k, + query_list_length=query_list_length, ) + + return results except CollectionNotFoundError: - return [] + return [[] for _ in range(query_list_length)] if query_list_length else [] except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", - query, + query_batch if query_list_length else [query], error, ) raise error From c20304a92ad7accf687dc9a0f07f5240fd89c58a Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:40:55 +0100 Subject: [PATCH 10/14] tests: update and expand test_brute_force_triplet_search.py and test_node_edge_vector_search.py --- .../test_brute_force_triplet_search.py | 240 +++++++++++++++++- .../retrieval/test_node_edge_vector_search.py | 26 ++ 2 files changed, 264 insertions(+), 2 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index 00db1e794..fcbfd2434 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query(): @pytest.mark.asyncio async def test_brute_force_triplet_search_none_query(): """Test that None query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): + with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."): await brute_force_triplet_search(query=None) @@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation custom_top_k = 15 await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) - mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + mock_fragment.calculate_top_triplet_importances.assert_called_once_with( + k=custom_top_k, query_list_length=None + ) @pytest.mark.asyncio @@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level(): result = await brute_force_triplet_search(query="test query") assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_single_query_regression(): + """Test that single-query mode maintains legacy behavior (flat list, ID filtering).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search( + query="q1", query_batch=None, wide_search_top_k=10, node_name=None + ) + + assert isinstance(result, list) + assert not (result and isinstance(result[0], list)) + mock_get_fragment.assert_called_once() + call_kwargs = mock_get_fragment.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] is not None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_wiring_happy_path(): + """Test that batch mode returns list-of-lists and skips ID filtering.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + return_value=[ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + ) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search(query_batch=["q1", "q2"]) + + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], list) + assert isinstance(result[1], list) + mock_get_fragment.assert_called_once() + call_kwargs = mock_get_fragment.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_shape_propagation_to_graph(): + """Test that query_list_length is passed through to graph mapping methods.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + return_value=[ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + ) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + await brute_force_triplet_search(query_batch=["q1", "q2"]) + + mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once() + node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1] + assert "query_list_length" in node_call_kwargs + assert node_call_kwargs["query_list_length"] == 2 + + mock_fragment.map_vector_distances_to_graph_edges.assert_called_once() + edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1] + assert "query_list_length" in edge_call_kwargs + assert edge_call_kwargs["query_list_length"] == 2 + + mock_fragment.calculate_top_triplet_importances.assert_called_once() + importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1] + assert "query_list_length" in importance_call_kwargs + assert importance_call_kwargs["query_list_length"] == 2 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_path_comprehensive(): + """Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + + def batch_search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + elif collection_name == "EdgeType_relationship_name": + return [ + [MockScoredResult("edge1", 0.92)], + [MockScoredResult("edge2", 0.88)], + ] + return [[], []] + + mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search( + query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"] + ) + + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], list) + assert isinstance(result[1], list) + + mock_get_fragment.assert_called_once() + fragment_call_kwargs = mock_get_fragment.call_args[1] + assert fragment_call_kwargs["relevant_ids_to_filter"] is None + + batch_search_calls = mock_vector_engine.batch_search.call_args_list + assert len(batch_search_calls) > 0 + for call in batch_search_calls: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_error_fallback(): + """Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ): + result = await brute_force_triplet_search(query_batch=["q1", "q2"]) + + assert result == [[], []] + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_cognee_graph_mapping_batch_shapes(): + """Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set.""" + from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge + + graph = CogneeGraph() + node1 = Node("node1", {"name": "Node1"}) + node2 = Node("node2", {"name": "Node2"}) + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge(node1, node2, attributes={"edge_text": "relates_to"}) + graph.add_edge(edge) + + node_distances_batch = { + "Entity_name": [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + } + + edge_distances_batch = [ + [MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})], + [MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})], + ] + + await graph.map_vector_distances_to_graph_nodes( + node_distances=node_distances_batch, query_list_length=2 + ) + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances_batch, query_list_length=2 + ) + + assert node1.attributes.get("vector_distance") == [0.95, 3.5] + assert node2.attributes.get("vector_distance") == [3.5, 0.87] + assert edge.attributes.get("vector_distance") == [0.92, 0.88] diff --git a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py index d93dce42b..1fd169fcc 100644 --- a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +++ b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py @@ -212,3 +212,29 @@ async def test_node_edge_vector_search_single_query_collection_not_found(): ) assert vector_search.node_distances["MissingCollection"] == [] + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch_nodes_only(): + """Test has_results returns True when only node distances are populated in batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.edge_distances = [[], []] + vector_search.node_distances = { + "Entity_name": [[MockScoredResult("node1", 0.95)], []], + } + + assert vector_search.has_results() is True + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch_edges_only(): + """Test has_results returns True when only edge distances are populated in batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []] + vector_search.node_distances = {} + + assert vector_search.has_results() is True From 1c8d0f6da1416b0747c3c75f5f6355beb8100243 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:51:50 +0100 Subject: [PATCH 11/14] chore: update tests and minor tweaks --- .../utils/node_edge_vector_search.py | 7 ++- .../unit/modules/graph/cognee_graph_test.py | 46 +++++++++++++++++++ .../retrieval/test_node_edge_vector_search.py | 33 +++++++++++++ 3 files changed, 85 insertions(+), 1 deletion(-) diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index ff2d98eb8..80116f6f2 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -49,7 +49,12 @@ class NodeEdgeVectorSearch: search_results: List[List[Any]], query_list_length: Optional[int] = None, ): - """Separates search results into node and edge distances with stable shapes.""" + """Separates search results into node and edge distances with stable shapes. + + Ensures all collections are present in the output, even if empty: + - Batch mode: missing/empty collections become [[]] * query_list_length + - Single mode: missing/empty collections become [] + """ self.node_distances = {} self.edge_distances = ( [] if query_list_length is None else [[] for _ in range(query_list_length)] diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 41f12e73a..a13031ac5 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1, query_list_length=1) + + +def test_normalize_query_distance_lists_flat_list_single_query(setup_graph): + """Test that flat list is normalized to list-of-lists with length 1 for single-query mode.""" + graph = setup_graph + flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)] + + result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test") + + assert len(result) == 1 + assert result[0] == flat_list + + +def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph): + """Test that nested list is used as-is when query_list_length matches.""" + graph = setup_graph + nested_list = [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + + result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test") + + assert len(result) == 2 + assert result == nested_list + + +def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph): + """Test that ValueError is raised when nested list length doesn't match query_list_length.""" + graph = setup_graph + nested_list = [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + + with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"): + graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test") + + +def test_normalize_query_distance_lists_empty_list(setup_graph): + """Test that empty list returns empty list.""" + graph = setup_graph + + result = graph._normalize_query_distance_lists([], query_list_length=None, name="test") + + assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py index 1fd169fcc..98d76ddef 100644 --- a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +++ b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py @@ -214,6 +214,39 @@ async def test_node_edge_vector_search_single_query_collection_not_found(): assert vector_search.node_distances["MissingCollection"] == [] +@pytest.mark.asyncio +async def test_node_edge_vector_search_missing_collections_single_query(): + """Test that missing collections in single-query mode are handled gracefully with empty lists.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + node_result = MockScoredResult("node1", 0.95) + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [node_result] + elif collection_name == "MissingCollection": + raise CollectionNotFoundError("Collection not found") + return [] + + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name", "MissingCollection", "EmptyCollection"] + + await vector_search.embed_and_retrieve_distances( + query="test query", query_batch=None, collections=collections, wide_search_limit=10 + ) + + assert len(vector_search.node_distances["Entity_name"]) == 1 + assert vector_search.node_distances["Entity_name"][0].id == "node1" + assert vector_search.node_distances["Entity_name"][0].score == 0.95 + assert vector_search.node_distances["MissingCollection"] == [] + assert vector_search.node_distances["EmptyCollection"] == [] + + @pytest.mark.asyncio async def test_node_edge_vector_search_has_results_batch_nodes_only(): """Test has_results returns True when only node distances are populated in batch mode.""" From 872795f0cc76b6d08ed94d188f10bc6b0b42babd Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 18:16:30 +0100 Subject: [PATCH 12/14] test: add integration test for brute_force_triplet_search --- ...brute_force_triplet_search_with_cognify.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py diff --git a/cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py b/cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py new file mode 100644 index 000000000..e07ddbd96 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py @@ -0,0 +1,67 @@ +import os +import pathlib + +import pytest +import pytest_asyncio +import cognee + +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search + + +skip_without_provider = pytest.mark.skipif( + not (os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")), + reason="requires embedding/vector provider credentials", +) + + +@pytest_asyncio.fixture +async def clean_environment(): + """Configure isolated storage and ensure cleanup before/after.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_brute_force_triplet_search_e2e") + data_directory_path = str(base_dir / ".data_storage/test_brute_force_triplet_search_e2e") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@skip_without_provider +@pytest.mark.asyncio +async def test_brute_force_triplet_search_end_to_end(clean_environment): + """Minimal end-to-end exercise of single and batch triplet search.""" + + text = """ + Cognee is an open-source AI memory engine that structures data into searchable formats for use with AI agents. + The company focuses on persistent memory systems using knowledge graphs and vector search. + It is a Berlin-based startup building infrastructure for context-aware AI applications. + """ + + await cognee.add(text) + await cognee.cognify() + + single_result = await brute_force_triplet_search(query="What is NLP?", top_k=1) + assert isinstance(single_result, list) + if single_result: + assert all(isinstance(edge, Edge) for edge in single_result) + + batch_queries = ["What is Cognee?", "What is the company's focus?"] + batch_result = await brute_force_triplet_search(query_batch=batch_queries, top_k=1) + + assert isinstance(batch_result, list) + assert len(batch_result) == len(batch_queries) + assert all(isinstance(per_query, list) for per_query in batch_result) + for per_query in batch_result: + if per_query: + assert all(isinstance(edge, Edge) for edge in per_query) From c609b73cdad17b7750631462c524cc69c2c5b847 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Tue, 13 Jan 2026 11:22:04 +0100 Subject: [PATCH 13/14] refactor: improve methods order --- .../utils/node_edge_vector_search.py | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index 80116f6f2..558b9bc0c 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -27,6 +27,39 @@ class NodeEdgeVectorSearch: logger.error("Failed to initialize vector engine: %s", e) raise RuntimeError("Initialization error") from e + async def embed_and_retrieve_distances( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + collections: List[str] = None, + wide_search_limit: Optional[int] = None, + ): + """Embeds query/queries and retrieves vector distances from all collections.""" + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if not collections: + raise ValueError("'collections' must be a non-empty list.") + + start_time = time.time() + + if query_batch is not None: + self.query_list_length = len(query_batch) + search_results = await self._run_batch_search(collections, query_batch) + else: + self.query_list_length = None + search_results = await self._run_single_search(collections, query, wide_search_limit) + + elapsed_time = time.time() - start_time + collections_with_results = sum(1 for result in search_results if any(result)) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + self.set_distances_from_results(collections, search_results, self.query_list_length) + def has_results(self) -> bool: """Checks if any collections returned results.""" if self.query_list_length is None: @@ -43,6 +76,18 @@ class NodeEdgeVectorSearch: for collection_results in self.node_distances.values() ) + def extract_relevant_node_ids(self) -> List[str]: + """Extracts unique node IDs from search results.""" + if self.query_list_length is not None: + return [] + relevant_node_ids = set() + for scored_results in self.node_distances.values(): + for scored_node in scored_results: + node_id = getattr(scored_node, "id", None) + if node_id: + relevant_node_ids.add(str(node_id)) + return list(relevant_node_ids) + def set_distances_from_results( self, collections: List[str], @@ -74,23 +119,6 @@ class NodeEdgeVectorSearch: else: self.node_distances[collection] = result - def extract_relevant_node_ids(self) -> List[str]: - """Extracts unique node IDs from search results.""" - if self.query_list_length is not None: - return [] - relevant_node_ids = set() - for scored_results in self.node_distances.values(): - for scored_node in scored_results: - node_id = getattr(scored_node, "id", None) - if node_id: - relevant_node_ids.add(str(node_id)) - return list(relevant_node_ids) - - async def _embed_query(self, query: str): - """Embeds the query and stores the resulting vector.""" - query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) - self.query_vector = query_embeddings[0] - async def _run_batch_search( self, collections: List[str], query_batch: List[str] ) -> List[List[Any]]: @@ -127,38 +155,10 @@ class NodeEdgeVectorSearch: search_results = await asyncio.gather(*search_tasks) return search_results - async def embed_and_retrieve_distances( - self, - query: Optional[str] = None, - query_batch: Optional[List[str]] = None, - collections: List[str] = None, - wide_search_limit: Optional[int] = None, - ): - """Embeds query/queries and retrieves vector distances from all collections.""" - if query is not None and query_batch is not None: - raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") - if query is None and query_batch is None: - raise ValueError("Must provide either 'query' or 'query_batch'.") - if not collections: - raise ValueError("'collections' must be a non-empty list.") - - start_time = time.time() - - if query_batch is not None: - self.query_list_length = len(query_batch) - search_results = await self._run_batch_search(collections, query_batch) - else: - self.query_list_length = None - search_results = await self._run_single_search(collections, query, wide_search_limit) - - elapsed_time = time.time() - start_time - collections_with_results = sum(1 for result in search_results if any(result)) - logger.info( - f"Vector collection retrieval completed: Retrieved distances from " - f"{collections_with_results} collections in {elapsed_time:.2f}s" - ) - - self.set_distances_from_results(collections, search_results, self.query_list_length) + async def _embed_query(self, query: str): + """Embeds the query and stores the resulting vector.""" + query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) + self.query_vector = query_embeddings[0] async def _search_single_collection( self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str From 08779398b062fd63287bf5e79b5e5733d45bfe0e Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:15:49 +0100 Subject: [PATCH 14/14] fix: deduplicate skeleton edges --- cognee/modules/graph/cognee_graph/CogneeGraph.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index bec9b15fd..f67c026d3 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -215,9 +215,6 @@ class CogneeGraph(CogneeAbstractGraph): edge_penalty=triplet_distance_penalty, ) self.add_edge(edge) - - source_node.add_skeleton_edge(edge) - target_node.add_skeleton_edge(edge) else: raise EntityNotFoundError( message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"