refactor: brute_force_triplet_search.py with context class

This commit is contained in:
lxobr 2026-01-09 09:10:29 +01:00
parent f9cb490ad9
commit 876120853f

View file

@ -23,12 +23,10 @@ def format_triplets(edges):
node1_attributes = node1.attributes node1_attributes = node1.attributes
node2_attributes = node2.attributes node2_attributes = node2.attributes
# Filter only non-None properties
node1_info = {key: value for key, value in node1_attributes.items() if value is not None} node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
node2_info = {key: value for key, value in node2_attributes.items() if value is not None} node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
edge_info = {key: value for key, value in edge_attributes.items() if value is not None} edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet
triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n" triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
triplets.append(triplet) triplets.append(triplet)
@ -67,8 +65,8 @@ async def get_memory_fragment(
return memory_fragment return memory_fragment
class _BruteForceTripletSearchEngine: class TripletSearchContext:
"""Internal search engine for brute force triplet search operations.""" """Pure state container for triplet search operations."""
def __init__( def __init__(
self, self,
@ -76,7 +74,6 @@ class _BruteForceTripletSearchEngine:
top_k: int, top_k: int,
collections: List[str], collections: List[str],
properties_to_project: Optional[List[str]], properties_to_project: Optional[List[str]],
memory_fragment: Optional[CogneeGraph],
node_type: Optional[Type], node_type: Optional[Type],
node_name: Optional[List[str]], node_name: Optional[List[str]],
wide_search_limit: Optional[int], wide_search_limit: Optional[int],
@ -86,92 +83,20 @@ class _BruteForceTripletSearchEngine:
self.top_k = top_k self.top_k = top_k
self.collections = collections self.collections = collections
self.properties_to_project = properties_to_project self.properties_to_project = properties_to_project
self.memory_fragment = memory_fragment
self.node_type = node_type self.node_type = node_type
self.node_name = node_name self.node_name = node_name
self.wide_search_limit = wide_search_limit self.wide_search_limit = wide_search_limit
self.triplet_distance_penalty = triplet_distance_penalty self.triplet_distance_penalty = triplet_distance_penalty
self.vector_engine = self._load_vector_engine()
self.query_vector = None self.query_vector = None
self.node_distances = None self.node_distances = None
self.edge_distances = None self.edge_distances = None
async def search(self) -> List[Edge]: def has_results(self) -> bool:
"""Orchestrates the brute force triplet search workflow.""" """Checks if any collections returned results."""
await self._embed_query_text() return bool(self.edge_distances or any(self.node_distances.values()))
await self._retrieve_and_set_vector_distances()
if not (self.edge_distances or any(self.node_distances.values())): def extract_relevant_node_ids(self) -> Optional[List[str]]:
return []
await self._ensure_memory_fragment_is_loaded()
await self._map_distances_to_memory_fragment()
return await self.memory_fragment.calculate_top_triplet_importances(k=self.top_k)
def _load_vector_engine(self):
"""Loads the vector engine instance."""
try:
return get_vector_engine()
except Exception as e:
logger.error("Failed to initialize vector engine: %s", e)
raise RuntimeError("Initialization error") from e
async def _embed_query_text(self):
"""Converts query text into embedding vector."""
query_embeddings = await self.vector_engine.embedding_engine.embed_text([self.query])
self.query_vector = query_embeddings[0]
async def _retrieve_and_set_vector_distances(self):
"""Searches all collections in parallel and sets node/edge distances directly."""
start_time = time.time()
search_results = await self._run_parallel_collection_searches()
elapsed_time = time.time() - start_time
collections_with_results = sum(1 for result in search_results if result)
logger.info(
f"Vector collection retrieval completed: Retrieved distances from "
f"{collections_with_results} collections in {elapsed_time:.2f}s"
)
self.node_distances = {}
for collection, result in zip(self.collections, search_results):
if collection == "EdgeType_relationship_name":
self.edge_distances = result
else:
self.node_distances[collection] = result
async def _run_parallel_collection_searches(self) -> List[List[Any]]:
"""Executes vector searches across all collections concurrently."""
search_tasks = [
self._search_single_collection(collection_name) for collection_name in self.collections
]
return await asyncio.gather(*search_tasks)
async def _search_single_collection(self, collection_name: str):
"""Searches one collection and returns results or empty list if not found."""
try:
return await self.vector_engine.search(
collection_name=collection_name,
query_vector=self.query_vector,
limit=self.wide_search_limit,
)
except CollectionNotFoundError:
return []
async def _ensure_memory_fragment_is_loaded(self):
"""Loads memory fragment if not already provided."""
if self.memory_fragment is None:
relevant_node_ids = self._extract_relevant_node_ids_for_filtering()
self.memory_fragment = await get_memory_fragment(
properties_to_project=self.properties_to_project,
node_type=self.node_type,
node_name=self.node_name,
relevant_ids_to_filter=relevant_node_ids,
triplet_distance_penalty=self.triplet_distance_penalty,
)
def _extract_relevant_node_ids_for_filtering(self) -> Optional[List[str]]:
"""Extracts unique node IDs from search results to filter graph projection.""" """Extracts unique node IDs from search results to filter graph projection."""
if self.wide_search_limit is None: if self.wide_search_limit is None:
return None return None
@ -185,14 +110,76 @@ class _BruteForceTripletSearchEngine:
} }
return list(relevant_node_ids) return list(relevant_node_ids)
async def _map_distances_to_memory_fragment(self): def set_distances_from_results(self, search_results: List[List[Any]]):
"""Maps vector distances to nodes and edges in the memory fragment.""" """Separates search results into node and edge distances."""
await self.memory_fragment.map_vector_distances_to_graph_nodes( self.node_distances = {}
node_distances=self.node_distances for collection, result in zip(self.collections, search_results):
) if collection == "EdgeType_relationship_name":
await self.memory_fragment.map_vector_distances_to_graph_edges( self.edge_distances = result
edge_distances=self.edge_distances else:
self.node_distances[collection] = result
async def _search_single_collection(
vector_engine: Any, search_context: TripletSearchContext, collection_name: str
):
"""Searches one collection and returns results or empty list if not found."""
try:
return await vector_engine.search(
collection_name=collection_name,
query_vector=search_context.query_vector,
limit=search_context.wide_search_limit,
) )
except CollectionNotFoundError:
return []
async def _embed_and_retrieve_distances(search_context: TripletSearchContext):
"""Embeds query and retrieves vector distances from all collections."""
vector_engine = get_vector_engine()
query_embeddings = await vector_engine.embedding_engine.embed_text([search_context.query])
search_context.query_vector = query_embeddings[0]
start_time = time.time()
search_tasks = [
_search_single_collection(vector_engine, search_context, collection)
for collection in search_context.collections
]
search_results = await asyncio.gather(*search_tasks)
elapsed_time = time.time() - start_time
collections_with_results = sum(1 for result in search_results if result)
logger.info(
f"Vector collection retrieval completed: Retrieved distances from "
f"{collections_with_results} collections in {elapsed_time:.2f}s"
)
search_context.set_distances_from_results(search_results)
async def _create_memory_fragment(search_context: TripletSearchContext) -> CogneeGraph:
"""Creates memory fragment using search context properties."""
relevant_node_ids = search_context.extract_relevant_node_ids()
return await get_memory_fragment(
properties_to_project=search_context.properties_to_project,
node_type=search_context.node_type,
node_name=search_context.node_name,
relevant_ids_to_filter=relevant_node_ids,
triplet_distance_penalty=search_context.triplet_distance_penalty,
)
async def _map_distances_to_fragment(
search_context: TripletSearchContext, memory_fragment: CogneeGraph
):
"""Maps vector distances from search context to memory fragment."""
await memory_fragment.map_vector_distances_to_graph_nodes(
node_distances=search_context.node_distances
)
await memory_fragment.map_vector_distances_to_graph_edges(
edge_distances=search_context.edge_distances
)
async def brute_force_triplet_search( async def brute_force_triplet_search(
@ -228,9 +215,7 @@ async def brute_force_triplet_search(
if top_k <= 0: if top_k <= 0:
raise ValueError("top_k must be a positive integer.") raise ValueError("top_k must be a positive integer.")
# Setting wide search limit based on the parameters wide_search_limit = wide_search_top_k if node_name is None else None
non_global_search = node_name is None
wide_search_limit = wide_search_top_k if non_global_search else None
if collections is None: if collections is None:
collections = [ collections = [
@ -244,18 +229,28 @@ async def brute_force_triplet_search(
collections.append("EdgeType_relationship_name") collections.append("EdgeType_relationship_name")
try: try:
engine = _BruteForceTripletSearchEngine( search_context = TripletSearchContext(
query=query, query=query,
top_k=top_k, top_k=top_k,
collections=collections, collections=collections,
properties_to_project=properties_to_project, properties_to_project=properties_to_project,
memory_fragment=memory_fragment,
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
wide_search_limit=wide_search_limit, wide_search_limit=wide_search_limit,
triplet_distance_penalty=triplet_distance_penalty, triplet_distance_penalty=triplet_distance_penalty,
) )
return await engine.search()
await _embed_and_retrieve_distances(search_context)
if not search_context.has_results():
return []
if memory_fragment is None:
memory_fragment = await _create_memory_fragment(search_context)
await _map_distances_to_fragment(search_context, memory_fragment)
return await memory_fragment.calculate_top_triplet_importances(k=search_context.top_k)
except Exception as error: except Exception as error:
logger.error( logger.error(
"Error during brute force search for query: %s. Error: %s", "Error during brute force search for query: %s. Error: %s",