refactor: minor tweaks
This commit is contained in:
parent
c79af6c8cc
commit
fad75e21c1
2 changed files with 24 additions and 7 deletions
|
|
@ -3,6 +3,7 @@ from typing import List, Optional, Type
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
||||||
|
|
@ -74,7 +75,11 @@ async def _get_top_triplet_importances(
|
||||||
) -> List[Edge]:
|
) -> List[Edge]:
|
||||||
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances."""
|
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances."""
|
||||||
if memory_fragment is None:
|
if memory_fragment is None:
|
||||||
relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None
|
if wide_search_limit is None:
|
||||||
|
relevant_node_ids = None
|
||||||
|
else:
|
||||||
|
relevant_node_ids = vector_search.extract_relevant_node_ids()
|
||||||
|
|
||||||
memory_fragment = await get_memory_fragment(
|
memory_fragment = await get_memory_fragment(
|
||||||
properties_to_project=properties_to_project,
|
properties_to_project=properties_to_project,
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
|
|
@ -157,6 +162,8 @@ async def brute_force_triplet_search(
|
||||||
wide_search_limit,
|
wide_search_limit,
|
||||||
top_k,
|
top_k,
|
||||||
)
|
)
|
||||||
|
except CollectionNotFoundError:
|
||||||
|
return []
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Error during brute force search for query: %s. Error: %s",
|
"Error during brute force search for query: %s. Error: %s",
|
||||||
|
|
|
||||||
|
|
@ -12,12 +12,20 @@ logger = get_logger(level=ERROR)
|
||||||
class NodeEdgeVectorSearch:
|
class NodeEdgeVectorSearch:
|
||||||
"""Manages vector search and distance retrieval for graph nodes and edges."""
|
"""Manages vector search and distance retrieval for graph nodes and edges."""
|
||||||
|
|
||||||
def __init__(self, edge_collection: str = "EdgeType_relationship_name"):
|
def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None):
|
||||||
self.edge_collection = edge_collection
|
self.edge_collection = edge_collection
|
||||||
|
self.vector_engine = vector_engine or self._init_vector_engine()
|
||||||
self.query_vector: Optional[Any] = None
|
self.query_vector: Optional[Any] = None
|
||||||
self.node_distances: dict[str, list[Any]] = {}
|
self.node_distances: dict[str, list[Any]] = {}
|
||||||
self.edge_distances: Optional[list[Any]] = None
|
self.edge_distances: Optional[list[Any]] = None
|
||||||
|
|
||||||
|
def _init_vector_engine(self):
|
||||||
|
try:
|
||||||
|
return get_vector_engine()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initialize vector engine: %s", e)
|
||||||
|
raise RuntimeError("Initialization error") from e
|
||||||
|
|
||||||
def has_results(self) -> bool:
|
def has_results(self) -> bool:
|
||||||
"""Checks if any collections returned results."""
|
"""Checks if any collections returned results."""
|
||||||
return bool(self.edge_distances) or any(self.node_distances.values())
|
return bool(self.edge_distances) or any(self.node_distances.values())
|
||||||
|
|
@ -42,18 +50,20 @@ class NodeEdgeVectorSearch:
|
||||||
}
|
}
|
||||||
return list(relevant_node_ids)
|
return list(relevant_node_ids)
|
||||||
|
|
||||||
|
async def _embed_query(self, query: str):
|
||||||
|
"""Embeds the query and stores the resulting vector."""
|
||||||
|
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
||||||
|
self.query_vector = query_embeddings[0]
|
||||||
|
|
||||||
async def embed_and_retrieve_distances(
|
async def embed_and_retrieve_distances(
|
||||||
self, query: str, collections: List[str], wide_search_limit: Optional[int]
|
self, query: str, collections: List[str], wide_search_limit: Optional[int]
|
||||||
):
|
):
|
||||||
"""Embeds query and retrieves vector distances from all collections."""
|
"""Embeds query and retrieves vector distances from all collections."""
|
||||||
vector_engine = get_vector_engine()
|
await self._embed_query(query)
|
||||||
|
|
||||||
query_embeddings = await vector_engine.embedding_engine.embed_text([query])
|
|
||||||
self.query_vector = query_embeddings[0]
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
search_tasks = [
|
search_tasks = [
|
||||||
self._search_single_collection(vector_engine, wide_search_limit, collection)
|
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
|
||||||
for collection in collections
|
for collection in collections
|
||||||
]
|
]
|
||||||
search_results = await asyncio.gather(*search_tasks)
|
search_results = await asyncio.gather(*search_tasks)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue