refactor: minor tweaks
This commit is contained in:
parent
c79af6c8cc
commit
fad75e21c1
2 changed files with 24 additions and 7 deletions
|
|
@ -3,6 +3,7 @@ from typing import List, Optional, Type
|
|||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
||||
|
|
@ -74,7 +75,11 @@ async def _get_top_triplet_importances(
|
|||
) -> List[Edge]:
|
||||
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances."""
|
||||
if memory_fragment is None:
|
||||
relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None
|
||||
if wide_search_limit is None:
|
||||
relevant_node_ids = None
|
||||
else:
|
||||
relevant_node_ids = vector_search.extract_relevant_node_ids()
|
||||
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
|
|
@ -157,6 +162,8 @@ async def brute_force_triplet_search(
|
|||
wide_search_limit,
|
||||
top_k,
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error during brute force search for query: %s. Error: %s",
|
||||
|
|
|
|||
|
|
@ -12,12 +12,20 @@ logger = get_logger(level=ERROR)
|
|||
class NodeEdgeVectorSearch:
|
||||
"""Manages vector search and distance retrieval for graph nodes and edges."""
|
||||
|
||||
def __init__(self, edge_collection: str = "EdgeType_relationship_name"):
|
||||
def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None):
|
||||
self.edge_collection = edge_collection
|
||||
self.vector_engine = vector_engine or self._init_vector_engine()
|
||||
self.query_vector: Optional[Any] = None
|
||||
self.node_distances: dict[str, list[Any]] = {}
|
||||
self.edge_distances: Optional[list[Any]] = None
|
||||
|
||||
def _init_vector_engine(self):
|
||||
try:
|
||||
return get_vector_engine()
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
def has_results(self) -> bool:
|
||||
"""Checks if any collections returned results."""
|
||||
return bool(self.edge_distances) or any(self.node_distances.values())
|
||||
|
|
@ -42,18 +50,20 @@ class NodeEdgeVectorSearch:
|
|||
}
|
||||
return list(relevant_node_ids)
|
||||
|
||||
async def _embed_query(self, query: str):
|
||||
"""Embeds the query and stores the resulting vector."""
|
||||
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
||||
self.query_vector = query_embeddings[0]
|
||||
|
||||
async def embed_and_retrieve_distances(
|
||||
self, query: str, collections: List[str], wide_search_limit: Optional[int]
|
||||
):
|
||||
"""Embeds query and retrieves vector distances from all collections."""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
query_embeddings = await vector_engine.embedding_engine.embed_text([query])
|
||||
self.query_vector = query_embeddings[0]
|
||||
await self._embed_query(query)
|
||||
|
||||
start_time = time.time()
|
||||
search_tasks = [
|
||||
self._search_single_collection(vector_engine, wide_search_limit, collection)
|
||||
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
|
||||
for collection in collections
|
||||
]
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue