diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index ef805a127..3c3603f01 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -3,6 +3,7 @@ from typing import List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch @@ -74,7 +75,11 @@ async def _get_top_triplet_importances( ) -> List[Edge]: """Creates memory fragment (if needed), maps distances, and calculates top triplet importances.""" if memory_fragment is None: - relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None + if wide_search_limit is None: + relevant_node_ids = None + else: + relevant_node_ids = vector_search.extract_relevant_node_ids() + memory_fragment = await get_memory_fragment( properties_to_project=properties_to_project, node_type=node_type, @@ -157,6 +162,8 @@ async def brute_force_triplet_search( wide_search_limit, top_k, ) + except CollectionNotFoundError: + return [] except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index 777751cf2..08f76218c 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -12,12 +12,20 @@ logger = get_logger(level=ERROR) class NodeEdgeVectorSearch: """Manages vector search and distance retrieval for graph nodes and edges.""" - def __init__(self, edge_collection: str = "EdgeType_relationship_name"): + def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None): self.edge_collection = edge_collection + self.vector_engine = vector_engine or self._init_vector_engine() self.query_vector: Optional[Any] = None self.node_distances: dict[str, list[Any]] = {} self.edge_distances: Optional[list[Any]] = None + def _init_vector_engine(self): + try: + return get_vector_engine() + except Exception as e: + logger.error("Failed to initialize vector engine: %s", e) + raise RuntimeError("Initialization error") from e + def has_results(self) -> bool: """Checks if any collections returned results.""" return bool(self.edge_distances) or any(self.node_distances.values()) @@ -42,18 +50,20 @@ class NodeEdgeVectorSearch: } return list(relevant_node_ids) + async def _embed_query(self, query: str): + """Embeds the query and stores the resulting vector.""" + query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) + self.query_vector = query_embeddings[0] + async def embed_and_retrieve_distances( self, query: str, collections: List[str], wide_search_limit: Optional[int] ): """Embeds query and retrieves vector distances from all collections.""" - vector_engine = get_vector_engine() - - query_embeddings = await vector_engine.embedding_engine.embed_text([query]) - self.query_vector = query_embeddings[0] + await self._embed_query(query) start_time = time.time() search_tasks = [ - self._search_single_collection(vector_engine, wide_search_limit, collection) + self._search_single_collection(self.vector_engine, wide_search_limit, collection) for collection in collections ] search_results = await asyncio.gather(*search_tasks)