refactor: brute_force_triplet_search.py and node_edge_vector_search.py
This commit is contained in:
parent
876120853f
commit
c79af6c8cc
2 changed files with 119 additions and 132 deletions
|
|
@ -1,14 +1,11 @@
|
||||||
import asyncio
|
from typing import List, Optional, Type
|
||||||
import time
|
|
||||||
from typing import Any, List, Optional, Type
|
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
||||||
|
|
||||||
logger = get_logger(level=ERROR)
|
logger = get_logger(level=ERROR)
|
||||||
|
|
||||||
|
|
@ -65,122 +62,36 @@ async def get_memory_fragment(
|
||||||
return memory_fragment
|
return memory_fragment
|
||||||
|
|
||||||
|
|
||||||
class TripletSearchContext:
|
async def _get_top_triplet_importances(
|
||||||
"""Pure state container for triplet search operations."""
|
memory_fragment: Optional[CogneeGraph],
|
||||||
|
vector_search: NodeEdgeVectorSearch,
|
||||||
def __init__(
|
properties_to_project: Optional[List[str]],
|
||||||
self,
|
node_type: Optional[Type],
|
||||||
query: str,
|
node_name: Optional[List[str]],
|
||||||
top_k: int,
|
triplet_distance_penalty: float,
|
||||||
collections: List[str],
|
wide_search_limit: Optional[int],
|
||||||
properties_to_project: Optional[List[str]],
|
top_k: int,
|
||||||
node_type: Optional[Type],
|
) -> List[Edge]:
|
||||||
node_name: Optional[List[str]],
|
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances."""
|
||||||
wide_search_limit: Optional[int],
|
if memory_fragment is None:
|
||||||
triplet_distance_penalty: float,
|
relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None
|
||||||
):
|
memory_fragment = await get_memory_fragment(
|
||||||
self.query = query
|
properties_to_project=properties_to_project,
|
||||||
self.top_k = top_k
|
node_type=node_type,
|
||||||
self.collections = collections
|
node_name=node_name,
|
||||||
self.properties_to_project = properties_to_project
|
relevant_ids_to_filter=relevant_node_ids,
|
||||||
self.node_type = node_type
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
self.node_name = node_name
|
|
||||||
self.wide_search_limit = wide_search_limit
|
|
||||||
self.triplet_distance_penalty = triplet_distance_penalty
|
|
||||||
|
|
||||||
self.query_vector = None
|
|
||||||
self.node_distances = None
|
|
||||||
self.edge_distances = None
|
|
||||||
|
|
||||||
def has_results(self) -> bool:
|
|
||||||
"""Checks if any collections returned results."""
|
|
||||||
return bool(self.edge_distances or any(self.node_distances.values()))
|
|
||||||
|
|
||||||
def extract_relevant_node_ids(self) -> Optional[List[str]]:
|
|
||||||
"""Extracts unique node IDs from search results to filter graph projection."""
|
|
||||||
if self.wide_search_limit is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
relevant_node_ids = {
|
|
||||||
str(getattr(scored_node, "id"))
|
|
||||||
for score_collection in self.node_distances.values()
|
|
||||||
if isinstance(score_collection, (list, tuple))
|
|
||||||
for scored_node in score_collection
|
|
||||||
if getattr(scored_node, "id", None)
|
|
||||||
}
|
|
||||||
return list(relevant_node_ids)
|
|
||||||
|
|
||||||
def set_distances_from_results(self, search_results: List[List[Any]]):
|
|
||||||
"""Separates search results into node and edge distances."""
|
|
||||||
self.node_distances = {}
|
|
||||||
for collection, result in zip(self.collections, search_results):
|
|
||||||
if collection == "EdgeType_relationship_name":
|
|
||||||
self.edge_distances = result
|
|
||||||
else:
|
|
||||||
self.node_distances[collection] = result
|
|
||||||
|
|
||||||
|
|
||||||
async def _search_single_collection(
|
|
||||||
vector_engine: Any, search_context: TripletSearchContext, collection_name: str
|
|
||||||
):
|
|
||||||
"""Searches one collection and returns results or empty list if not found."""
|
|
||||||
try:
|
|
||||||
return await vector_engine.search(
|
|
||||||
collection_name=collection_name,
|
|
||||||
query_vector=search_context.query_vector,
|
|
||||||
limit=search_context.wide_search_limit,
|
|
||||||
)
|
)
|
||||||
except CollectionNotFoundError:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def _embed_and_retrieve_distances(search_context: TripletSearchContext):
|
|
||||||
"""Embeds query and retrieves vector distances from all collections."""
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
|
|
||||||
query_embeddings = await vector_engine.embedding_engine.embed_text([search_context.query])
|
|
||||||
search_context.query_vector = query_embeddings[0]
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
search_tasks = [
|
|
||||||
_search_single_collection(vector_engine, search_context, collection)
|
|
||||||
for collection in search_context.collections
|
|
||||||
]
|
|
||||||
search_results = await asyncio.gather(*search_tasks)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
collections_with_results = sum(1 for result in search_results if result)
|
|
||||||
logger.info(
|
|
||||||
f"Vector collection retrieval completed: Retrieved distances from "
|
|
||||||
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
search_context.set_distances_from_results(search_results)
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_memory_fragment(search_context: TripletSearchContext) -> CogneeGraph:
|
|
||||||
"""Creates memory fragment using search context properties."""
|
|
||||||
relevant_node_ids = search_context.extract_relevant_node_ids()
|
|
||||||
return await get_memory_fragment(
|
|
||||||
properties_to_project=search_context.properties_to_project,
|
|
||||||
node_type=search_context.node_type,
|
|
||||||
node_name=search_context.node_name,
|
|
||||||
relevant_ids_to_filter=relevant_node_ids,
|
|
||||||
triplet_distance_penalty=search_context.triplet_distance_penalty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _map_distances_to_fragment(
|
|
||||||
search_context: TripletSearchContext, memory_fragment: CogneeGraph
|
|
||||||
):
|
|
||||||
"""Maps vector distances from search context to memory fragment."""
|
|
||||||
await memory_fragment.map_vector_distances_to_graph_nodes(
|
await memory_fragment.map_vector_distances_to_graph_nodes(
|
||||||
node_distances=search_context.node_distances
|
node_distances=vector_search.node_distances
|
||||||
)
|
)
|
||||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||||
edge_distances=search_context.edge_distances
|
edge_distances=vector_search.edge_distances
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||||
|
|
||||||
|
|
||||||
async def brute_force_triplet_search(
|
async def brute_force_triplet_search(
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -229,28 +140,23 @@ async def brute_force_triplet_search(
|
||||||
collections.append("EdgeType_relationship_name")
|
collections.append("EdgeType_relationship_name")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_context = TripletSearchContext(
|
vector_search = NodeEdgeVectorSearch()
|
||||||
query=query,
|
|
||||||
top_k=top_k,
|
|
||||||
collections=collections,
|
|
||||||
properties_to_project=properties_to_project,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
wide_search_limit=wide_search_limit,
|
|
||||||
triplet_distance_penalty=triplet_distance_penalty,
|
|
||||||
)
|
|
||||||
|
|
||||||
await _embed_and_retrieve_distances(search_context)
|
await vector_search.embed_and_retrieve_distances(query, collections, wide_search_limit)
|
||||||
|
|
||||||
if not search_context.has_results():
|
if not vector_search.has_results():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if memory_fragment is None:
|
return await _get_top_triplet_importances(
|
||||||
memory_fragment = await _create_memory_fragment(search_context)
|
memory_fragment,
|
||||||
|
vector_search,
|
||||||
await _map_distances_to_fragment(search_context, memory_fragment)
|
properties_to_project,
|
||||||
|
node_type,
|
||||||
return await memory_fragment.calculate_top_triplet_importances(k=search_context.top_k)
|
node_name,
|
||||||
|
triplet_distance_penalty,
|
||||||
|
wide_search_limit,
|
||||||
|
top_k,
|
||||||
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Error during brute force search for query: %s. Error: %s",
|
"Error during brute force search for query: %s. Error: %s",
|
||||||
|
|
|
||||||
81
cognee/modules/retrieval/utils/node_edge_vector_search.py
Normal file
81
cognee/modules/retrieval/utils/node_edge_vector_search.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
|
logger = get_logger(level=ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeEdgeVectorSearch:
|
||||||
|
"""Manages vector search and distance retrieval for graph nodes and edges."""
|
||||||
|
|
||||||
|
def __init__(self, edge_collection: str = "EdgeType_relationship_name"):
|
||||||
|
self.edge_collection = edge_collection
|
||||||
|
self.query_vector: Optional[Any] = None
|
||||||
|
self.node_distances: dict[str, list[Any]] = {}
|
||||||
|
self.edge_distances: Optional[list[Any]] = None
|
||||||
|
|
||||||
|
def has_results(self) -> bool:
|
||||||
|
"""Checks if any collections returned results."""
|
||||||
|
return bool(self.edge_distances) or any(self.node_distances.values())
|
||||||
|
|
||||||
|
def set_distances_from_results(self, collections: List[str], search_results: List[List[Any]]):
|
||||||
|
"""Separates search results into node and edge distances."""
|
||||||
|
self.node_distances = {}
|
||||||
|
for collection, result in zip(collections, search_results):
|
||||||
|
if collection == self.edge_collection:
|
||||||
|
self.edge_distances = result
|
||||||
|
else:
|
||||||
|
self.node_distances[collection] = result
|
||||||
|
|
||||||
|
def extract_relevant_node_ids(self) -> List[str]:
|
||||||
|
"""Extracts unique node IDs from search results."""
|
||||||
|
relevant_node_ids = {
|
||||||
|
str(getattr(scored_node, "id"))
|
||||||
|
for score_collection in self.node_distances.values()
|
||||||
|
if isinstance(score_collection, (list, tuple))
|
||||||
|
for scored_node in score_collection
|
||||||
|
if getattr(scored_node, "id", None)
|
||||||
|
}
|
||||||
|
return list(relevant_node_ids)
|
||||||
|
|
||||||
|
async def embed_and_retrieve_distances(
|
||||||
|
self, query: str, collections: List[str], wide_search_limit: Optional[int]
|
||||||
|
):
|
||||||
|
"""Embeds query and retrieves vector distances from all collections."""
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
query_embeddings = await vector_engine.embedding_engine.embed_text([query])
|
||||||
|
self.query_vector = query_embeddings[0]
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
search_tasks = [
|
||||||
|
self._search_single_collection(vector_engine, wide_search_limit, collection)
|
||||||
|
for collection in collections
|
||||||
|
]
|
||||||
|
search_results = await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
collections_with_results = sum(1 for result in search_results if result)
|
||||||
|
logger.info(
|
||||||
|
f"Vector collection retrieval completed: Retrieved distances from "
|
||||||
|
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_distances_from_results(collections, search_results)
|
||||||
|
|
||||||
|
async def _search_single_collection(
|
||||||
|
self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
|
||||||
|
):
|
||||||
|
"""Searches one collection and returns results or empty list if not found."""
|
||||||
|
try:
|
||||||
|
return await vector_engine.search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query_vector=self.query_vector,
|
||||||
|
limit=wide_search_limit,
|
||||||
|
)
|
||||||
|
except CollectionNotFoundError:
|
||||||
|
return []
|
||||||
Loading…
Add table
Reference in a new issue