add reranker_min_score (#355)

* add reranker_min_score

* update divide by 0 case

* center node always gets a score of .1

* linter
This commit is contained in:
Preston Rasmussen 2025-04-15 12:33:37 -04:00 committed by GitHub
parent c8d5c45269
commit 11e19a35b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 79 additions and 36 deletions

View file

@ -17,9 +17,9 @@ limitations under the License.
import logging import logging
from typing import Any from typing import Any
import numpy as np
import openai import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI
from pydantic import BaseModel
from ..helpers import semaphore_gather from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError from ..llm_client import LLMConfig, RateLimitError
@ -31,10 +31,6 @@ logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4.1-nano' DEFAULT_MODEL = 'gpt-4.1-nano'
class BooleanClassifier(BaseModel):
isTrue: bool
class OpenAIRerankerClient(CrossEncoderClient): class OpenAIRerankerClient(CrossEncoderClient):
def __init__( def __init__(
self, self,
@ -107,11 +103,15 @@ class OpenAIRerankerClient(CrossEncoderClient):
] ]
scores: list[float] = [] scores: list[float] = []
for top_logprobs in responses_top_logprobs: for top_logprobs in responses_top_logprobs:
for logprob in top_logprobs: if len(top_logprobs) == 0:
if bool(logprob.token): continue
scores.append(logprob.logprob) norm_logprobs = np.exp(top_logprobs[0].logprob)
if bool(top_logprobs[0].token):
scores.append(norm_logprobs)
else:
scores.append(1 - norm_logprobs)
results = [(passage, score) for passage, score in zip(passages, scores, strict=False)] results = [(passage, score) for passage, score in zip(passages, scores, strict=True)]
results.sort(reverse=True, key=lambda x: x[1]) results.sort(reverse=True, key=lambda x: x[1])
return results return results
except openai.RateLimitError as e: except openai.RateLimitError as e:

View file

@ -92,6 +92,7 @@ async def search(
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids, bfs_origin_node_uuids,
config.limit, config.limit,
config.reranker_min_score,
), ),
node_search( node_search(
driver, driver,
@ -104,6 +105,7 @@ async def search(
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids, bfs_origin_node_uuids,
config.limit, config.limit,
config.reranker_min_score,
), ),
community_search( community_search(
driver, driver,
@ -112,8 +114,8 @@ async def search(
query_vector, query_vector,
group_ids, group_ids,
config.community_config, config.community_config,
bfs_origin_node_uuids,
config.limit, config.limit,
config.reranker_min_score,
), ),
) )
@ -141,6 +143,7 @@ async def edge_search(
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
if config is None: if config is None:
return [] return []
@ -180,7 +183,7 @@ async def edge_search(
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions: if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results] search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids) reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr: elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = [ search_result_uuids_and_vectors = [
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024) (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
@ -188,23 +191,31 @@ async def edge_search(
for edge in result for edge in result
] ]
reranked_uuids = maximal_marginal_relevance( reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
min_score=reranker_min_score,
) )
elif config.reranker == EdgeReranker.cross_encoder: elif config.reranker == EdgeReranker.cross_encoder:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results] search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
rrf_result_uuids = rrf(search_result_uuids) rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges} fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys())) reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts] reranked_uuids = [
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
]
elif config.reranker == EdgeReranker.node_distance: elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
# use rrf as a preliminary sort # use rrf as a preliminary sort
sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results]) sorted_result_uuids = rrf(
[[edge.uuid for edge in result] for result in search_results],
min_score=reranker_min_score,
)
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids] sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
# node distance reranking # node distance reranking
@ -214,7 +225,9 @@ async def edge_search(
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map] source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) reranked_node_uuids = await node_distance_reranker(
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
)
for node_uuid in reranked_node_uuids: for node_uuid in reranked_node_uuids:
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid]) reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
@ -238,6 +251,7 @@ async def node_search(
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EntityNode]: ) -> list[EntityNode]:
if config is None: if config is None:
return [] return []
@ -269,7 +283,7 @@ async def node_search(
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf: if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids) reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr: elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = [ search_result_uuids_and_vectors = [
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024) (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
@ -277,24 +291,36 @@ async def node_search(
for node in result for node in result
] ]
reranked_uuids = maximal_marginal_relevance( reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
min_score=reranker_min_score,
) )
elif config.reranker == NodeReranker.cross_encoder: elif config.reranker == NodeReranker.cross_encoder:
# use rrf as a preliminary reranker # use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids) rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results} summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] reranked_uuids = [
summary_to_uuid_map[fact]
for fact, score in reranked_summaries
if score >= reranker_min_score
]
elif config.reranker == NodeReranker.episode_mentions: elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) reranked_uuids = await episode_mentions_reranker(
driver, search_result_uuids, min_score=reranker_min_score
)
elif config.reranker == NodeReranker.node_distance: elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker( reranked_uuids = await node_distance_reranker(
driver, rrf(search_result_uuids), center_node_uuid driver,
rrf(search_result_uuids, min_score=reranker_min_score),
center_node_uuid,
min_score=reranker_min_score,
) )
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
@ -309,8 +335,8 @@ async def community_search(
query_vector: list[float], query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: CommunitySearchConfig | None, config: CommunitySearchConfig | None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
if config is None: if config is None:
return [] return []
@ -333,7 +359,7 @@ async def community_search(
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
if config.reranker == CommunityReranker.rrf: if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids) reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr: elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = [ search_result_uuids_and_vectors = [
( (
@ -344,14 +370,21 @@ async def community_search(
for community in result for community in result
] ]
reranked_uuids = maximal_marginal_relevance( reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
min_score=reranker_min_score,
) )
elif config.reranker == CommunityReranker.cross_encoder: elif config.reranker == CommunityReranker.cross_encoder:
summary_to_uuid_map = { summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result node.summary: node.uuid for result in search_results for node in result
} }
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] reranked_uuids = [
summary_to_uuid_map[fact]
for fact, score in reranked_summaries
if score >= reranker_min_score
]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

View file

@ -97,6 +97,7 @@ class SearchConfig(BaseModel):
node_config: NodeSearchConfig | None = Field(default=None) node_config: NodeSearchConfig | None = Field(default=None)
community_config: CommunitySearchConfig | None = Field(default=None) community_config: CommunitySearchConfig | None = Field(default=None)
limit: int = Field(default=DEFAULT_SEARCH_LIMIT) limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
reranker_min_score: float = Field(default=0)
class SearchResults(BaseModel): class SearchResults(BaseModel):

View file

@ -229,8 +229,8 @@ async def edge_similarity_search(
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
""" """
+ group_filter_query + group_filter_query
+ filter_query + filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -718,7 +718,7 @@ async def get_relevant_edges(
# takes in a list of rankings of uuids # takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]: def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
scores: dict[str, float] = defaultdict(float) scores: dict[str, float] = defaultdict(float)
for result in results: for result in results:
for i, uuid in enumerate(result): for i, uuid in enumerate(result):
@ -729,11 +729,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
sorted_uuids = [term[0] for term in scored_uuids] sorted_uuids = [term[0] for term in scored_uuids]
return sorted_uuids return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
async def node_distance_reranker( async def node_distance_reranker(
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str driver: AsyncDriver,
node_uuids: list[str],
center_node_uuid: str,
min_score: float = 0,
) -> list[str]: ) -> list[str]:
# filter out node_uuid center node node uuid # filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids)) filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
@ -767,12 +770,15 @@ async def node_distance_reranker(
# add back in filtered center uuid if it was filtered out # add back in filtered center uuid if it was filtered out
if center_node_uuid in node_uuids: if center_node_uuid in node_uuids:
scores[center_node_uuid] = 0.1
filtered_uuids = [center_node_uuid] + filtered_uuids filtered_uuids = [center_node_uuid] + filtered_uuids
return filtered_uuids return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]: async def episode_mentions_reranker(
driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
) -> list[str]:
# use rrf as a preliminary ranker # use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids) sorted_uuids = rrf(node_uuids)
scores: dict[str, float] = {} scores: dict[str, float] = {}
@ -796,13 +802,14 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
# rerank on shortest distance # rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
def maximal_marginal_relevance( def maximal_marginal_relevance(
query_vector: list[float], query_vector: list[float],
candidates: list[tuple[str, list[float]]], candidates: list[tuple[str, list[float]]],
mmr_lambda: float = DEFAULT_MMR_LAMBDA, mmr_lambda: float = DEFAULT_MMR_LAMBDA,
min_score: float = 0,
): ):
candidates_with_mmr: list[tuple[str, float]] = [] candidates_with_mmr: list[tuple[str, float]] = []
for candidate in candidates: for candidate in candidates:
@ -812,4 +819,6 @@ def maximal_marginal_relevance(
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1]) candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
return list(set([candidate[0] for candidate in candidates_with_mmr])) return list(
set([candidate[0] for candidate in candidates_with_mmr if candidate[1] >= min_score])
)

View file

@ -66,7 +66,7 @@ async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
results = await graphiti.search_( results = await graphiti.search_(
query="Who is Alice's friend?", query='Who is the user?',
) )
pretty_results = search_results_to_context_string(results) pretty_results = search_results_to_context_string(results)