add reranker_min_score (#355)

* add reranker_min_score

* update divide by 0 case

* center node always gets a score of .1

* linter
This commit is contained in:
Preston Rasmussen 2025-04-15 12:33:37 -04:00 committed by GitHub
parent c8d5c45269
commit 11e19a35b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 79 additions and 36 deletions

View file

@ -17,9 +17,9 @@ limitations under the License.
import logging
from typing import Any
import numpy as np
import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
from pydantic import BaseModel
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError
@ -31,10 +31,6 @@ logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4.1-nano'
class BooleanClassifier(BaseModel):
isTrue: bool
class OpenAIRerankerClient(CrossEncoderClient):
def __init__(
self,
@ -107,11 +103,15 @@ class OpenAIRerankerClient(CrossEncoderClient):
]
scores: list[float] = []
for top_logprobs in responses_top_logprobs:
for logprob in top_logprobs:
if bool(logprob.token):
scores.append(logprob.logprob)
if len(top_logprobs) == 0:
continue
norm_logprobs = np.exp(top_logprobs[0].logprob)
if bool(top_logprobs[0].token):
scores.append(norm_logprobs)
else:
scores.append(1 - norm_logprobs)
results = [(passage, score) for passage, score in zip(passages, scores, strict=False)]
results = [(passage, score) for passage, score in zip(passages, scores, strict=True)]
results.sort(reverse=True, key=lambda x: x[1])
return results
except openai.RateLimitError as e:

View file

@ -92,6 +92,7 @@ async def search(
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
node_search(
driver,
@ -104,6 +105,7 @@ async def search(
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
community_search(
driver,
@ -112,8 +114,8 @@ async def search(
query_vector,
group_ids,
config.community_config,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
)
@ -141,6 +143,7 @@ async def edge_search(
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EntityEdge]:
if config is None:
return []
@ -180,7 +183,7 @@ async def edge_search(
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids)
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = [
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
@ -188,23 +191,31 @@ async def edge_search(
for edge in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
min_score=reranker_min_score,
)
elif config.reranker == EdgeReranker.cross_encoder:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
rrf_result_uuids = rrf(search_result_uuids)
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
reranked_uuids = [
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
# use rrf as a preliminary sort
sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
sorted_result_uuids = rrf(
[[edge.uuid for edge in result] for result in search_results],
min_score=reranker_min_score,
)
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
# node distance reranking
@ -214,7 +225,9 @@ async def edge_search(
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
reranked_node_uuids = await node_distance_reranker(
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
)
for node_uuid in reranked_node_uuids:
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
@ -238,6 +251,7 @@ async def node_search(
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EntityNode]:
if config is None:
return []
@ -269,7 +283,7 @@ async def node_search(
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = [
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
@ -277,24 +291,36 @@ async def node_search(
for node in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
min_score=reranker_min_score,
)
elif config.reranker == NodeReranker.cross_encoder:
# use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids)
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
reranked_uuids = [
summary_to_uuid_map[fact]
for fact, score in reranked_summaries
if score >= reranker_min_score
]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
reranked_uuids = await episode_mentions_reranker(
driver, search_result_uuids, min_score=reranker_min_score
)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, rrf(search_result_uuids), center_node_uuid
driver,
rrf(search_result_uuids, min_score=reranker_min_score),
center_node_uuid,
min_score=reranker_min_score,
)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
@ -309,8 +335,8 @@ async def community_search(
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[CommunityNode]:
if config is None:
return []
@ -333,7 +359,7 @@ async def community_search(
reranked_uuids: list[str] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = [
(
@ -344,14 +370,21 @@ async def community_search(
for community in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
min_score=reranker_min_score,
)
elif config.reranker == CommunityReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
reranked_uuids = [
summary_to_uuid_map[fact]
for fact, score in reranked_summaries
if score >= reranker_min_score
]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

View file

@ -97,6 +97,7 @@ class SearchConfig(BaseModel):
node_config: NodeSearchConfig | None = Field(default=None)
community_config: CommunitySearchConfig | None = Field(default=None)
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
reranker_min_score: float = Field(default=0)
class SearchResults(BaseModel):

View file

@ -229,8 +229,8 @@ async def edge_similarity_search(
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -718,7 +718,7 @@ async def get_relevant_edges(
# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
scores: dict[str, float] = defaultdict(float)
for result in results:
for i, uuid in enumerate(result):
@ -729,11 +729,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
sorted_uuids = [term[0] for term in scored_uuids]
return sorted_uuids
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
async def node_distance_reranker(
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
driver: AsyncDriver,
node_uuids: list[str],
center_node_uuid: str,
min_score: float = 0,
) -> list[str]:
# filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
@ -767,12 +770,15 @@ async def node_distance_reranker(
# add back in filtered center uuid if it was filtered out
if center_node_uuid in node_uuids:
scores[center_node_uuid] = 0.1
filtered_uuids = [center_node_uuid] + filtered_uuids
return filtered_uuids
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
async def episode_mentions_reranker(
driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids)
scores: dict[str, float] = {}
@ -796,13 +802,14 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
def maximal_marginal_relevance(
query_vector: list[float],
candidates: list[tuple[str, list[float]]],
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
min_score: float = 0,
):
candidates_with_mmr: list[tuple[str, float]] = []
for candidate in candidates:
@ -812,4 +819,6 @@ def maximal_marginal_relevance(
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
return list(set([candidate[0] for candidate in candidates_with_mmr]))
return list(
set([candidate[0] for candidate in candidates_with_mmr if candidate[1] >= min_score])
)

View file

@ -66,7 +66,7 @@ async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
results = await graphiti.search_(
query="Who is Alice's friend?",
query='Who is the user?',
)
pretty_results = search_results_to_context_string(results)