refactor: minor tweaks

This commit is contained in:
lxobr 2026-01-09 14:53:47 +01:00
parent c79af6c8cc
commit fad75e21c1
2 changed files with 24 additions and 7 deletions

View file

@ -3,6 +3,7 @@ from typing import List, Optional, Type
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
@ -74,7 +75,11 @@ async def _get_top_triplet_importances(
) -> List[Edge]:
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances."""
if memory_fragment is None:
relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None
if wide_search_limit is None:
relevant_node_ids = None
else:
relevant_node_ids = vector_search.extract_relevant_node_ids()
memory_fragment = await get_memory_fragment(
properties_to_project=properties_to_project,
node_type=node_type,
@ -157,6 +162,8 @@ async def brute_force_triplet_search(
wide_search_limit,
top_k,
)
except CollectionNotFoundError:
return []
except Exception as error:
logger.error(
"Error during brute force search for query: %s. Error: %s",

View file

@ -12,12 +12,20 @@ logger = get_logger(level=ERROR)
class NodeEdgeVectorSearch:
"""Manages vector search and distance retrieval for graph nodes and edges."""
def __init__(self, edge_collection: str = "EdgeType_relationship_name"):
def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None):
self.edge_collection = edge_collection
self.vector_engine = vector_engine or self._init_vector_engine()
self.query_vector: Optional[Any] = None
self.node_distances: dict[str, list[Any]] = {}
self.edge_distances: Optional[list[Any]] = None
def _init_vector_engine(self):
try:
return get_vector_engine()
except Exception as e:
logger.error("Failed to initialize vector engine: %s", e)
raise RuntimeError("Initialization error") from e
def has_results(self) -> bool:
"""Checks if any collections returned results."""
return bool(self.edge_distances) or any(self.node_distances.values())
@ -42,18 +50,20 @@ class NodeEdgeVectorSearch:
}
return list(relevant_node_ids)
async def _embed_query(self, query: str):
"""Embeds the query and stores the resulting vector."""
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
self.query_vector = query_embeddings[0]
async def embed_and_retrieve_distances(
self, query: str, collections: List[str], wide_search_limit: Optional[int]
):
"""Embeds query and retrieves vector distances from all collections."""
vector_engine = get_vector_engine()
query_embeddings = await vector_engine.embedding_engine.embed_text([query])
self.query_vector = query_embeddings[0]
await self._embed_query(query)
start_time = time.time()
search_tasks = [
self._search_single_collection(vector_engine, wide_search_limit, collection)
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
for collection in collections
]
search_results = await asyncio.gather(*search_tasks)