Return reranker scores (#758)

* add search reranker scores to search output

* bump version

* updates
This commit is contained in:
Preston Rasmussen 2025-07-23 16:05:48 -04:00 committed by GitHub
parent 266f3f396c
commit 17747ff58d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 88 additions and 62 deletions

View file

@ -959,7 +959,7 @@ class Graphiti:
nodes = await get_mentioned_nodes(self.driver, episodes) nodes = await get_mentioned_nodes(self.driver, episodes)
return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[]) return SearchResults(edges=edges, nodes=nodes)
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode): async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
if source_node.name_embedding is None: if source_node.name_embedding is None:

View file

@ -80,12 +80,7 @@ async def search(
cross_encoder = clients.cross_encoder cross_encoder = clients.cross_encoder
if query.strip() == '': if query.strip() == '':
return SearchResults( return SearchResults()
edges=[],
nodes=[],
episodes=[],
communities=[],
)
query_vector = ( query_vector = (
query_vector query_vector
if query_vector is not None if query_vector is not None
@ -94,7 +89,12 @@ async def search(
# if group_ids is empty, set it to None # if group_ids is empty, set it to None
group_ids = group_ids if group_ids and group_ids != [''] else None group_ids = group_ids if group_ids and group_ids != [''] else None
edges, nodes, episodes, communities = await semaphore_gather( (
(edges, edge_reranker_scores),
(nodes, node_reranker_scores),
(episodes, episode_reranker_scores),
(communities, community_reranker_scores),
) = await semaphore_gather(
edge_search( edge_search(
driver, driver,
cross_encoder, cross_encoder,
@ -146,9 +146,13 @@ async def search(
results = SearchResults( results = SearchResults(
edges=edges, edges=edges,
edge_reranker_scores=edge_reranker_scores,
nodes=nodes, nodes=nodes,
node_reranker_scores=node_reranker_scores,
episodes=episodes, episodes=episodes,
episode_reranker_scores=episode_reranker_scores,
communities=communities, communities=communities,
community_reranker_scores=community_reranker_scores,
) )
latency = (time() - start) * 1000 latency = (time() - start) * 1000
@ -170,9 +174,9 @@ async def edge_search(
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0, reranker_min_score: float = 0,
) -> list[EntityEdge]: ) -> tuple[list[EntityEdge], list[float]]:
if config is None: if config is None:
return [] return [], []
search_results: list[list[EntityEdge]] = list( search_results: list[list[EntityEdge]] = list(
await semaphore_gather( await semaphore_gather(
*[ *[
@ -215,15 +219,16 @@ async def edge_search(
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
edge_scores: list[float] = []
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions: if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results] search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr: elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_edges( search_result_uuids_and_vectors = await get_embeddings_for_edges(
driver, list(edge_uuid_map.values()) driver, list(edge_uuid_map.values())
) )
reranked_uuids = maximal_marginal_relevance( reranked_uuids, edge_scores = maximal_marginal_relevance(
query_vector, query_vector,
search_result_uuids_and_vectors, search_result_uuids_and_vectors,
config.mmr_lambda, config.mmr_lambda,
@ -235,12 +240,13 @@ async def edge_search(
reranked_uuids = [ reranked_uuids = [
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
] ]
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
elif config.reranker == EdgeReranker.node_distance: elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
# use rrf as a preliminary sort # use rrf as a preliminary sort
sorted_result_uuids = rrf( sorted_result_uuids, node_scores = rrf(
[[edge.uuid for edge in result] for result in search_results], [[edge.uuid for edge in result] for result in search_results],
min_score=reranker_min_score, min_score=reranker_min_score,
) )
@ -253,7 +259,7 @@ async def edge_search(
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map] source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids = await node_distance_reranker( reranked_node_uuids, edge_scores = await node_distance_reranker(
driver, source_uuids, center_node_uuid, min_score=reranker_min_score driver, source_uuids, center_node_uuid, min_score=reranker_min_score
) )
@ -265,7 +271,7 @@ async def edge_search(
if config.reranker == EdgeReranker.episode_mentions: if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes)) reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
return reranked_edges[:limit] return reranked_edges[:limit], edge_scores[:limit]
async def node_search( async def node_search(
@ -280,9 +286,9 @@ async def node_search(
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0, reranker_min_score: float = 0,
) -> list[EntityNode]: ) -> tuple[list[EntityNode], list[float]]:
if config is None: if config is None:
return [] return [], []
search_results: list[list[EntityNode]] = list( search_results: list[list[EntityNode]] = list(
await semaphore_gather( await semaphore_gather(
*[ *[
@ -319,14 +325,15 @@ async def node_search(
node_uuid_map = {node.uuid: node for result in search_results for node in result} node_uuid_map = {node.uuid: node for result in search_results for node in result}
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
node_scores: list[float] = []
if config.reranker == NodeReranker.rrf: if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr: elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_nodes( search_result_uuids_and_vectors = await get_embeddings_for_nodes(
driver, list(node_uuid_map.values()) driver, list(node_uuid_map.values())
) )
reranked_uuids = maximal_marginal_relevance( reranked_uuids, node_scores = maximal_marginal_relevance(
query_vector, query_vector,
search_result_uuids_and_vectors, search_result_uuids_and_vectors,
config.mmr_lambda, config.mmr_lambda,
@ -341,23 +348,24 @@ async def node_search(
for name, score in reranked_node_names for name, score in reranked_node_names
if score >= reranker_min_score if score >= reranker_min_score
] ]
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
elif config.reranker == NodeReranker.episode_mentions: elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker( reranked_uuids, node_scores = await episode_mentions_reranker(
driver, search_result_uuids, min_score=reranker_min_score driver, search_result_uuids, min_score=reranker_min_score
) )
elif config.reranker == NodeReranker.node_distance: elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker( reranked_uuids, node_scores = await node_distance_reranker(
driver, driver,
rrf(search_result_uuids, min_score=reranker_min_score), rrf(search_result_uuids, min_score=reranker_min_score)[0],
center_node_uuid, center_node_uuid,
min_score=reranker_min_score, min_score=reranker_min_score,
) )
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes[:limit] return reranked_nodes[:limit], node_scores[:limit]
async def episode_search( async def episode_search(
@ -370,9 +378,9 @@ async def episode_search(
search_filter: SearchFilters, search_filter: SearchFilters,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0, reranker_min_score: float = 0,
) -> list[EpisodicNode]: ) -> tuple[list[EpisodicNode], list[float]]:
if config is None: if config is None:
return [] return [], []
search_results: list[list[EpisodicNode]] = list( search_results: list[list[EpisodicNode]] = list(
await semaphore_gather( await semaphore_gather(
*[ *[
@ -385,12 +393,13 @@ async def episode_search(
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result} episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
episode_scores: list[float] = []
if config.reranker == EpisodeReranker.rrf: if config.reranker == EpisodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EpisodeReranker.cross_encoder: elif config.reranker == EpisodeReranker.cross_encoder:
# use rrf as a preliminary reranker # use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score) rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results} content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
@ -401,10 +410,11 @@ async def episode_search(
for content, score in reranked_contents for content, score in reranked_contents
if score >= reranker_min_score if score >= reranker_min_score
] ]
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids] reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_episodes[:limit] return reranked_episodes[:limit], episode_scores[:limit]
async def community_search( async def community_search(
@ -416,9 +426,9 @@ async def community_search(
config: CommunitySearchConfig | None, config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0, reranker_min_score: float = 0,
) -> list[CommunityNode]: ) -> tuple[list[CommunityNode], list[float]]:
if config is None: if config is None:
return [] return [], []
search_results: list[list[CommunityNode]] = list( search_results: list[list[CommunityNode]] = list(
await semaphore_gather( await semaphore_gather(
@ -437,14 +447,15 @@ async def community_search(
} }
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
community_scores: list[float] = []
if config.reranker == CommunityReranker.rrf: if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr: elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_communities( search_result_uuids_and_vectors = await get_embeddings_for_communities(
driver, list(community_uuid_map.values()) driver, list(community_uuid_map.values())
) )
reranked_uuids = maximal_marginal_relevance( reranked_uuids, community_scores = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
) )
elif config.reranker == CommunityReranker.cross_encoder: elif config.reranker == CommunityReranker.cross_encoder:
@ -453,7 +464,8 @@ async def community_search(
reranked_uuids = [ reranked_uuids = [
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
] ]
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities[:limit] return reranked_communities[:limit], community_scores[:limit]

View file

@ -119,7 +119,11 @@ class SearchConfig(BaseModel):
class SearchResults(BaseModel): class SearchResults(BaseModel):
edges: list[EntityEdge] edges: list[EntityEdge] = Field(default_factory=list)
nodes: list[EntityNode] edge_reranker_scores: list[float] = Field(default_factory=list)
episodes: list[EpisodicNode] nodes: list[EntityNode] = Field(default_factory=list)
communities: list[CommunityNode] node_reranker_scores: list[float] = Field(default_factory=list)
episodes: list[EpisodicNode] = Field(default_factory=list)
episode_reranker_scores: list[float] = Field(default_factory=list)
communities: list[CommunityNode] = Field(default_factory=list)
community_reranker_scores: list[float] = Field(default_factory=list)

View file

@ -294,13 +294,13 @@ async def edge_bfs_search(
query = ( query = (
""" """
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid WHERE r.uuid = rel.uuid
AND r.group_id IN $group_ids AND r.group_id IN $group_ids
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT RETURN DISTINCT
@ -445,11 +445,11 @@ async def node_bfs_search(
query = ( query = (
""" """
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id WHERE n.group_id = origin.group_id
AND origin.group_id IN $group_ids AND origin.group_id IN $group_ids
""" """
+ filter_query + filter_query
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
+ """ + """
@ -672,7 +672,7 @@ async def hybrid_node_search(
} }
result_uuids = [[node.uuid for node in result] for result in results] result_uuids = [[node.uuid for node in result] for result in results]
ranked_uuids = rrf(result_uuids) ranked_uuids, _ = rrf(result_uuids)
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids] relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
@ -914,7 +914,9 @@ async def get_edge_invalidation_candidates(
# takes in a list of rankings of uuids # takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]: def rrf(
results: list[list[str]], rank_const=1, min_score: float = 0
) -> tuple[list[str], list[float]]:
scores: dict[str, float] = defaultdict(float) scores: dict[str, float] = defaultdict(float)
for result in results: for result in results:
for i, uuid in enumerate(result): for i, uuid in enumerate(result):
@ -925,7 +927,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
sorted_uuids = [term[0] for term in scored_uuids] sorted_uuids = [term[0] for term in scored_uuids]
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score] return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
]
async def node_distance_reranker( async def node_distance_reranker(
@ -933,7 +937,7 @@ async def node_distance_reranker(
node_uuids: list[str], node_uuids: list[str],
center_node_uuid: str, center_node_uuid: str,
min_score: float = 0, min_score: float = 0,
) -> list[str]: ) -> tuple[list[str], list[float]]:
# filter out node_uuid center node node uuid # filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids)) filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {center_node_uuid: 0.0} scores: dict[str, float] = {center_node_uuid: 0.0}
@ -970,14 +974,16 @@ async def node_distance_reranker(
scores[center_node_uuid] = 0.1 scores[center_node_uuid] = 0.1
filtered_uuids = [center_node_uuid] + filtered_uuids filtered_uuids = [center_node_uuid] + filtered_uuids
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score] return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
]
async def episode_mentions_reranker( async def episode_mentions_reranker(
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0 driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
) -> list[str]: ) -> tuple[list[str], list[float]]:
# use rrf as a preliminary ranker # use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids) sorted_uuids, _ = rrf(node_uuids)
scores: dict[str, float] = {} scores: dict[str, float] = {}
# Find the shortest path to center node # Find the shortest path to center node
@ -998,7 +1004,9 @@ async def episode_mentions_reranker(
# rerank on shortest distance # rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score] return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
]
def maximal_marginal_relevance( def maximal_marginal_relevance(
@ -1006,7 +1014,7 @@ def maximal_marginal_relevance(
candidates: dict[str, list[float]], candidates: dict[str, list[float]],
mmr_lambda: float = DEFAULT_MMR_LAMBDA, mmr_lambda: float = DEFAULT_MMR_LAMBDA,
min_score: float = -2.0, min_score: float = -2.0,
) -> list[str]: ) -> tuple[list[str], list[float]]:
start = time() start = time()
query_array = np.array(query_vector) query_array = np.array(query_vector)
candidate_arrays: dict[str, NDArray] = {} candidate_arrays: dict[str, NDArray] = {}
@ -1037,7 +1045,9 @@ def maximal_marginal_relevance(
end = time() end = time()
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms') logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score] return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
]
async def get_embeddings_for_nodes( async def get_embeddings_for_nodes(

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.17.11" version = "0.18.0"
authors = [ authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

View file

@ -71,7 +71,7 @@ async def test_graphiti_init():
) )
results = await graphiti.search_( results = await graphiti.search_(
query='What is the hall of portrait?', query='Who is Tania',
search_filter=search_filter, search_filter=search_filter,
) )

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.17.10" version = "0.18.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },