graphiti/graphiti_core/search/search.py
Preston Rasmussen 088029a80c
node label filters (#265)
* node label filters

* update

* add search filters

* updates

* bump versions

* update tests

* test update
2025-02-21 12:38:01 -05:00

358 lines
12 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections import defaultdict
from time import time
from neo4j import AsyncDriver
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
CommunitySearchConfig,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
SearchResults,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
maximal_marginal_relevance,
node_bfs_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
rrf,
)
logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder: EmbedderClient,
cross_encoder: CrossEncoderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
start = time()
if query.strip() == '':
return SearchResults(
edges=[],
nodes=[],
communities=[],
)
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
edges, nodes, communities = await semaphore_gather(
edge_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.edge_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
),
node_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
),
community_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.community_config,
bfs_origin_node_uuids,
config.limit,
),
)
results = SearchResults(
edges=edges,
nodes=nodes,
communities=communities,
)
latency = (time() - start) * 1000
logger.debug(f'search returned context for query {query} in {latency} ms')
return results
async def edge_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
if config is None:
return []
search_results: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
edge_similarity_search(
driver,
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
),
edge_bfs_search(
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
),
]
)
)
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append(
await edge_bfs_search(
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
)
)
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = []
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = [
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
for result in search_results
for edge in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == EdgeReranker.cross_encoder:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
rrf_result_uuids = rrf(search_result_uuids)
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
# use rrf as a preliminary sort
sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
# node distance reranking
source_to_edge_uuid_map = defaultdict(list)
for edge in sorted_results:
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
for node_uuid in reranked_node_uuids:
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
return reranked_edges[:limit]
async def node_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
if config is None:
return []
search_results: list[list[EntityNode]] = list(
await semaphore_gather(
*[
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
node_similarity_search(
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
),
]
)
)
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append(
await node_bfs_search(
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
)
)
search_result_uuids = [[node.uuid for node in result] for result in search_results]
node_uuid_map = {node.uuid: node for result in search_results for node in result}
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = [
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
for result in search_results
for node in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == NodeReranker.cross_encoder:
# use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids)
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, rrf(search_result_uuids), center_node_uuid
)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes[:limit]
async def community_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
if config is None:
return []
search_results: list[list[CommunityNode]] = list(
await semaphore_gather(
*[
community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
),
]
)
)
search_result_uuids = [[community.uuid for community in result] for result in search_results]
community_uuid_map = {
community.uuid: community for result in search_results for community in result
}
reranked_uuids: list[str] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = [
(
community.uuid,
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
)
for result in search_results
for community in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == CommunityReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities[:limit]