diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index 7682b23b..3e3d7566 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -17,9 +17,9 @@ limitations under the License. import logging from typing import Any +import numpy as np import openai from openai import AsyncAzureOpenAI, AsyncOpenAI -from pydantic import BaseModel from ..helpers import semaphore_gather from ..llm_client import LLMConfig, RateLimitError @@ -31,10 +31,6 @@ logger = logging.getLogger(__name__) DEFAULT_MODEL = 'gpt-4.1-nano' -class BooleanClassifier(BaseModel): - isTrue: bool - - class OpenAIRerankerClient(CrossEncoderClient): def __init__( self, @@ -107,11 +103,15 @@ class OpenAIRerankerClient(CrossEncoderClient): ] scores: list[float] = [] for top_logprobs in responses_top_logprobs: - for logprob in top_logprobs: - if bool(logprob.token): - scores.append(logprob.logprob) + if len(top_logprobs) == 0: + continue + norm_logprobs = np.exp(top_logprobs[0].logprob) + if bool(top_logprobs[0].token): + scores.append(norm_logprobs) + else: + scores.append(1 - norm_logprobs) - results = [(passage, score) for passage, score in zip(passages, scores, strict=False)] + results = [(passage, score) for passage, score in zip(passages, scores, strict=True)] results.sort(reverse=True, key=lambda x: x[1]) return results except openai.RateLimitError as e: diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 4eb9a53b..709b86df 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -92,6 +92,7 @@ async def search( center_node_uuid, bfs_origin_node_uuids, config.limit, + config.reranker_min_score, ), node_search( driver, @@ -104,6 +105,7 @@ async def search( center_node_uuid, bfs_origin_node_uuids, config.limit, + config.reranker_min_score, ), community_search( driver, @@ -112,8 +114,8 @@ async def search( query_vector, group_ids, config.community_config, - bfs_origin_node_uuids, config.limit, + config.reranker_min_score, ), ) @@ -141,6 +143,7 @@ async def edge_search( center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, + reranker_min_score: float = 0, ) -> list[EntityEdge]: if config is None: return [] @@ -180,7 +183,7 @@ async def edge_search( if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions: search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - reranked_uuids = rrf(search_result_uuids) + reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == EdgeReranker.mmr: search_result_uuids_and_vectors = [ (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024) @@ -188,23 +191,31 @@ async def edge_search( for edge in result ] reranked_uuids = maximal_marginal_relevance( - query_vector, search_result_uuids_and_vectors, config.mmr_lambda + query_vector, + search_result_uuids_and_vectors, + config.mmr_lambda, + min_score=reranker_min_score, ) elif config.reranker == EdgeReranker.cross_encoder: search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - rrf_result_uuids = rrf(search_result_uuids) + rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score) rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges} reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys())) - reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts] + reranked_uuids = [ + fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score + ] elif config.reranker == EdgeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') # use rrf as a preliminary sort - sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results]) + sorted_result_uuids = rrf( + [[edge.uuid for edge in result] for result in search_results], + min_score=reranker_min_score, + ) sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids] # node distance reranking @@ -214,7 +225,9 @@ async def edge_search( source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map] - reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) + reranked_node_uuids = await node_distance_reranker( + driver, source_uuids, center_node_uuid, min_score=reranker_min_score + ) for node_uuid in reranked_node_uuids: reranked_uuids.extend(source_to_edge_uuid_map[node_uuid]) @@ -238,6 +251,7 @@ async def node_search( center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, + reranker_min_score: float = 0, ) -> list[EntityNode]: if config is None: return [] @@ -269,7 +283,7 @@ async def node_search( reranked_uuids: list[str] = [] if config.reranker == NodeReranker.rrf: - reranked_uuids = rrf(search_result_uuids) + reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == NodeReranker.mmr: search_result_uuids_and_vectors = [ (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024) @@ -277,24 +291,36 @@ async def node_search( for node in result ] reranked_uuids = maximal_marginal_relevance( - query_vector, search_result_uuids_and_vectors, config.mmr_lambda + query_vector, + search_result_uuids_and_vectors, + config.mmr_lambda, + min_score=reranker_min_score, ) elif config.reranker == NodeReranker.cross_encoder: # use rrf as a preliminary reranker - rrf_result_uuids = rrf(search_result_uuids) + rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score) rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results} reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) - reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] + reranked_uuids = [ + summary_to_uuid_map[fact] + for fact, score in reranked_summaries + if score >= reranker_min_score + ] elif config.reranker == NodeReranker.episode_mentions: - reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) + reranked_uuids = await episode_mentions_reranker( + driver, search_result_uuids, min_score=reranker_min_score + ) elif config.reranker == NodeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') reranked_uuids = await node_distance_reranker( - driver, rrf(search_result_uuids), center_node_uuid + driver, + rrf(search_result_uuids, min_score=reranker_min_score), + center_node_uuid, + min_score=reranker_min_score, ) reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] @@ -309,8 +335,8 @@ async def community_search( query_vector: list[float], group_ids: list[str] | None, config: CommunitySearchConfig | None, - bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, + reranker_min_score: float = 0, ) -> list[CommunityNode]: if config is None: return [] @@ -333,7 +359,7 @@ async def community_search( reranked_uuids: list[str] = [] if config.reranker == CommunityReranker.rrf: - reranked_uuids = rrf(search_result_uuids) + reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == CommunityReranker.mmr: search_result_uuids_and_vectors = [ ( @@ -344,14 +370,21 @@ async def community_search( for community in result ] reranked_uuids = maximal_marginal_relevance( - query_vector, search_result_uuids_and_vectors, config.mmr_lambda + query_vector, + search_result_uuids_and_vectors, + config.mmr_lambda, + min_score=reranker_min_score, ) elif config.reranker == CommunityReranker.cross_encoder: summary_to_uuid_map = { node.summary: node.uuid for result in search_results for node in result } reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) - reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] + reranked_uuids = [ + summary_to_uuid_map[fact] + for fact, score in reranked_summaries + if score >= reranker_min_score + ] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index 9aa23daa..f0c21bde 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -97,6 +97,7 @@ class SearchConfig(BaseModel): node_config: NodeSearchConfig | None = Field(default=None) community_config: CommunitySearchConfig | None = Field(default=None) limit: int = Field(default=DEFAULT_SEARCH_LIMIT) + reranker_min_score: float = Field(default=0) class SearchResults(BaseModel): diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 3f720736..820fb00f 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -229,8 +229,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -718,7 +718,7 @@ async def get_relevant_edges( # takes in a list of rankings of uuids -def rrf(results: list[list[str]], rank_const=1) -> list[str]: +def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]: scores: dict[str, float] = defaultdict(float) for result in results: for i, uuid in enumerate(result): @@ -729,11 +729,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: sorted_uuids = [term[0] for term in scored_uuids] - return sorted_uuids + return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score] async def node_distance_reranker( - driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str + driver: AsyncDriver, + node_uuids: list[str], + center_node_uuid: str, + min_score: float = 0, ) -> list[str]: # filter out node_uuid center node node uuid filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids)) @@ -767,12 +770,15 @@ async def node_distance_reranker( # add back in filtered center uuid if it was filtered out if center_node_uuid in node_uuids: + scores[center_node_uuid] = 0.1 filtered_uuids = [center_node_uuid] + filtered_uuids - return filtered_uuids + return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score] -async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]: +async def episode_mentions_reranker( + driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0 +) -> list[str]: # use rrf as a preliminary ranker sorted_uuids = rrf(node_uuids) scores: dict[str, float] = {} @@ -796,13 +802,14 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s # rerank on shortest distance sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) - return sorted_uuids + return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score] def maximal_marginal_relevance( query_vector: list[float], candidates: list[tuple[str, list[float]]], mmr_lambda: float = DEFAULT_MMR_LAMBDA, + min_score: float = 0, ): candidates_with_mmr: list[tuple[str, float]] = [] for candidate in candidates: @@ -812,4 +819,6 @@ def maximal_marginal_relevance( candidates_with_mmr.sort(reverse=True, key=lambda c: c[1]) - return list(set([candidate[0] for candidate in candidates_with_mmr])) + return list( + set([candidate[0] for candidate in candidates_with_mmr if candidate[1] >= min_score]) + ) diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 6fd93900..6739bf06 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -66,7 +66,7 @@ async def test_graphiti_init(): graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) results = await graphiti.search_( - query="Who is Alice's friend?", + query='Who is the user?', ) pretty_results = search_results_to_context_string(results)