Return reranker scores (#758)

* add search reranker scores to search output

* bump version

* updates
This commit is contained in:
Preston Rasmussen 2025-07-23 16:05:48 -04:00 committed by GitHub
parent 266f3f396c
commit 17747ff58d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 88 additions and 62 deletions

View file

@ -959,7 +959,7 @@ class Graphiti:
nodes = await get_mentioned_nodes(self.driver, episodes)
return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
return SearchResults(edges=edges, nodes=nodes)
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
if source_node.name_embedding is None:

View file

@ -80,12 +80,7 @@ async def search(
cross_encoder = clients.cross_encoder
if query.strip() == '':
return SearchResults(
edges=[],
nodes=[],
episodes=[],
communities=[],
)
return SearchResults()
query_vector = (
query_vector
if query_vector is not None
@ -94,7 +89,12 @@ async def search(
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids and group_ids != [''] else None
edges, nodes, episodes, communities = await semaphore_gather(
(
(edges, edge_reranker_scores),
(nodes, node_reranker_scores),
(episodes, episode_reranker_scores),
(communities, community_reranker_scores),
) = await semaphore_gather(
edge_search(
driver,
cross_encoder,
@ -146,9 +146,13 @@ async def search(
results = SearchResults(
edges=edges,
edge_reranker_scores=edge_reranker_scores,
nodes=nodes,
node_reranker_scores=node_reranker_scores,
episodes=episodes,
episode_reranker_scores=episode_reranker_scores,
communities=communities,
community_reranker_scores=community_reranker_scores,
)
latency = (time() - start) * 1000
@ -170,9 +174,9 @@ async def edge_search(
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EntityEdge]:
) -> tuple[list[EntityEdge], list[float]]:
if config is None:
return []
return [], []
search_results: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
@ -215,15 +219,16 @@ async def edge_search(
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = []
edge_scores: list[float] = []
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_edges(
driver, list(edge_uuid_map.values())
)
reranked_uuids = maximal_marginal_relevance(
reranked_uuids, edge_scores = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
@ -235,12 +240,13 @@ async def edge_search(
reranked_uuids = [
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
]
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
# use rrf as a preliminary sort
sorted_result_uuids = rrf(
sorted_result_uuids, node_scores = rrf(
[[edge.uuid for edge in result] for result in search_results],
min_score=reranker_min_score,
)
@ -253,7 +259,7 @@ async def edge_search(
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids = await node_distance_reranker(
reranked_node_uuids, edge_scores = await node_distance_reranker(
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
)
@ -265,7 +271,7 @@ async def edge_search(
if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
return reranked_edges[:limit]
return reranked_edges[:limit], edge_scores[:limit]
async def node_search(
@ -280,9 +286,9 @@ async def node_search(
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EntityNode]:
) -> tuple[list[EntityNode], list[float]]:
if config is None:
return []
return [], []
search_results: list[list[EntityNode]] = list(
await semaphore_gather(
*[
@ -319,14 +325,15 @@ async def node_search(
node_uuid_map = {node.uuid: node for result in search_results for node in result}
reranked_uuids: list[str] = []
node_scores: list[float] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
driver, list(node_uuid_map.values())
)
reranked_uuids = maximal_marginal_relevance(
reranked_uuids, node_scores = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
@ -341,23 +348,24 @@ async def node_search(
for name, score in reranked_node_names
if score >= reranker_min_score
]
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(
reranked_uuids, node_scores = await episode_mentions_reranker(
driver, search_result_uuids, min_score=reranker_min_score
)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
reranked_uuids, node_scores = await node_distance_reranker(
driver,
rrf(search_result_uuids, min_score=reranker_min_score),
rrf(search_result_uuids, min_score=reranker_min_score)[0],
center_node_uuid,
min_score=reranker_min_score,
)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes[:limit]
return reranked_nodes[:limit], node_scores[:limit]
async def episode_search(
@ -370,9 +378,9 @@ async def episode_search(
search_filter: SearchFilters,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EpisodicNode]:
) -> tuple[list[EpisodicNode], list[float]]:
if config is None:
return []
return [], []
search_results: list[list[EpisodicNode]] = list(
await semaphore_gather(
*[
@ -385,12 +393,13 @@ async def episode_search(
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
reranked_uuids: list[str] = []
episode_scores: list[float] = []
if config.reranker == EpisodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EpisodeReranker.cross_encoder:
# use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
@ -401,10 +410,11 @@ async def episode_search(
for content, score in reranked_contents
if score >= reranker_min_score
]
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_episodes[:limit]
return reranked_episodes[:limit], episode_scores[:limit]
async def community_search(
@ -416,9 +426,9 @@ async def community_search(
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[CommunityNode]:
) -> tuple[list[CommunityNode], list[float]]:
if config is None:
return []
return [], []
search_results: list[list[CommunityNode]] = list(
await semaphore_gather(
@ -437,14 +447,15 @@ async def community_search(
}
reranked_uuids: list[str] = []
community_scores: list[float] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_communities(
driver, list(community_uuid_map.values())
)
reranked_uuids = maximal_marginal_relevance(
reranked_uuids, community_scores = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
)
elif config.reranker == CommunityReranker.cross_encoder:
@ -453,7 +464,8 @@ async def community_search(
reranked_uuids = [
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
]
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities[:limit]
return reranked_communities[:limit], community_scores[:limit]

View file

@ -119,7 +119,11 @@ class SearchConfig(BaseModel):
class SearchResults(BaseModel):
edges: list[EntityEdge]
nodes: list[EntityNode]
episodes: list[EpisodicNode]
communities: list[CommunityNode]
edges: list[EntityEdge] = Field(default_factory=list)
edge_reranker_scores: list[float] = Field(default_factory=list)
nodes: list[EntityNode] = Field(default_factory=list)
node_reranker_scores: list[float] = Field(default_factory=list)
episodes: list[EpisodicNode] = Field(default_factory=list)
episode_reranker_scores: list[float] = Field(default_factory=list)
communities: list[CommunityNode] = Field(default_factory=list)
community_reranker_scores: list[float] = Field(default_factory=list)

View file

@ -294,13 +294,13 @@ async def edge_bfs_search(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
AND r.group_id IN $group_ids
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
AND r.group_id IN $group_ids
"""
+ filter_query
+ """
RETURN DISTINCT
@ -445,11 +445,11 @@ async def node_bfs_search(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
AND origin.group_id IN $group_ids
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
AND origin.group_id IN $group_ids
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@ -672,7 +672,7 @@ async def hybrid_node_search(
}
result_uuids = [[node.uuid for node in result] for result in results]
ranked_uuids = rrf(result_uuids)
ranked_uuids, _ = rrf(result_uuids)
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
@ -914,7 +914,9 @@ async def get_edge_invalidation_candidates(
# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
def rrf(
results: list[list[str]], rank_const=1, min_score: float = 0
) -> tuple[list[str], list[float]]:
scores: dict[str, float] = defaultdict(float)
for result in results:
for i, uuid in enumerate(result):
@ -925,7 +927,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
sorted_uuids = [term[0] for term in scored_uuids]
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
]
async def node_distance_reranker(
@ -933,7 +937,7 @@ async def node_distance_reranker(
node_uuids: list[str],
center_node_uuid: str,
min_score: float = 0,
) -> list[str]:
) -> tuple[list[str], list[float]]:
# filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {center_node_uuid: 0.0}
@ -970,14 +974,16 @@ async def node_distance_reranker(
scores[center_node_uuid] = 0.1
filtered_uuids = [center_node_uuid] + filtered_uuids
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
]
async def episode_mentions_reranker(
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
) -> list[str]:
) -> tuple[list[str], list[float]]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids)
sorted_uuids, _ = rrf(node_uuids)
scores: dict[str, float] = {}
# Find the shortest path to center node
@ -998,7 +1004,9 @@ async def episode_mentions_reranker(
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
]
def maximal_marginal_relevance(
@ -1006,7 +1014,7 @@ def maximal_marginal_relevance(
candidates: dict[str, list[float]],
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
min_score: float = -2.0,
) -> list[str]:
) -> tuple[list[str], list[float]]:
start = time()
query_array = np.array(query_vector)
candidate_arrays: dict[str, NDArray] = {}
@ -1037,7 +1045,9 @@ def maximal_marginal_relevance(
end = time()
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
]
async def get_embeddings_for_nodes(

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.17.11"
version = "0.18.0"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

View file

@ -71,7 +71,7 @@ async def test_graphiti_init():
)
results = await graphiti.search_(
query='What is the hall of portrait?',
query='Who is Tania',
search_filter=search_filter,
)

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.17.10"
version = "0.18.0"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },