diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index bec9b15fd..f67c026d3 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -215,9 +215,6 @@ class CogneeGraph(CogneeAbstractGraph): edge_penalty=triplet_distance_penalty, ) self.add_edge(edge) - - source_node.add_skeleton_edge(edge) - target_node.add_skeleton_edge(edge) else: raise EntityNotFoundError( message=f"Edge references nonexistent nodes: {source_id} -> {target_id}" diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index a70fa661b..ce84c1423 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,21 +1,18 @@ -import asyncio -import time -from typing import List, Optional, Type +from typing import List, Optional, Type, Union from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -from cognee.modules.users.models import User -from cognee.shared.utils import send_telemetry +from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch logger = get_logger(level=ERROR) def format_triplets(edges): + """Formats edges into human-readable triplet strings.""" triplets = [] for edge in edges: node1 = edge.node1 @@ -24,12 +21,10 @@ def format_triplets(edges): node1_attributes = node1.attributes node2_attributes = node2.attributes - # Filter only non-None properties node1_info = {key: value for key, value in node1_attributes.items() if value is not None} node2_info = {key: value for key, value in node2_attributes.items() if value is not None} edge_info = {key: value for key, value in edge_attributes.items() if value is not None} - # Create the formatted triplet triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n" triplets.append(triplet) @@ -51,7 +46,6 @@ async def get_memory_fragment( try: graph_engine = await get_graph_engine() - await memory_fragment.project_graph_from_db( graph_engine, node_properties_to_project=properties_to_project, @@ -61,20 +55,64 @@ async def get_memory_fragment( relevant_ids_to_filter=relevant_ids_to_filter, triplet_distance_penalty=triplet_distance_penalty, ) - except EntityNotFoundError: - # This is expected behavior - continue with empty fragment pass except Exception as e: logger.error(f"Error during memory fragment creation: {str(e)}") - # Still return the fragment even if projection failed - pass return memory_fragment +async def _get_top_triplet_importances( + memory_fragment: Optional[CogneeGraph], + vector_search: NodeEdgeVectorSearch, + properties_to_project: Optional[List[str]], + node_type: Optional[Type], + node_name: Optional[List[str]], + triplet_distance_penalty: float, + wide_search_limit: Optional[int], + top_k: int, + query_list_length: Optional[int] = None, +) -> Union[List[Edge], List[List[Edge]]]: + """Creates memory fragment (if needed), maps distances, and calculates top triplet importances. + + Args: + query_list_length: Number of queries in batch mode (None for single-query mode). + When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists. + + Returns: + List[Edge]: For single-query mode (query_list_length is None). + List[List[Edge]]: For batch mode (query_list_length is set), one list per query. + """ + if memory_fragment is None: + if wide_search_limit is None: + relevant_node_ids = None + else: + relevant_node_ids = vector_search.extract_relevant_node_ids() + + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, + node_type=node_type, + node_name=node_name, + relevant_ids_to_filter=relevant_node_ids, + triplet_distance_penalty=triplet_distance_penalty, + ) + + await memory_fragment.map_vector_distances_to_graph_nodes( + node_distances=vector_search.node_distances, query_list_length=query_list_length + ) + await memory_fragment.map_vector_distances_to_graph_edges( + edge_distances=vector_search.edge_distances, query_list_length=query_list_length + ) + + return await memory_fragment.calculate_top_triplet_importances( + k=top_k, query_list_length=query_list_length + ) + + async def brute_force_triplet_search( - query: str, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, top_k: int = 5, collections: Optional[List[str]] = None, properties_to_project: Optional[List[str]] = None, @@ -83,33 +121,49 @@ async def brute_force_triplet_search( node_name: Optional[List[str]] = None, wide_search_top_k: Optional[int] = 100, triplet_distance_penalty: Optional[float] = 3.5, -) -> List[Edge]: +) -> Union[List[Edge], List[List[Edge]]]: """ Performs a brute force search to retrieve the top triplets from the graph. Args: - query (str): The search query. + query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided. + query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided. top_k (int): The number of top results to retrieve. collections (Optional[List[str]]): List of collections to query. properties_to_project (Optional[List[str]]): List of properties to project. memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. node_type: node type to filter node_name: node name to filter - wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections + wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections. + Ignored in batch mode (always None to project full graph). triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection Returns: - list: The top triplet results. + List[Edge]: The top triplet results for single query mode (flat list). + List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists). + + Note: + In single-query mode, node_distances and edge_distances are stored as flat lists. + In batch mode, they are stored as list-of-lists (one list per query). """ - if not query or not isinstance(query, str): + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if query is not None and (not query or not isinstance(query, str)): raise ValueError("The query must be a non-empty string.") + if query_batch is not None: + if not isinstance(query_batch, list) or not query_batch: + raise ValueError("query_batch must be a non-empty list of strings.") + if not all(isinstance(q, str) and q for q in query_batch): + raise ValueError("All items in query_batch must be non-empty strings.") if top_k <= 0: raise ValueError("top_k must be a positive integer.") - # Setting wide search limit based on the parameters - non_global_search = node_name is None - - wide_search_limit = wide_search_top_k if non_global_search else None + query_list_length = len(query_batch) if query_batch is not None else None + wide_search_limit = ( + None if query_list_length else (wide_search_top_k if node_name is None else None) + ) if collections is None: collections = [ @@ -123,77 +177,37 @@ async def brute_force_triplet_search( collections.append("EdgeType_relationship_name") try: - vector_engine = get_vector_engine() - except Exception as e: - logger.error("Failed to initialize vector engine: %s", e) - raise RuntimeError("Initialization error") from e + vector_search = NodeEdgeVectorSearch() - query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0] - - async def search_in_collection(collection_name: str): - try: - return await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit - ) - except CollectionNotFoundError: - return [] - - try: - start_time = time.time() - - results = await asyncio.gather( - *[search_in_collection(collection_name) for collection_name in collections] + await vector_search.embed_and_retrieve_distances( + query=None if query_list_length else query, + query_batch=query_batch if query_list_length else None, + collections=collections, + wide_search_limit=wide_search_limit, ) - if all(not item for item in results): - return [] + if not vector_search.has_results(): + return [[] for _ in range(query_list_length)] if query_list_length else [] - # Final statistics - vector_collection_search_time = time.time() - start_time - logger.info( - f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s" + results = await _get_top_triplet_importances( + memory_fragment, + vector_search, + properties_to_project, + node_type, + node_name, + triplet_distance_penalty, + wide_search_limit, + top_k, + query_list_length=query_list_length, ) - node_distances = {collection: result for collection, result in zip(collections, results)} - - edge_distances = node_distances.get("EdgeType_relationship_name", None) - - if wide_search_limit is not None: - relevant_ids_to_filter = list( - { - str(getattr(scored_node, "id")) - for collection_name, score_collection in node_distances.items() - if collection_name != "EdgeType_relationship_name" - and isinstance(score_collection, (list, tuple)) - for scored_node in score_collection - if getattr(scored_node, "id", None) - } - ) - else: - relevant_ids_to_filter = None - - if memory_fragment is None: - memory_fragment = await get_memory_fragment( - properties_to_project=properties_to_project, - node_type=node_type, - node_name=node_name, - relevant_ids_to_filter=relevant_ids_to_filter, - triplet_distance_penalty=triplet_distance_penalty, - ) - - await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - - results = await memory_fragment.calculate_top_triplet_importances(k=top_k) - return results - except CollectionNotFoundError: - return [] + return [[] for _ in range(query_list_length)] if query_list_length else [] except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", - query, + query_batch if query_list_length else [query], error, ) raise error diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py new file mode 100644 index 000000000..558b9bc0c --- /dev/null +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -0,0 +1,174 @@ +import asyncio +import time +from typing import Any, List, Optional + +from cognee.shared.logging_utils import get_logger, ERROR +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError +from cognee.infrastructure.databases.vector import get_vector_engine + +logger = get_logger(level=ERROR) + + +class NodeEdgeVectorSearch: + """Manages vector search and distance retrieval for graph nodes and edges.""" + + def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None): + self.edge_collection = edge_collection + self.vector_engine = vector_engine or self._init_vector_engine() + self.query_vector: Optional[Any] = None + self.node_distances: dict[str, list[Any]] = {} + self.edge_distances: list[Any] = [] + self.query_list_length: Optional[int] = None + + def _init_vector_engine(self): + try: + return get_vector_engine() + except Exception as e: + logger.error("Failed to initialize vector engine: %s", e) + raise RuntimeError("Initialization error") from e + + async def embed_and_retrieve_distances( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + collections: List[str] = None, + wide_search_limit: Optional[int] = None, + ): + """Embeds query/queries and retrieves vector distances from all collections.""" + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if not collections: + raise ValueError("'collections' must be a non-empty list.") + + start_time = time.time() + + if query_batch is not None: + self.query_list_length = len(query_batch) + search_results = await self._run_batch_search(collections, query_batch) + else: + self.query_list_length = None + search_results = await self._run_single_search(collections, query, wide_search_limit) + + elapsed_time = time.time() - start_time + collections_with_results = sum(1 for result in search_results if any(result)) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + self.set_distances_from_results(collections, search_results, self.query_list_length) + + def has_results(self) -> bool: + """Checks if any collections returned results.""" + if self.query_list_length is None: + if self.edge_distances and any(self.edge_distances): + return True + return any( + bool(collection_results) for collection_results in self.node_distances.values() + ) + + if self.edge_distances and any(inner_list for inner_list in self.edge_distances): + return True + return any( + any(results_per_query for results_per_query in collection_results) + for collection_results in self.node_distances.values() + ) + + def extract_relevant_node_ids(self) -> List[str]: + """Extracts unique node IDs from search results.""" + if self.query_list_length is not None: + return [] + relevant_node_ids = set() + for scored_results in self.node_distances.values(): + for scored_node in scored_results: + node_id = getattr(scored_node, "id", None) + if node_id: + relevant_node_ids.add(str(node_id)) + return list(relevant_node_ids) + + def set_distances_from_results( + self, + collections: List[str], + search_results: List[List[Any]], + query_list_length: Optional[int] = None, + ): + """Separates search results into node and edge distances with stable shapes. + + Ensures all collections are present in the output, even if empty: + - Batch mode: missing/empty collections become [[]] * query_list_length + - Single mode: missing/empty collections become [] + """ + self.node_distances = {} + self.edge_distances = ( + [] if query_list_length is None else [[] for _ in range(query_list_length)] + ) + for collection, result in zip(collections, search_results): + if not result: + empty_result = ( + [] if query_list_length is None else [[] for _ in range(query_list_length)] + ) + if collection == self.edge_collection: + self.edge_distances = empty_result + else: + self.node_distances[collection] = empty_result + else: + if collection == self.edge_collection: + self.edge_distances = result + else: + self.node_distances[collection] = result + + async def _run_batch_search( + self, collections: List[str], query_batch: List[str] + ) -> List[List[Any]]: + """Runs batch search across all collections and returns list-of-lists per collection.""" + search_tasks = [ + self._search_batch_collection(collection, query_batch) for collection in collections + ] + return await asyncio.gather(*search_tasks) + + async def _search_batch_collection( + self, collection_name: str, query_batch: List[str] + ) -> List[List[Any]]: + """Searches one collection with batch queries and returns list-of-lists.""" + try: + return await self.vector_engine.batch_search( + collection_name=collection_name, query_texts=query_batch, limit=None + ) + except CollectionNotFoundError: + return [[]] * len(query_batch) + + async def _run_single_search( + self, collections: List[str], query: str, wide_search_limit: Optional[int] + ) -> List[List[Any]]: + """Runs single query search and returns flat lists per collection. + + Returns a list where each element is a collection's results (flat list). + These are stored as flat lists in node_distances/edge_distances for single-query mode. + """ + await self._embed_query(query) + search_tasks = [ + self._search_single_collection(self.vector_engine, wide_search_limit, collection) + for collection in collections + ] + search_results = await asyncio.gather(*search_tasks) + return search_results + + async def _embed_query(self, query: str): + """Embeds the query and stores the resulting vector.""" + query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) + self.query_vector = query_embeddings[0] + + async def _search_single_collection( + self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str + ): + """Searches one collection and returns results or empty list if not found.""" + try: + return await vector_engine.search( + collection_name=collection_name, + query_vector=self.query_vector, + limit=wide_search_limit, + ) + except CollectionNotFoundError: + return [] diff --git a/cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py b/cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py new file mode 100644 index 000000000..e07ddbd96 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py @@ -0,0 +1,67 @@ +import os +import pathlib + +import pytest +import pytest_asyncio +import cognee + +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search + + +skip_without_provider = pytest.mark.skipif( + not (os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")), + reason="requires embedding/vector provider credentials", +) + + +@pytest_asyncio.fixture +async def clean_environment(): + """Configure isolated storage and ensure cleanup before/after.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_brute_force_triplet_search_e2e") + data_directory_path = str(base_dir / ".data_storage/test_brute_force_triplet_search_e2e") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@skip_without_provider +@pytest.mark.asyncio +async def test_brute_force_triplet_search_end_to_end(clean_environment): + """Minimal end-to-end exercise of single and batch triplet search.""" + + text = """ + Cognee is an open-source AI memory engine that structures data into searchable formats for use with AI agents. + The company focuses on persistent memory systems using knowledge graphs and vector search. + It is a Berlin-based startup building infrastructure for context-aware AI applications. + """ + + await cognee.add(text) + await cognee.cognify() + + single_result = await brute_force_triplet_search(query="What is NLP?", top_k=1) + assert isinstance(single_result, list) + if single_result: + assert all(isinstance(edge, Edge) for edge in single_result) + + batch_queries = ["What is Cognee?", "What is the company's focus?"] + batch_result = await brute_force_triplet_search(query_batch=batch_queries, top_k=1) + + assert isinstance(batch_result, list) + assert len(batch_result) == len(batch_queries) + assert all(isinstance(per_query, list) for per_query in batch_result) + for per_query in batch_result: + if per_query: + assert all(isinstance(edge, Edge) for edge in per_query) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 41f12e73a..a13031ac5 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1, query_list_length=1) + + +def test_normalize_query_distance_lists_flat_list_single_query(setup_graph): + """Test that flat list is normalized to list-of-lists with length 1 for single-query mode.""" + graph = setup_graph + flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)] + + result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test") + + assert len(result) == 1 + assert result[0] == flat_list + + +def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph): + """Test that nested list is used as-is when query_list_length matches.""" + graph = setup_graph + nested_list = [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + + result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test") + + assert len(result) == 2 + assert result == nested_list + + +def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph): + """Test that ValueError is raised when nested list length doesn't match query_list_length.""" + graph = setup_graph + nested_list = [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + + with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"): + graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test") + + +def test_normalize_query_distance_lists_empty_list(setup_graph): + """Test that empty list returns empty list.""" + graph = setup_graph + + result = graph._normalize_query_distance_lists([], query_list_length=None, name="test") + + assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index b7cbe08d7..fcbfd2434 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query(): @pytest.mark.asyncio async def test_brute_force_triplet_search_none_query(): """Test that None query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): + with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."): await brute_force_triplet_search(query=None) @@ -57,7 +57,7 @@ async def test_brute_force_triplet_search_wide_search_limit_global_search(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search( @@ -79,7 +79,7 @@ async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search( @@ -101,7 +101,7 @@ async def test_brute_force_triplet_search_wide_search_default(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test", node_name=None) @@ -119,7 +119,7 @@ async def test_brute_force_triplet_search_default_collections(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test") @@ -149,7 +149,7 @@ async def test_brute_force_triplet_search_custom_collections(): custom_collections = ["CustomCol1", "CustomCol2"] with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test", collections=custom_collections) @@ -171,7 +171,7 @@ async def test_brute_force_triplet_search_always_includes_edge_collection(): collections_without_edge = ["Entity_name", "TextSummary_text"] with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query="test", collections=collections_without_edge) @@ -194,7 +194,7 @@ async def test_brute_force_triplet_search_all_collections_empty(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): results = await brute_force_triplet_search(query="test") @@ -216,7 +216,7 @@ async def test_brute_force_triplet_search_embeds_query(): mock_vector_engine.search = AsyncMock(return_value=[]) with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ): await brute_force_triplet_search(query=query_text) @@ -249,7 +249,7 @@ async def test_brute_force_triplet_search_extracts_node_ids_global_search(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -279,7 +279,7 @@ async def test_brute_force_triplet_search_reuses_provided_fragment(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -311,7 +311,7 @@ async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -340,7 +340,7 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation custom_top_k = 15 await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) - mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + mock_fragment.calculate_top_triplet_importances.assert_called_once_with( + k=custom_top_k, query_list_length=None + ) @pytest.mark.asyncio @@ -430,7 +432,7 @@ async def test_brute_force_triplet_search_deduplicates_node_ids(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -471,7 +473,7 @@ async def test_brute_force_triplet_search_excludes_edge_collection(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -523,7 +525,7 @@ async def test_brute_force_triplet_search_skips_nodes_without_ids(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -564,7 +566,7 @@ async def test_brute_force_triplet_search_handles_tuple_results(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -606,7 +608,7 @@ async def test_brute_force_triplet_search_mixed_empty_collections(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -689,7 +691,7 @@ async def test_brute_force_triplet_search_vector_engine_init_error(): """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine" ) as mock_get_vector_engine, ): mock_get_vector_engine.side_effect = Exception("Initialization error") @@ -716,7 +718,7 @@ async def test_brute_force_triplet_search_collection_not_found_error(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -743,7 +745,7 @@ async def test_brute_force_triplet_search_generic_exception(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), ): @@ -769,7 +771,7 @@ async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_no with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -804,7 +806,7 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level(): with ( patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", return_value=mock_vector_engine, ), patch( @@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level(): result = await brute_force_triplet_search(query="test query") assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_single_query_regression(): + """Test that single-query mode maintains legacy behavior (flat list, ID filtering).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search( + query="q1", query_batch=None, wide_search_top_k=10, node_name=None + ) + + assert isinstance(result, list) + assert not (result and isinstance(result[0], list)) + mock_get_fragment.assert_called_once() + call_kwargs = mock_get_fragment.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] is not None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_wiring_happy_path(): + """Test that batch mode returns list-of-lists and skips ID filtering.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + return_value=[ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + ) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search(query_batch=["q1", "q2"]) + + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], list) + assert isinstance(result[1], list) + mock_get_fragment.assert_called_once() + call_kwargs = mock_get_fragment.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_shape_propagation_to_graph(): + """Test that query_list_length is passed through to graph mapping methods.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + return_value=[ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + ) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + await brute_force_triplet_search(query_batch=["q1", "q2"]) + + mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once() + node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1] + assert "query_list_length" in node_call_kwargs + assert node_call_kwargs["query_list_length"] == 2 + + mock_fragment.map_vector_distances_to_graph_edges.assert_called_once() + edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1] + assert "query_list_length" in edge_call_kwargs + assert edge_call_kwargs["query_list_length"] == 2 + + mock_fragment.calculate_top_triplet_importances.assert_called_once() + importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1] + assert "query_list_length" in importance_call_kwargs + assert importance_call_kwargs["query_list_length"] == 2 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_path_comprehensive(): + """Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + + def batch_search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + elif collection_name == "EdgeType_relationship_name": + return [ + [MockScoredResult("edge1", 0.92)], + [MockScoredResult("edge2", 0.88)], + ] + return [[], []] + + mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search( + query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"] + ) + + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], list) + assert isinstance(result[1], list) + + mock_get_fragment.assert_called_once() + fragment_call_kwargs = mock_get_fragment.call_args[1] + assert fragment_call_kwargs["relevant_ids_to_filter"] is None + + batch_search_calls = mock_vector_engine.batch_search.call_args_list + assert len(batch_search_calls) > 0 + for call in batch_search_calls: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_error_fallback(): + """Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ): + result = await brute_force_triplet_search(query_batch=["q1", "q2"]) + + assert result == [[], []] + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_cognee_graph_mapping_batch_shapes(): + """Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set.""" + from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge + + graph = CogneeGraph() + node1 = Node("node1", {"name": "Node1"}) + node2 = Node("node2", {"name": "Node2"}) + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge(node1, node2, attributes={"edge_text": "relates_to"}) + graph.add_edge(edge) + + node_distances_batch = { + "Entity_name": [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + } + + edge_distances_batch = [ + [MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})], + [MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})], + ] + + await graph.map_vector_distances_to_graph_nodes( + node_distances=node_distances_batch, query_list_length=2 + ) + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances_batch, query_list_length=2 + ) + + assert node1.attributes.get("vector_distance") == [0.95, 3.5] + assert node2.attributes.get("vector_distance") == [3.5, 0.87] + assert edge.attributes.get("vector_distance") == [0.92, 0.88] diff --git a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py new file mode 100644 index 000000000..98d76ddef --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py @@ -0,0 +1,273 @@ +import pytest +from unittest.mock import AsyncMock + +from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_single_query_shape(): + """Test that single query mode produces flat lists (not list-of-lists).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + node_results = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)] + edge_results = [MockScoredResult("edge1", 0.92)] + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "EdgeType_relationship_name": + return edge_results + return node_results + + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name", "EdgeType_relationship_name"] + + await vector_search.embed_and_retrieve_distances( + query="test query", query_batch=None, collections=collections, wide_search_limit=10 + ) + + assert vector_search.query_list_length is None + assert vector_search.edge_distances == edge_results + assert vector_search.node_distances["Entity_name"] == node_results + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with(["test query"]) + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_batch_query_shape_and_empties(): + """Test that batch query mode produces list-of-lists with correct length and handles empty collections.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + + query_batch = ["query a", "query b"] + node_results_query_a = [MockScoredResult("node1", 0.95)] + node_results_query_b = [MockScoredResult("node2", 0.87)] + edge_results_query_a = [MockScoredResult("edge1", 0.92)] + edge_results_query_b = [] + + def batch_search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "EdgeType_relationship_name": + return [edge_results_query_a, edge_results_query_b] + elif collection_name == "Entity_name": + return [node_results_query_a, node_results_query_b] + elif collection_name == "MissingCollection": + raise CollectionNotFoundError("Collection not found") + return [[], []] + + mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = [ + "Entity_name", + "EdgeType_relationship_name", + "MissingCollection", + "EmptyCollection", + ] + + await vector_search.embed_and_retrieve_distances( + query=None, query_batch=query_batch, collections=collections, wide_search_limit=None + ) + + assert vector_search.query_list_length == 2 + assert len(vector_search.edge_distances) == 2 + assert vector_search.edge_distances[0] == edge_results_query_a + assert vector_search.edge_distances[1] == edge_results_query_b + assert len(vector_search.node_distances["Entity_name"]) == 2 + assert vector_search.node_distances["Entity_name"][0] == node_results_query_a + assert vector_search.node_distances["Entity_name"][1] == node_results_query_b + assert len(vector_search.node_distances["MissingCollection"]) == 2 + assert vector_search.node_distances["MissingCollection"] == [[], []] + assert len(vector_search.node_distances["EmptyCollection"]) == 2 + assert vector_search.node_distances["EmptyCollection"] == [[], []] + mock_vector_engine.embedding_engine.embed_text.assert_not_called() + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_input_validation_both_provided(): + """Test that providing both query and query_batch raises ValueError.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name"] + + with pytest.raises(ValueError, match="Cannot provide both 'query' and 'query_batch'"): + await vector_search.embed_and_retrieve_distances( + query="test", query_batch=["test1", "test2"], collections=collections + ) + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_input_validation_neither_provided(): + """Test that providing neither query nor query_batch raises ValueError.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name"] + + with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'"): + await vector_search.embed_and_retrieve_distances( + query=None, query_batch=None, collections=collections + ) + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_extract_relevant_node_ids_single_query(): + """Test that extract_relevant_node_ids returns IDs for single query mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = None + vector_search.node_distances = { + "Entity_name": [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)], + "TextSummary_text": [MockScoredResult("node1", 0.90), MockScoredResult("node3", 0.92)], + } + + node_ids = vector_search.extract_relevant_node_ids() + assert set(node_ids) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_extract_relevant_node_ids_batch(): + """Test that extract_relevant_node_ids returns empty list for batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.node_distances = { + "Entity_name": [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ], + } + + node_ids = vector_search.extract_relevant_node_ids() + assert node_ids == [] + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_single_query(): + """Test has_results returns True when results exist and False when only empties.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + + vector_search.edge_distances = [MockScoredResult("edge1", 0.92)] + vector_search.node_distances = {} + assert vector_search.has_results() is True + + vector_search.edge_distances = [] + vector_search.node_distances = {"Entity_name": [MockScoredResult("node1", 0.95)]} + assert vector_search.has_results() is True + + vector_search.edge_distances = [] + vector_search.node_distances = {} + assert vector_search.has_results() is False + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch(): + """Test has_results works correctly for batch mode with list-of-lists.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + + vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []] + vector_search.node_distances = {} + assert vector_search.has_results() is True + + vector_search.edge_distances = [[], []] + vector_search.node_distances = { + "Entity_name": [[MockScoredResult("node1", 0.95)], []], + } + assert vector_search.has_results() is True + + vector_search.edge_distances = [[], []] + vector_search.node_distances = {"Entity_name": [[], []]} + assert vector_search.has_results() is False + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_single_query_collection_not_found(): + """Test that CollectionNotFoundError in single query mode returns empty list.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["MissingCollection"] + + await vector_search.embed_and_retrieve_distances( + query="test query", query_batch=None, collections=collections, wide_search_limit=10 + ) + + assert vector_search.node_distances["MissingCollection"] == [] + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_missing_collections_single_query(): + """Test that missing collections in single-query mode are handled gracefully with empty lists.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + node_result = MockScoredResult("node1", 0.95) + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [node_result] + elif collection_name == "MissingCollection": + raise CollectionNotFoundError("Collection not found") + return [] + + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name", "MissingCollection", "EmptyCollection"] + + await vector_search.embed_and_retrieve_distances( + query="test query", query_batch=None, collections=collections, wide_search_limit=10 + ) + + assert len(vector_search.node_distances["Entity_name"]) == 1 + assert vector_search.node_distances["Entity_name"][0].id == "node1" + assert vector_search.node_distances["Entity_name"][0].score == 0.95 + assert vector_search.node_distances["MissingCollection"] == [] + assert vector_search.node_distances["EmptyCollection"] == [] + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch_nodes_only(): + """Test has_results returns True when only node distances are populated in batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.edge_distances = [[], []] + vector_search.node_distances = { + "Entity_name": [[MockScoredResult("node1", 0.95)], []], + } + + assert vector_search.has_results() is True + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch_edges_only(): + """Test has_results returns True when only edge distances are populated in batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []] + vector_search.node_distances = {} + + assert vector_search.has_results() is True