feat: multiquery triplet search (#1991)
<!-- .github/pull_request_template.md -->
## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
- Adds batch search support to `brute_force_triplet_search` with a new
`query_batch` parameter that accepts a list of queries in addition to
the existing single `query` parameter.
- Introduces a new `NodeEdgeVectorSearch` class that encapsulates vector
search operations, handling embedding and distance retrieval for both
single-query and batch-query modes.
- Returns `List[List[Edge]]` (one list per query) when using
`query_batch`, instead of the single `List[Edge]` format used for single
queries.
- Adds comprehensive test coverage including new test files and cases
for the `NodeEdgeVectorSearch` class, batch search functionality, and
edge cases for both single and batch modes.
- Refactors code by extracting vector search logic into the new class
and adding a helper function `_get_top_triplet_importances` to reduce
code duplication and improve maintainability.
## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added batch-query support to triplet search; batch returns per-query
nested results while single-query remains flat.
* Introduced a unified vector search controller to embed queries and
retrieve node/edge distances across collections.
* **Bug Fixes**
* Improved input validation and safer error handling for missing
collections and batch failures.
* Stopped adding duplicate skeleton edge links after edge creation.
* **Tests**
* Added comprehensive unit and integration tests covering single/batch
flows and edge cases.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
4765f9e4a0
7 changed files with 918 additions and 111 deletions
|
|
@ -215,9 +215,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
edge_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.add_edge(edge)
|
||||
|
||||
source_node.add_skeleton_edge(edge)
|
||||
target_node.add_skeleton_edge(edge)
|
||||
else:
|
||||
raise EntityNotFoundError(
|
||||
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
|
||||
|
|
|
|||
|
|
@ -1,21 +1,18 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
|
||||
def format_triplets(edges):
|
||||
"""Formats edges into human-readable triplet strings."""
|
||||
triplets = []
|
||||
for edge in edges:
|
||||
node1 = edge.node1
|
||||
|
|
@ -24,12 +21,10 @@ def format_triplets(edges):
|
|||
node1_attributes = node1.attributes
|
||||
node2_attributes = node2.attributes
|
||||
|
||||
# Filter only non-None properties
|
||||
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
|
||||
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
|
||||
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
|
||||
|
||||
# Create the formatted triplet
|
||||
triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
|
||||
triplets.append(triplet)
|
||||
|
||||
|
|
@ -51,7 +46,6 @@ async def get_memory_fragment(
|
|||
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
|
|
@ -61,20 +55,64 @@ async def get_memory_fragment(
|
|||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
# This is expected behavior - continue with empty fragment
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error during memory fragment creation: {str(e)}")
|
||||
# Still return the fragment even if projection failed
|
||||
pass
|
||||
|
||||
return memory_fragment
|
||||
|
||||
|
||||
async def _get_top_triplet_importances(
|
||||
memory_fragment: Optional[CogneeGraph],
|
||||
vector_search: NodeEdgeVectorSearch,
|
||||
properties_to_project: Optional[List[str]],
|
||||
node_type: Optional[Type],
|
||||
node_name: Optional[List[str]],
|
||||
triplet_distance_penalty: float,
|
||||
wide_search_limit: Optional[int],
|
||||
top_k: int,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances.
|
||||
|
||||
Args:
|
||||
query_list_length: Number of queries in batch mode (None for single-query mode).
|
||||
When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists.
|
||||
|
||||
Returns:
|
||||
List[Edge]: For single-query mode (query_list_length is None).
|
||||
List[List[Edge]]: For batch mode (query_list_length is set), one list per query.
|
||||
"""
|
||||
if memory_fragment is None:
|
||||
if wide_search_limit is None:
|
||||
relevant_node_ids = None
|
||||
else:
|
||||
relevant_node_ids = vector_search.extract_relevant_node_ids()
|
||||
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_node_ids,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(
|
||||
node_distances=vector_search.node_distances, query_list_length=query_list_length
|
||||
)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
edge_distances=vector_search.edge_distances, query_list_length=query_list_length
|
||||
)
|
||||
|
||||
return await memory_fragment.calculate_top_triplet_importances(
|
||||
k=top_k, query_list_length=query_list_length
|
||||
)
|
||||
|
||||
|
||||
async def brute_force_triplet_search(
|
||||
query: str,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
top_k: int = 5,
|
||||
collections: Optional[List[str]] = None,
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
|
|
@ -83,33 +121,49 @@ async def brute_force_triplet_search(
|
|||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Edge]:
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided.
|
||||
query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided.
|
||||
top_k (int): The number of top results to retrieve.
|
||||
collections (Optional[List[str]]): List of collections to query.
|
||||
properties_to_project (Optional[List[str]]): List of properties to project.
|
||||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections.
|
||||
Ignored in batch mode (always None to project full graph).
|
||||
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
List[Edge]: The top triplet results for single query mode (flat list).
|
||||
List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists).
|
||||
|
||||
Note:
|
||||
In single-query mode, node_distances and edge_distances are stored as flat lists.
|
||||
In batch mode, they are stored as list-of-lists (one list per query).
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if query is not None and (not query or not isinstance(query, str)):
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
if query_batch is not None:
|
||||
if not isinstance(query_batch, list) or not query_batch:
|
||||
raise ValueError("query_batch must be a non-empty list of strings.")
|
||||
if not all(isinstance(q, str) and q for q in query_batch):
|
||||
raise ValueError("All items in query_batch must be non-empty strings.")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
# Setting wide search limit based on the parameters
|
||||
non_global_search = node_name is None
|
||||
|
||||
wide_search_limit = wide_search_top_k if non_global_search else None
|
||||
query_list_length = len(query_batch) if query_batch is not None else None
|
||||
wide_search_limit = (
|
||||
None if query_list_length else (wide_search_top_k if node_name is None else None)
|
||||
)
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
@ -123,77 +177,37 @@ async def brute_force_triplet_search(
|
|||
collections.append("EdgeType_relationship_name")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
vector_search = NodeEdgeVectorSearch()
|
||||
|
||||
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
||||
|
||||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[search_in_collection(collection_name) for collection_name in collections]
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query=None if query_list_length else query,
|
||||
query_batch=query_batch if query_list_length else None,
|
||||
collections=collections,
|
||||
wide_search_limit=wide_search_limit,
|
||||
)
|
||||
|
||||
if all(not item for item in results):
|
||||
return []
|
||||
if not vector_search.has_results():
|
||||
return [[] for _ in range(query_list_length)] if query_list_length else []
|
||||
|
||||
# Final statistics
|
||||
vector_collection_search_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
||||
results = await _get_top_triplet_importances(
|
||||
memory_fragment,
|
||||
vector_search,
|
||||
properties_to_project,
|
||||
node_type,
|
||||
node_name,
|
||||
triplet_distance_penalty,
|
||||
wide_search_limit,
|
||||
top_k,
|
||||
query_list_length=query_list_length,
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
||||
if wide_search_limit is not None:
|
||||
relevant_ids_to_filter = list(
|
||||
{
|
||||
str(getattr(scored_node, "id"))
|
||||
for collection_name, score_collection in node_distances.items()
|
||||
if collection_name != "EdgeType_relationship_name"
|
||||
and isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
)
|
||||
else:
|
||||
relevant_ids_to_filter = None
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
return results
|
||||
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
return [[] for _ in range(query_list_length)] if query_list_length else []
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error during brute force search for query: %s. Error: %s",
|
||||
query,
|
||||
query_batch if query_list_length else [query],
|
||||
error,
|
||||
)
|
||||
raise error
|
||||
|
|
|
|||
174
cognee/modules/retrieval/utils/node_edge_vector_search.py
Normal file
174
cognee/modules/retrieval/utils/node_edge_vector_search.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
|
||||
class NodeEdgeVectorSearch:
|
||||
"""Manages vector search and distance retrieval for graph nodes and edges."""
|
||||
|
||||
def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None):
|
||||
self.edge_collection = edge_collection
|
||||
self.vector_engine = vector_engine or self._init_vector_engine()
|
||||
self.query_vector: Optional[Any] = None
|
||||
self.node_distances: dict[str, list[Any]] = {}
|
||||
self.edge_distances: list[Any] = []
|
||||
self.query_list_length: Optional[int] = None
|
||||
|
||||
def _init_vector_engine(self):
|
||||
try:
|
||||
return get_vector_engine()
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
async def embed_and_retrieve_distances(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
collections: List[str] = None,
|
||||
wide_search_limit: Optional[int] = None,
|
||||
):
|
||||
"""Embeds query/queries and retrieves vector distances from all collections."""
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if not collections:
|
||||
raise ValueError("'collections' must be a non-empty list.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if query_batch is not None:
|
||||
self.query_list_length = len(query_batch)
|
||||
search_results = await self._run_batch_search(collections, query_batch)
|
||||
else:
|
||||
self.query_list_length = None
|
||||
search_results = await self._run_single_search(collections, query, wide_search_limit)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
collections_with_results = sum(1 for result in search_results if any(result))
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from "
|
||||
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
self.set_distances_from_results(collections, search_results, self.query_list_length)
|
||||
|
||||
def has_results(self) -> bool:
|
||||
"""Checks if any collections returned results."""
|
||||
if self.query_list_length is None:
|
||||
if self.edge_distances and any(self.edge_distances):
|
||||
return True
|
||||
return any(
|
||||
bool(collection_results) for collection_results in self.node_distances.values()
|
||||
)
|
||||
|
||||
if self.edge_distances and any(inner_list for inner_list in self.edge_distances):
|
||||
return True
|
||||
return any(
|
||||
any(results_per_query for results_per_query in collection_results)
|
||||
for collection_results in self.node_distances.values()
|
||||
)
|
||||
|
||||
def extract_relevant_node_ids(self) -> List[str]:
|
||||
"""Extracts unique node IDs from search results."""
|
||||
if self.query_list_length is not None:
|
||||
return []
|
||||
relevant_node_ids = set()
|
||||
for scored_results in self.node_distances.values():
|
||||
for scored_node in scored_results:
|
||||
node_id = getattr(scored_node, "id", None)
|
||||
if node_id:
|
||||
relevant_node_ids.add(str(node_id))
|
||||
return list(relevant_node_ids)
|
||||
|
||||
def set_distances_from_results(
|
||||
self,
|
||||
collections: List[str],
|
||||
search_results: List[List[Any]],
|
||||
query_list_length: Optional[int] = None,
|
||||
):
|
||||
"""Separates search results into node and edge distances with stable shapes.
|
||||
|
||||
Ensures all collections are present in the output, even if empty:
|
||||
- Batch mode: missing/empty collections become [[]] * query_list_length
|
||||
- Single mode: missing/empty collections become []
|
||||
"""
|
||||
self.node_distances = {}
|
||||
self.edge_distances = (
|
||||
[] if query_list_length is None else [[] for _ in range(query_list_length)]
|
||||
)
|
||||
for collection, result in zip(collections, search_results):
|
||||
if not result:
|
||||
empty_result = (
|
||||
[] if query_list_length is None else [[] for _ in range(query_list_length)]
|
||||
)
|
||||
if collection == self.edge_collection:
|
||||
self.edge_distances = empty_result
|
||||
else:
|
||||
self.node_distances[collection] = empty_result
|
||||
else:
|
||||
if collection == self.edge_collection:
|
||||
self.edge_distances = result
|
||||
else:
|
||||
self.node_distances[collection] = result
|
||||
|
||||
async def _run_batch_search(
|
||||
self, collections: List[str], query_batch: List[str]
|
||||
) -> List[List[Any]]:
|
||||
"""Runs batch search across all collections and returns list-of-lists per collection."""
|
||||
search_tasks = [
|
||||
self._search_batch_collection(collection, query_batch) for collection in collections
|
||||
]
|
||||
return await asyncio.gather(*search_tasks)
|
||||
|
||||
async def _search_batch_collection(
|
||||
self, collection_name: str, query_batch: List[str]
|
||||
) -> List[List[Any]]:
|
||||
"""Searches one collection with batch queries and returns list-of-lists."""
|
||||
try:
|
||||
return await self.vector_engine.batch_search(
|
||||
collection_name=collection_name, query_texts=query_batch, limit=None
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return [[]] * len(query_batch)
|
||||
|
||||
async def _run_single_search(
|
||||
self, collections: List[str], query: str, wide_search_limit: Optional[int]
|
||||
) -> List[List[Any]]:
|
||||
"""Runs single query search and returns flat lists per collection.
|
||||
|
||||
Returns a list where each element is a collection's results (flat list).
|
||||
These are stored as flat lists in node_distances/edge_distances for single-query mode.
|
||||
"""
|
||||
await self._embed_query(query)
|
||||
search_tasks = [
|
||||
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
|
||||
for collection in collections
|
||||
]
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
return search_results
|
||||
|
||||
async def _embed_query(self, query: str):
|
||||
"""Embeds the query and stores the resulting vector."""
|
||||
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
||||
self.query_vector = query_embeddings[0]
|
||||
|
||||
async def _search_single_collection(
|
||||
self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
|
||||
):
|
||||
"""Searches one collection and returns results or empty list if not found."""
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=self.query_vector,
|
||||
limit=wide_search_limit,
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
|
||||
skip_without_provider = pytest.mark.skipif(
|
||||
not (os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")),
|
||||
reason="requires embedding/vector provider credentials",
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_environment():
|
||||
"""Configure isolated storage and ensure cleanup before/after."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_brute_force_triplet_search_e2e")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_brute_force_triplet_search_e2e")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@skip_without_provider
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_end_to_end(clean_environment):
|
||||
"""Minimal end-to-end exercise of single and batch triplet search."""
|
||||
|
||||
text = """
|
||||
Cognee is an open-source AI memory engine that structures data into searchable formats for use with AI agents.
|
||||
The company focuses on persistent memory systems using knowledge graphs and vector search.
|
||||
It is a Berlin-based startup building infrastructure for context-aware AI applications.
|
||||
"""
|
||||
|
||||
await cognee.add(text)
|
||||
await cognee.cognify()
|
||||
|
||||
single_result = await brute_force_triplet_search(query="What is NLP?", top_k=1)
|
||||
assert isinstance(single_result, list)
|
||||
if single_result:
|
||||
assert all(isinstance(edge, Edge) for edge in single_result)
|
||||
|
||||
batch_queries = ["What is Cognee?", "What is the company's focus?"]
|
||||
batch_result = await brute_force_triplet_search(query_batch=batch_queries, top_k=1)
|
||||
|
||||
assert isinstance(batch_result, list)
|
||||
assert len(batch_result) == len(batch_queries)
|
||||
assert all(isinstance(per_query, list) for per_query in batch_result)
|
||||
for per_query in batch_result:
|
||||
if per_query:
|
||||
assert all(isinstance(edge, Edge) for edge in per_query)
|
||||
|
|
@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_flat_list_single_query(setup_graph):
|
||||
"""Test that flat list is normalized to list-of-lists with length 1 for single-query mode."""
|
||||
graph = setup_graph
|
||||
flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
||||
|
||||
result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == flat_list
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph):
|
||||
"""Test that nested list is used as-is when query_list_length matches."""
|
||||
graph = setup_graph
|
||||
nested_list = [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
|
||||
result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == nested_list
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph):
|
||||
"""Test that ValueError is raised when nested list length doesn't match query_list_length."""
|
||||
graph = setup_graph
|
||||
nested_list = [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"):
|
||||
graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test")
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_empty_list(setup_graph):
|
||||
"""Test that empty list returns empty list."""
|
||||
graph = setup_graph
|
||||
|
||||
result = graph._normalize_query_distance_lists([], query_list_length=None, name="test")
|
||||
|
||||
assert result == []
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query():
|
|||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
|
|
@ -79,7 +79,7 @@ async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
|
|
@ -101,7 +101,7 @@ async def test_brute_force_triplet_search_wide_search_default():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
|
@ -119,7 +119,7 @@ async def test_brute_force_triplet_search_default_collections():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test")
|
||||
|
|
@ -149,7 +149,7 @@ async def test_brute_force_triplet_search_custom_collections():
|
|||
custom_collections = ["CustomCol1", "CustomCol2"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=custom_collections)
|
||||
|
|
@ -171,7 +171,7 @@ async def test_brute_force_triplet_search_always_includes_edge_collection():
|
|||
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
||||
|
|
@ -194,7 +194,7 @@ async def test_brute_force_triplet_search_all_collections_empty():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
results = await brute_force_triplet_search(query="test")
|
||||
|
|
@ -216,7 +216,7 @@ async def test_brute_force_triplet_search_embeds_query():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query=query_text)
|
||||
|
|
@ -249,7 +249,7 @@ async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -279,7 +279,7 @@ async def test_brute_force_triplet_search_reuses_provided_fragment():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -311,7 +311,7 @@ async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -340,7 +340,7 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(
|
||||
k=custom_top_k, query_list_length=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -430,7 +432,7 @@ async def test_brute_force_triplet_search_deduplicates_node_ids():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -471,7 +473,7 @@ async def test_brute_force_triplet_search_excludes_edge_collection():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -523,7 +525,7 @@ async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -564,7 +566,7 @@ async def test_brute_force_triplet_search_handles_tuple_results():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -606,7 +608,7 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -689,7 +691,7 @@ async def test_brute_force_triplet_search_vector_engine_init_error():
|
|||
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine"
|
||||
) as mock_get_vector_engine,
|
||||
):
|
||||
mock_get_vector_engine.side_effect = Exception("Initialization error")
|
||||
|
|
@ -716,7 +718,7 @@ async def test_brute_force_triplet_search_collection_not_found_error():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -743,7 +745,7 @@ async def test_brute_force_triplet_search_generic_exception():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
|
|
@ -769,7 +771,7 @@ async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_no
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -804,7 +806,7 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
|||
result = await brute_force_triplet_search(query="test query")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_single_query_regression():
|
||||
"""Test that single-query mode maintains legacy behavior (flat list, ID filtering)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query="q1", query_batch=None, wide_search_top_k=10, node_name=None
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert not (result and isinstance(result[0], list))
|
||||
mock_get_fragment.assert_called_once()
|
||||
call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_wiring_happy_path():
|
||||
"""Test that batch mode returns list-of-lists and skips ID filtering."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
return_value=[
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], list)
|
||||
assert isinstance(result[1], list)
|
||||
mock_get_fragment.assert_called_once()
|
||||
call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_shape_propagation_to_graph():
|
||||
"""Test that query_list_length is passed through to graph mapping methods."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
return_value=[
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once()
|
||||
node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1]
|
||||
assert "query_list_length" in node_call_kwargs
|
||||
assert node_call_kwargs["query_list_length"] == 2
|
||||
|
||||
mock_fragment.map_vector_distances_to_graph_edges.assert_called_once()
|
||||
edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1]
|
||||
assert "query_list_length" in edge_call_kwargs
|
||||
assert edge_call_kwargs["query_list_length"] == 2
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once()
|
||||
importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1]
|
||||
assert "query_list_length" in importance_call_kwargs
|
||||
assert importance_call_kwargs["query_list_length"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_path_comprehensive():
|
||||
"""Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
|
||||
def batch_search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [
|
||||
[MockScoredResult("edge1", 0.92)],
|
||||
[MockScoredResult("edge2", 0.88)],
|
||||
]
|
||||
return [[], []]
|
||||
|
||||
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"]
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], list)
|
||||
assert isinstance(result[1], list)
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
fragment_call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert fragment_call_kwargs["relevant_ids_to_filter"] is None
|
||||
|
||||
batch_search_calls = mock_vector_engine.batch_search.call_args_list
|
||||
assert len(batch_search_calls) > 0
|
||||
for call in batch_search_calls:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_error_fallback():
|
||||
"""Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
side_effect=CollectionNotFoundError("Collection not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
assert result == [[], []]
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognee_graph_mapping_batch_shapes():
|
||||
"""Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set."""
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
|
||||
graph = CogneeGraph()
|
||||
node1 = Node("node1", {"name": "Node1"})
|
||||
node2 = Node("node2", {"name": "Node2"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(node1, node2, attributes={"edge_text": "relates_to"})
|
||||
graph.add_edge(edge)
|
||||
|
||||
node_distances_batch = {
|
||||
"Entity_name": [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
}
|
||||
|
||||
edge_distances_batch = [
|
||||
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
|
||||
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(
|
||||
node_distances=node_distances_batch, query_list_length=2
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
edge_distances=edge_distances_batch, query_list_length=2
|
||||
)
|
||||
|
||||
assert node1.attributes.get("vector_distance") == [0.95, 3.5]
|
||||
assert node2.attributes.get("vector_distance") == [3.5, 0.87]
|
||||
assert edge.attributes.get("vector_distance") == [0.92, 0.88]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,273 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_single_query_shape():
|
||||
"""Test that single query mode produces flat lists (not list-of-lists)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
node_results = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
||||
edge_results = [MockScoredResult("edge1", 0.92)]
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "EdgeType_relationship_name":
|
||||
return edge_results
|
||||
return node_results
|
||||
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name", "EdgeType_relationship_name"]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
||||
)
|
||||
|
||||
assert vector_search.query_list_length is None
|
||||
assert vector_search.edge_distances == edge_results
|
||||
assert vector_search.node_distances["Entity_name"] == node_results
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with(["test query"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_batch_query_shape_and_empties():
|
||||
"""Test that batch query mode produces list-of-lists with correct length and handles empty collections."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
|
||||
query_batch = ["query a", "query b"]
|
||||
node_results_query_a = [MockScoredResult("node1", 0.95)]
|
||||
node_results_query_b = [MockScoredResult("node2", 0.87)]
|
||||
edge_results_query_a = [MockScoredResult("edge1", 0.92)]
|
||||
edge_results_query_b = []
|
||||
|
||||
def batch_search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "EdgeType_relationship_name":
|
||||
return [edge_results_query_a, edge_results_query_b]
|
||||
elif collection_name == "Entity_name":
|
||||
return [node_results_query_a, node_results_query_b]
|
||||
elif collection_name == "MissingCollection":
|
||||
raise CollectionNotFoundError("Collection not found")
|
||||
return [[], []]
|
||||
|
||||
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = [
|
||||
"Entity_name",
|
||||
"EdgeType_relationship_name",
|
||||
"MissingCollection",
|
||||
"EmptyCollection",
|
||||
]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query=None, query_batch=query_batch, collections=collections, wide_search_limit=None
|
||||
)
|
||||
|
||||
assert vector_search.query_list_length == 2
|
||||
assert len(vector_search.edge_distances) == 2
|
||||
assert vector_search.edge_distances[0] == edge_results_query_a
|
||||
assert vector_search.edge_distances[1] == edge_results_query_b
|
||||
assert len(vector_search.node_distances["Entity_name"]) == 2
|
||||
assert vector_search.node_distances["Entity_name"][0] == node_results_query_a
|
||||
assert vector_search.node_distances["Entity_name"][1] == node_results_query_b
|
||||
assert len(vector_search.node_distances["MissingCollection"]) == 2
|
||||
assert vector_search.node_distances["MissingCollection"] == [[], []]
|
||||
assert len(vector_search.node_distances["EmptyCollection"]) == 2
|
||||
assert vector_search.node_distances["EmptyCollection"] == [[], []]
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_input_validation_both_provided():
|
||||
"""Test that providing both query and query_batch raises ValueError."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name"]
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot provide both 'query' and 'query_batch'"):
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test", query_batch=["test1", "test2"], collections=collections
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_input_validation_neither_provided():
|
||||
"""Test that providing neither query nor query_batch raises ValueError."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name"]
|
||||
|
||||
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'"):
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query=None, query_batch=None, collections=collections
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_extract_relevant_node_ids_single_query():
|
||||
"""Test that extract_relevant_node_ids returns IDs for single query mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = None
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)],
|
||||
"TextSummary_text": [MockScoredResult("node1", 0.90), MockScoredResult("node3", 0.92)],
|
||||
}
|
||||
|
||||
node_ids = vector_search.extract_relevant_node_ids()
|
||||
assert set(node_ids) == {"node1", "node2", "node3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_extract_relevant_node_ids_batch():
|
||||
"""Test that extract_relevant_node_ids returns empty list for batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
],
|
||||
}
|
||||
|
||||
node_ids = vector_search.extract_relevant_node_ids()
|
||||
assert node_ids == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_single_query():
|
||||
"""Test has_results returns True when results exist and False when only empties."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
|
||||
vector_search.edge_distances = [MockScoredResult("edge1", 0.92)]
|
||||
vector_search.node_distances = {}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = []
|
||||
vector_search.node_distances = {"Entity_name": [MockScoredResult("node1", 0.95)]}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = []
|
||||
vector_search.node_distances = {}
|
||||
assert vector_search.has_results() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch():
|
||||
"""Test has_results works correctly for batch mode with list-of-lists."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
|
||||
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
||||
vector_search.node_distances = {}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = [[], []]
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
||||
}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = [[], []]
|
||||
vector_search.node_distances = {"Entity_name": [[], []]}
|
||||
assert vector_search.has_results() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_single_query_collection_not_found():
|
||||
"""Test that CollectionNotFoundError in single query mode returns empty list."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(
|
||||
side_effect=CollectionNotFoundError("Collection not found")
|
||||
)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["MissingCollection"]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
||||
)
|
||||
|
||||
assert vector_search.node_distances["MissingCollection"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_missing_collections_single_query():
|
||||
"""Test that missing collections in single-query mode are handled gracefully with empty lists."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
node_result = MockScoredResult("node1", 0.95)
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [node_result]
|
||||
elif collection_name == "MissingCollection":
|
||||
raise CollectionNotFoundError("Collection not found")
|
||||
return []
|
||||
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name", "MissingCollection", "EmptyCollection"]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
||||
)
|
||||
|
||||
assert len(vector_search.node_distances["Entity_name"]) == 1
|
||||
assert vector_search.node_distances["Entity_name"][0].id == "node1"
|
||||
assert vector_search.node_distances["Entity_name"][0].score == 0.95
|
||||
assert vector_search.node_distances["MissingCollection"] == []
|
||||
assert vector_search.node_distances["EmptyCollection"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch_nodes_only():
|
||||
"""Test has_results returns True when only node distances are populated in batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.edge_distances = [[], []]
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
||||
}
|
||||
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch_edges_only():
|
||||
"""Test has_results returns True when only edge distances are populated in batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
||||
vector_search.node_distances = {}
|
||||
|
||||
assert vector_search.has_results() is True
|
||||
Loading…
Add table
Reference in a new issue