add reranker_min_score (#355)
* add reranker_min_score * update divide by 0 case * center node always gets a score of .1 * linter
This commit is contained in:
parent
c8d5c45269
commit
11e19a35b7
5 changed files with 79 additions and 36 deletions
|
|
@ -17,9 +17,9 @@ limitations under the License.
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..helpers import semaphore_gather
|
||||
from ..llm_client import LLMConfig, RateLimitError
|
||||
|
|
@ -31,10 +31,6 @@ logger = logging.getLogger(__name__)
|
|||
DEFAULT_MODEL = 'gpt-4.1-nano'
|
||||
|
||||
|
||||
class BooleanClassifier(BaseModel):
|
||||
isTrue: bool
|
||||
|
||||
|
||||
class OpenAIRerankerClient(CrossEncoderClient):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -107,11 +103,15 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|||
]
|
||||
scores: list[float] = []
|
||||
for top_logprobs in responses_top_logprobs:
|
||||
for logprob in top_logprobs:
|
||||
if bool(logprob.token):
|
||||
scores.append(logprob.logprob)
|
||||
if len(top_logprobs) == 0:
|
||||
continue
|
||||
norm_logprobs = np.exp(top_logprobs[0].logprob)
|
||||
if bool(top_logprobs[0].token):
|
||||
scores.append(norm_logprobs)
|
||||
else:
|
||||
scores.append(1 - norm_logprobs)
|
||||
|
||||
results = [(passage, score) for passage, score in zip(passages, scores, strict=False)]
|
||||
results = [(passage, score) for passage, score in zip(passages, scores, strict=True)]
|
||||
results.sort(reverse=True, key=lambda x: x[1])
|
||||
return results
|
||||
except openai.RateLimitError as e:
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ async def search(
|
|||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
config.reranker_min_score,
|
||||
),
|
||||
node_search(
|
||||
driver,
|
||||
|
|
@ -104,6 +105,7 @@ async def search(
|
|||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
config.reranker_min_score,
|
||||
),
|
||||
community_search(
|
||||
driver,
|
||||
|
|
@ -112,8 +114,8 @@ async def search(
|
|||
query_vector,
|
||||
group_ids,
|
||||
config.community_config,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
config.reranker_min_score,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -141,6 +143,7 @@ async def edge_search(
|
|||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -180,7 +183,7 @@ async def edge_search(
|
|||
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == EdgeReranker.mmr:
|
||||
search_result_uuids_and_vectors = [
|
||||
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
||||
|
|
@ -188,23 +191,31 @@ async def edge_search(
|
|||
for edge in result
|
||||
]
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
min_score=reranker_min_score,
|
||||
)
|
||||
elif config.reranker == EdgeReranker.cross_encoder:
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
rrf_result_uuids = rrf(search_result_uuids)
|
||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
|
||||
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
||||
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
||||
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
|
||||
reranked_uuids = [
|
||||
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
||||
]
|
||||
elif config.reranker == EdgeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
|
||||
# use rrf as a preliminary sort
|
||||
sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
|
||||
sorted_result_uuids = rrf(
|
||||
[[edge.uuid for edge in result] for result in search_results],
|
||||
min_score=reranker_min_score,
|
||||
)
|
||||
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
|
||||
|
||||
# node distance reranking
|
||||
|
|
@ -214,7 +225,9 @@ async def edge_search(
|
|||
|
||||
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
||||
|
||||
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
||||
reranked_node_uuids = await node_distance_reranker(
|
||||
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
||||
)
|
||||
|
||||
for node_uuid in reranked_node_uuids:
|
||||
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
|
||||
|
|
@ -238,6 +251,7 @@ async def node_search(
|
|||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -269,7 +283,7 @@ async def node_search(
|
|||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == NodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == NodeReranker.mmr:
|
||||
search_result_uuids_and_vectors = [
|
||||
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
||||
|
|
@ -277,24 +291,36 @@ async def node_search(
|
|||
for node in result
|
||||
]
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
min_score=reranker_min_score,
|
||||
)
|
||||
elif config.reranker == NodeReranker.cross_encoder:
|
||||
# use rrf as a preliminary reranker
|
||||
rrf_result_uuids = rrf(search_result_uuids)
|
||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
|
||||
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
|
||||
|
||||
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
||||
reranked_uuids = [
|
||||
summary_to_uuid_map[fact]
|
||||
for fact, score in reranked_summaries
|
||||
if score >= reranker_min_score
|
||||
]
|
||||
elif config.reranker == NodeReranker.episode_mentions:
|
||||
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
||||
reranked_uuids = await episode_mentions_reranker(
|
||||
driver, search_result_uuids, min_score=reranker_min_score
|
||||
)
|
||||
elif config.reranker == NodeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
reranked_uuids = await node_distance_reranker(
|
||||
driver, rrf(search_result_uuids), center_node_uuid
|
||||
driver,
|
||||
rrf(search_result_uuids, min_score=reranker_min_score),
|
||||
center_node_uuid,
|
||||
min_score=reranker_min_score,
|
||||
)
|
||||
|
||||
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
|
@ -309,8 +335,8 @@ async def community_search(
|
|||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[CommunityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -333,7 +359,7 @@ async def community_search(
|
|||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == CommunityReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == CommunityReranker.mmr:
|
||||
search_result_uuids_and_vectors = [
|
||||
(
|
||||
|
|
@ -344,14 +370,21 @@ async def community_search(
|
|||
for community in result
|
||||
]
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
min_score=reranker_min_score,
|
||||
)
|
||||
elif config.reranker == CommunityReranker.cross_encoder:
|
||||
summary_to_uuid_map = {
|
||||
node.summary: node.uuid for result in search_results for node in result
|
||||
}
|
||||
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
||||
reranked_uuids = [
|
||||
summary_to_uuid_map[fact]
|
||||
for fact, score in reranked_summaries
|
||||
if score >= reranker_min_score
|
||||
]
|
||||
|
||||
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ class SearchConfig(BaseModel):
|
|||
node_config: NodeSearchConfig | None = Field(default=None)
|
||||
community_config: CommunitySearchConfig | None = Field(default=None)
|
||||
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
||||
reranker_min_score: float = Field(default=0)
|
||||
|
||||
|
||||
class SearchResults(BaseModel):
|
||||
|
|
|
|||
|
|
@ -229,8 +229,8 @@ async def edge_similarity_search(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
|
|
@ -718,7 +718,7 @@ async def get_relevant_edges(
|
|||
|
||||
|
||||
# takes in a list of rankings of uuids
|
||||
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
|
||||
scores: dict[str, float] = defaultdict(float)
|
||||
for result in results:
|
||||
for i, uuid in enumerate(result):
|
||||
|
|
@ -729,11 +729,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|||
|
||||
sorted_uuids = [term[0] for term in scored_uuids]
|
||||
|
||||
return sorted_uuids
|
||||
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
||||
|
||||
|
||||
async def node_distance_reranker(
|
||||
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
||||
driver: AsyncDriver,
|
||||
node_uuids: list[str],
|
||||
center_node_uuid: str,
|
||||
min_score: float = 0,
|
||||
) -> list[str]:
|
||||
# filter out node_uuid center node node uuid
|
||||
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
||||
|
|
@ -767,12 +770,15 @@ async def node_distance_reranker(
|
|||
|
||||
# add back in filtered center uuid if it was filtered out
|
||||
if center_node_uuid in node_uuids:
|
||||
scores[center_node_uuid] = 0.1
|
||||
filtered_uuids = [center_node_uuid] + filtered_uuids
|
||||
|
||||
return filtered_uuids
|
||||
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
|
||||
|
||||
|
||||
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
||||
async def episode_mentions_reranker(
|
||||
driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
|
||||
) -> list[str]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(node_uuids)
|
||||
scores: dict[str, float] = {}
|
||||
|
|
@ -796,13 +802,14 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
|||
# rerank on shortest distance
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
return sorted_uuids
|
||||
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
query_vector: list[float],
|
||||
candidates: list[tuple[str, list[float]]],
|
||||
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
||||
min_score: float = 0,
|
||||
):
|
||||
candidates_with_mmr: list[tuple[str, float]] = []
|
||||
for candidate in candidates:
|
||||
|
|
@ -812,4 +819,6 @@ def maximal_marginal_relevance(
|
|||
|
||||
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
||||
|
||||
return list(set([candidate[0] for candidate in candidates_with_mmr]))
|
||||
return list(
|
||||
set([candidate[0] for candidate in candidates_with_mmr if candidate[1] >= min_score])
|
||||
)
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ async def test_graphiti_init():
|
|||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
|
||||
results = await graphiti.search_(
|
||||
query="Who is Alice's friend?",
|
||||
query='Who is the user?',
|
||||
)
|
||||
|
||||
pretty_results = search_results_to_context_string(results)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue