refactor: brute_force_triplet_search.py
This commit is contained in:
parent
af72dd2fc2
commit
f9cb490ad9
1 changed files with 141 additions and 75 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
|
@ -9,13 +9,12 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
|
||||
def format_triplets(edges):
|
||||
"""Formats edges into human-readable triplet strings."""
|
||||
triplets = []
|
||||
for edge in edges:
|
||||
node1 = edge.node1
|
||||
|
|
@ -51,7 +50,6 @@ async def get_memory_fragment(
|
|||
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
|
|
@ -61,18 +59,142 @@ async def get_memory_fragment(
|
|||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
# This is expected behavior - continue with empty fragment
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error during memory fragment creation: {str(e)}")
|
||||
# Still return the fragment even if projection failed
|
||||
pass
|
||||
|
||||
return memory_fragment
|
||||
|
||||
|
||||
class _BruteForceTripletSearchEngine:
|
||||
"""Internal search engine for brute force triplet search operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
collections: List[str],
|
||||
properties_to_project: Optional[List[str]],
|
||||
memory_fragment: Optional[CogneeGraph],
|
||||
node_type: Optional[Type],
|
||||
node_name: Optional[List[str]],
|
||||
wide_search_limit: Optional[int],
|
||||
triplet_distance_penalty: float,
|
||||
):
|
||||
self.query = query
|
||||
self.top_k = top_k
|
||||
self.collections = collections
|
||||
self.properties_to_project = properties_to_project
|
||||
self.memory_fragment = memory_fragment
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
self.wide_search_limit = wide_search_limit
|
||||
self.triplet_distance_penalty = triplet_distance_penalty
|
||||
self.vector_engine = self._load_vector_engine()
|
||||
self.query_vector = None
|
||||
self.node_distances = None
|
||||
self.edge_distances = None
|
||||
|
||||
async def search(self) -> List[Edge]:
|
||||
"""Orchestrates the brute force triplet search workflow."""
|
||||
await self._embed_query_text()
|
||||
await self._retrieve_and_set_vector_distances()
|
||||
|
||||
if not (self.edge_distances or any(self.node_distances.values())):
|
||||
return []
|
||||
|
||||
await self._ensure_memory_fragment_is_loaded()
|
||||
await self._map_distances_to_memory_fragment()
|
||||
|
||||
return await self.memory_fragment.calculate_top_triplet_importances(k=self.top_k)
|
||||
|
||||
def _load_vector_engine(self):
|
||||
"""Loads the vector engine instance."""
|
||||
try:
|
||||
return get_vector_engine()
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
async def _embed_query_text(self):
|
||||
"""Converts query text into embedding vector."""
|
||||
query_embeddings = await self.vector_engine.embedding_engine.embed_text([self.query])
|
||||
self.query_vector = query_embeddings[0]
|
||||
|
||||
async def _retrieve_and_set_vector_distances(self):
|
||||
"""Searches all collections in parallel and sets node/edge distances directly."""
|
||||
start_time = time.time()
|
||||
search_results = await self._run_parallel_collection_searches()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
collections_with_results = sum(1 for result in search_results if result)
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from "
|
||||
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
self.node_distances = {}
|
||||
for collection, result in zip(self.collections, search_results):
|
||||
if collection == "EdgeType_relationship_name":
|
||||
self.edge_distances = result
|
||||
else:
|
||||
self.node_distances[collection] = result
|
||||
|
||||
async def _run_parallel_collection_searches(self) -> List[List[Any]]:
|
||||
"""Executes vector searches across all collections concurrently."""
|
||||
search_tasks = [
|
||||
self._search_single_collection(collection_name) for collection_name in self.collections
|
||||
]
|
||||
return await asyncio.gather(*search_tasks)
|
||||
|
||||
async def _search_single_collection(self, collection_name: str):
|
||||
"""Searches one collection and returns results or empty list if not found."""
|
||||
try:
|
||||
return await self.vector_engine.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=self.query_vector,
|
||||
limit=self.wide_search_limit,
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
||||
async def _ensure_memory_fragment_is_loaded(self):
|
||||
"""Loads memory fragment if not already provided."""
|
||||
if self.memory_fragment is None:
|
||||
relevant_node_ids = self._extract_relevant_node_ids_for_filtering()
|
||||
self.memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=self.properties_to_project,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
relevant_ids_to_filter=relevant_node_ids,
|
||||
triplet_distance_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
def _extract_relevant_node_ids_for_filtering(self) -> Optional[List[str]]:
|
||||
"""Extracts unique node IDs from search results to filter graph projection."""
|
||||
if self.wide_search_limit is None:
|
||||
return None
|
||||
|
||||
relevant_node_ids = {
|
||||
str(getattr(scored_node, "id"))
|
||||
for score_collection in self.node_distances.values()
|
||||
if isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
return list(relevant_node_ids)
|
||||
|
||||
async def _map_distances_to_memory_fragment(self):
|
||||
"""Maps vector distances to nodes and edges in the memory fragment."""
|
||||
await self.memory_fragment.map_vector_distances_to_graph_nodes(
|
||||
node_distances=self.node_distances
|
||||
)
|
||||
await self.memory_fragment.map_vector_distances_to_graph_edges(
|
||||
edge_distances=self.edge_distances
|
||||
)
|
||||
|
||||
|
||||
async def brute_force_triplet_search(
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
|
|
@ -108,7 +230,6 @@ async def brute_force_triplet_search(
|
|||
|
||||
# Setting wide search limit based on the parameters
|
||||
non_global_search = node_name is None
|
||||
|
||||
wide_search_limit = wide_search_top_k if non_global_search else None
|
||||
|
||||
if collections is None:
|
||||
|
|
@ -123,73 +244,18 @@ async def brute_force_triplet_search(
|
|||
collections.append("EdgeType_relationship_name")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
||||
|
||||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[search_in_collection(collection_name) for collection_name in collections]
|
||||
engine = _BruteForceTripletSearchEngine(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
collections=collections,
|
||||
properties_to_project=properties_to_project,
|
||||
memory_fragment=memory_fragment,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
wide_search_limit=wide_search_limit,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
if all(not item for item in results):
|
||||
return []
|
||||
|
||||
# Final statistics
|
||||
vector_collection_search_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
||||
if wide_search_limit is not None:
|
||||
relevant_ids_to_filter = list(
|
||||
{
|
||||
str(getattr(scored_node, "id"))
|
||||
for collection_name, score_collection in node_distances.items()
|
||||
if collection_name != "EdgeType_relationship_name"
|
||||
and isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
)
|
||||
else:
|
||||
relevant_ids_to_filter = None
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
return results
|
||||
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
return await engine.search()
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error during brute force search for query: %s. Error: %s",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue