""" Copyright 2024, Zep Software, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging from collections import defaultdict from time import time from neo4j import AsyncDriver from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.edges import EntityEdge from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import SearchRerankerError from graphiti_core.helpers import semaphore_gather from graphiti_core.nodes import CommunityNode, EntityNode from graphiti_core.search.search_config import ( DEFAULT_SEARCH_LIMIT, CommunityReranker, CommunitySearchConfig, EdgeReranker, EdgeSearchConfig, EdgeSearchMethod, NodeReranker, NodeSearchConfig, NodeSearchMethod, SearchConfig, SearchResults, ) from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import ( community_fulltext_search, community_similarity_search, edge_bfs_search, edge_fulltext_search, edge_similarity_search, episode_mentions_reranker, maximal_marginal_relevance, node_bfs_search, node_distance_reranker, node_fulltext_search, node_similarity_search, rrf, ) logger = logging.getLogger(__name__) async def search( driver: AsyncDriver, embedder: EmbedderClient, cross_encoder: CrossEncoderClient, query: str, group_ids: list[str] | None, config: SearchConfig, search_filter: SearchFilters, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, ) -> SearchResults: start = time() if query.strip() == '': return SearchResults( edges=[], nodes=[], communities=[], ) query_vector = await embedder.create(input_data=[query.replace('\n', ' ')]) # if group_ids is empty, set it to None group_ids = group_ids if group_ids else None edges, nodes, communities = await semaphore_gather( edge_search( driver, cross_encoder, query, query_vector, group_ids, config.edge_config, search_filter, center_node_uuid, bfs_origin_node_uuids, config.limit, ), node_search( driver, cross_encoder, query, query_vector, group_ids, config.node_config, search_filter, center_node_uuid, bfs_origin_node_uuids, config.limit, ), community_search( driver, cross_encoder, query, query_vector, group_ids, config.community_config, bfs_origin_node_uuids, config.limit, ), ) results = SearchResults( edges=edges, nodes=nodes, communities=communities, ) latency = (time() - start) * 1000 logger.debug(f'search returned context for query {query} in {latency} ms') return results async def edge_search( driver: AsyncDriver, cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], group_ids: list[str] | None, config: EdgeSearchConfig | None, search_filter: SearchFilters, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityEdge]: if config is None: return [] search_results: list[list[EntityEdge]] = list( await semaphore_gather( *[ edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), edge_similarity_search( driver, query_vector, None, None, search_filter, group_ids, 2 * limit, config.sim_min_score, ), edge_bfs_search( driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit ), ] ) ) if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result] search_results.append( await edge_bfs_search( driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit ) ) edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} reranked_uuids: list[str] = [] if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions: search_result_uuids = [[edge.uuid for edge in result] for result in search_results] reranked_uuids = rrf(search_result_uuids) elif config.reranker == EdgeReranker.mmr: search_result_uuids_and_vectors = [ (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024) for result in search_results for edge in result ] reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda ) elif config.reranker == EdgeReranker.cross_encoder: search_result_uuids = [[edge.uuid for edge in result] for result in search_results] rrf_result_uuids = rrf(search_result_uuids) rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges} reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys())) reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts] elif config.reranker == EdgeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') # use rrf as a preliminary sort sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results]) sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids] # node distance reranking source_to_edge_uuid_map = defaultdict(list) for edge in sorted_results: source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid) source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map] reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) for node_uuid in reranked_node_uuids: reranked_uuids.extend(source_to_edge_uuid_map[node_uuid]) reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] if config.reranker == EdgeReranker.episode_mentions: reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes)) return reranked_edges[:limit] async def node_search( driver: AsyncDriver, cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], group_ids: list[str] | None, config: NodeSearchConfig | None, search_filter: SearchFilters, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityNode]: if config is None: return [] search_results: list[list[EntityNode]] = list( await semaphore_gather( *[ node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), node_similarity_search( driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score ), node_bfs_search( driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit ), ] ) ) if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: origin_node_uuids = [node.uuid for result in search_results for node in result] search_results.append( await node_bfs_search( driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit ) ) search_result_uuids = [[node.uuid for node in result] for result in search_results] node_uuid_map = {node.uuid: node for result in search_results for node in result} reranked_uuids: list[str] = [] if config.reranker == NodeReranker.rrf: reranked_uuids = rrf(search_result_uuids) elif config.reranker == NodeReranker.mmr: search_result_uuids_and_vectors = [ (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024) for result in search_results for node in result ] reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda ) elif config.reranker == NodeReranker.cross_encoder: # use rrf as a preliminary reranker rrf_result_uuids = rrf(search_result_uuids) rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results} reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] elif config.reranker == NodeReranker.episode_mentions: reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) elif config.reranker == NodeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') reranked_uuids = await node_distance_reranker( driver, rrf(search_result_uuids), center_node_uuid ) reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] return reranked_nodes[:limit] async def community_search( driver: AsyncDriver, cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], group_ids: list[str] | None, config: CommunitySearchConfig | None, bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[CommunityNode]: if config is None: return [] search_results: list[list[CommunityNode]] = list( await semaphore_gather( *[ community_fulltext_search(driver, query, group_ids, 2 * limit), community_similarity_search( driver, query_vector, group_ids, 2 * limit, config.sim_min_score ), ] ) ) search_result_uuids = [[community.uuid for community in result] for result in search_results] community_uuid_map = { community.uuid: community for result in search_results for community in result } reranked_uuids: list[str] = [] if config.reranker == CommunityReranker.rrf: reranked_uuids = rrf(search_result_uuids) elif config.reranker == CommunityReranker.mmr: search_result_uuids_and_vectors = [ ( community.uuid, community.name_embedding if community.name_embedding is not None else [0.0] * 1024, ) for result in search_results for community in result ] reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda ) elif config.reranker == CommunityReranker.cross_encoder: summary_to_uuid_map = { node.summary: node.uuid for result in search_results for node in result } reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] return reranked_communities[:limit]