Return reranker scores (#758)
* add search reranker scores to search output * bump version * updates
This commit is contained in:
parent
266f3f396c
commit
17747ff58d
7 changed files with 88 additions and 62 deletions
|
|
@ -959,7 +959,7 @@ class Graphiti:
|
|||
|
||||
nodes = await get_mentioned_nodes(self.driver, episodes)
|
||||
|
||||
return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
|
||||
return SearchResults(edges=edges, nodes=nodes)
|
||||
|
||||
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
|
||||
if source_node.name_embedding is None:
|
||||
|
|
|
|||
|
|
@ -80,12 +80,7 @@ async def search(
|
|||
cross_encoder = clients.cross_encoder
|
||||
|
||||
if query.strip() == '':
|
||||
return SearchResults(
|
||||
edges=[],
|
||||
nodes=[],
|
||||
episodes=[],
|
||||
communities=[],
|
||||
)
|
||||
return SearchResults()
|
||||
query_vector = (
|
||||
query_vector
|
||||
if query_vector is not None
|
||||
|
|
@ -94,7 +89,12 @@ async def search(
|
|||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids and group_ids != [''] else None
|
||||
edges, nodes, episodes, communities = await semaphore_gather(
|
||||
(
|
||||
(edges, edge_reranker_scores),
|
||||
(nodes, node_reranker_scores),
|
||||
(episodes, episode_reranker_scores),
|
||||
(communities, community_reranker_scores),
|
||||
) = await semaphore_gather(
|
||||
edge_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
|
|
@ -146,9 +146,13 @@ async def search(
|
|||
|
||||
results = SearchResults(
|
||||
edges=edges,
|
||||
edge_reranker_scores=edge_reranker_scores,
|
||||
nodes=nodes,
|
||||
node_reranker_scores=node_reranker_scores,
|
||||
episodes=episodes,
|
||||
episode_reranker_scores=episode_reranker_scores,
|
||||
communities=communities,
|
||||
community_reranker_scores=community_reranker_scores,
|
||||
)
|
||||
|
||||
latency = (time() - start) * 1000
|
||||
|
|
@ -170,9 +174,9 @@ async def edge_search(
|
|||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[EntityEdge]:
|
||||
) -> tuple[list[EntityEdge], list[float]]:
|
||||
if config is None:
|
||||
return []
|
||||
return [], []
|
||||
search_results: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -215,15 +219,16 @@ async def edge_search(
|
|||
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
edge_scores: list[float] = []
|
||||
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == EdgeReranker.mmr:
|
||||
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
||||
driver, list(edge_uuid_map.values())
|
||||
)
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
reranked_uuids, edge_scores = maximal_marginal_relevance(
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
|
|
@ -235,12 +240,13 @@ async def edge_search(
|
|||
reranked_uuids = [
|
||||
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
||||
]
|
||||
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
|
||||
elif config.reranker == EdgeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
|
||||
# use rrf as a preliminary sort
|
||||
sorted_result_uuids = rrf(
|
||||
sorted_result_uuids, node_scores = rrf(
|
||||
[[edge.uuid for edge in result] for result in search_results],
|
||||
min_score=reranker_min_score,
|
||||
)
|
||||
|
|
@ -253,7 +259,7 @@ async def edge_search(
|
|||
|
||||
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
||||
|
||||
reranked_node_uuids = await node_distance_reranker(
|
||||
reranked_node_uuids, edge_scores = await node_distance_reranker(
|
||||
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
||||
)
|
||||
|
||||
|
|
@ -265,7 +271,7 @@ async def edge_search(
|
|||
if config.reranker == EdgeReranker.episode_mentions:
|
||||
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
||||
|
||||
return reranked_edges[:limit]
|
||||
return reranked_edges[:limit], edge_scores[:limit]
|
||||
|
||||
|
||||
async def node_search(
|
||||
|
|
@ -280,9 +286,9 @@ async def node_search(
|
|||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[EntityNode]:
|
||||
) -> tuple[list[EntityNode], list[float]]:
|
||||
if config is None:
|
||||
return []
|
||||
return [], []
|
||||
search_results: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -319,14 +325,15 @@ async def node_search(
|
|||
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
node_scores: list[float] = []
|
||||
if config.reranker == NodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == NodeReranker.mmr:
|
||||
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
||||
driver, list(node_uuid_map.values())
|
||||
)
|
||||
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
reranked_uuids, node_scores = maximal_marginal_relevance(
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
|
|
@ -341,23 +348,24 @@ async def node_search(
|
|||
for name, score in reranked_node_names
|
||||
if score >= reranker_min_score
|
||||
]
|
||||
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
|
||||
elif config.reranker == NodeReranker.episode_mentions:
|
||||
reranked_uuids = await episode_mentions_reranker(
|
||||
reranked_uuids, node_scores = await episode_mentions_reranker(
|
||||
driver, search_result_uuids, min_score=reranker_min_score
|
||||
)
|
||||
elif config.reranker == NodeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
reranked_uuids = await node_distance_reranker(
|
||||
reranked_uuids, node_scores = await node_distance_reranker(
|
||||
driver,
|
||||
rrf(search_result_uuids, min_score=reranker_min_score),
|
||||
rrf(search_result_uuids, min_score=reranker_min_score)[0],
|
||||
center_node_uuid,
|
||||
min_score=reranker_min_score,
|
||||
)
|
||||
|
||||
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_nodes[:limit]
|
||||
return reranked_nodes[:limit], node_scores[:limit]
|
||||
|
||||
|
||||
async def episode_search(
|
||||
|
|
@ -370,9 +378,9 @@ async def episode_search(
|
|||
search_filter: SearchFilters,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[EpisodicNode]:
|
||||
) -> tuple[list[EpisodicNode], list[float]]:
|
||||
if config is None:
|
||||
return []
|
||||
return [], []
|
||||
search_results: list[list[EpisodicNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -385,12 +393,13 @@ async def episode_search(
|
|||
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
episode_scores: list[float] = []
|
||||
if config.reranker == EpisodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
|
||||
elif config.reranker == EpisodeReranker.cross_encoder:
|
||||
# use rrf as a preliminary reranker
|
||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
|
||||
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
||||
|
|
@ -401,10 +410,11 @@ async def episode_search(
|
|||
for content, score in reranked_contents
|
||||
if score >= reranker_min_score
|
||||
]
|
||||
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
|
||||
|
||||
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_episodes[:limit]
|
||||
return reranked_episodes[:limit], episode_scores[:limit]
|
||||
|
||||
|
||||
async def community_search(
|
||||
|
|
@ -416,9 +426,9 @@ async def community_search(
|
|||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[CommunityNode]:
|
||||
) -> tuple[list[CommunityNode], list[float]]:
|
||||
if config is None:
|
||||
return []
|
||||
return [], []
|
||||
|
||||
search_results: list[list[CommunityNode]] = list(
|
||||
await semaphore_gather(
|
||||
|
|
@ -437,14 +447,15 @@ async def community_search(
|
|||
}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
community_scores: list[float] = []
|
||||
if config.reranker == CommunityReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == CommunityReranker.mmr:
|
||||
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
||||
driver, list(community_uuid_map.values())
|
||||
)
|
||||
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
reranked_uuids, community_scores = maximal_marginal_relevance(
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
||||
)
|
||||
elif config.reranker == CommunityReranker.cross_encoder:
|
||||
|
|
@ -453,7 +464,8 @@ async def community_search(
|
|||
reranked_uuids = [
|
||||
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
|
||||
]
|
||||
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
|
||||
|
||||
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_communities[:limit]
|
||||
return reranked_communities[:limit], community_scores[:limit]
|
||||
|
|
|
|||
|
|
@ -119,7 +119,11 @@ class SearchConfig(BaseModel):
|
|||
|
||||
|
||||
class SearchResults(BaseModel):
|
||||
edges: list[EntityEdge]
|
||||
nodes: list[EntityNode]
|
||||
episodes: list[EpisodicNode]
|
||||
communities: list[CommunityNode]
|
||||
edges: list[EntityEdge] = Field(default_factory=list)
|
||||
edge_reranker_scores: list[float] = Field(default_factory=list)
|
||||
nodes: list[EntityNode] = Field(default_factory=list)
|
||||
node_reranker_scores: list[float] = Field(default_factory=list)
|
||||
episodes: list[EpisodicNode] = Field(default_factory=list)
|
||||
episode_reranker_scores: list[float] = Field(default_factory=list)
|
||||
communities: list[CommunityNode] = Field(default_factory=list)
|
||||
community_reranker_scores: list[float] = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -294,13 +294,13 @@ async def edge_bfs_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
AND r.group_id IN $group_ids
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
AND r.group_id IN $group_ids
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
|
|
@ -445,11 +445,11 @@ async def node_bfs_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
AND origin.group_id IN $group_ids
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
AND origin.group_id IN $group_ids
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
|
|
@ -672,7 +672,7 @@ async def hybrid_node_search(
|
|||
}
|
||||
result_uuids = [[node.uuid for node in result] for result in results]
|
||||
|
||||
ranked_uuids = rrf(result_uuids)
|
||||
ranked_uuids, _ = rrf(result_uuids)
|
||||
|
||||
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
||||
|
||||
|
|
@ -914,7 +914,9 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
|
||||
# takes in a list of rankings of uuids
|
||||
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
|
||||
def rrf(
|
||||
results: list[list[str]], rank_const=1, min_score: float = 0
|
||||
) -> tuple[list[str], list[float]]:
|
||||
scores: dict[str, float] = defaultdict(float)
|
||||
for result in results:
|
||||
for i, uuid in enumerate(result):
|
||||
|
|
@ -925,7 +927,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|||
|
||||
sorted_uuids = [term[0] for term in scored_uuids]
|
||||
|
||||
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
||||
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
||||
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
||||
]
|
||||
|
||||
|
||||
async def node_distance_reranker(
|
||||
|
|
@ -933,7 +937,7 @@ async def node_distance_reranker(
|
|||
node_uuids: list[str],
|
||||
center_node_uuid: str,
|
||||
min_score: float = 0,
|
||||
) -> list[str]:
|
||||
) -> tuple[list[str], list[float]]:
|
||||
# filter out node_uuid center node node uuid
|
||||
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
||||
scores: dict[str, float] = {center_node_uuid: 0.0}
|
||||
|
|
@ -970,14 +974,16 @@ async def node_distance_reranker(
|
|||
scores[center_node_uuid] = 0.1
|
||||
filtered_uuids = [center_node_uuid] + filtered_uuids
|
||||
|
||||
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
|
||||
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
|
||||
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
|
||||
]
|
||||
|
||||
|
||||
async def episode_mentions_reranker(
|
||||
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
||||
) -> list[str]:
|
||||
) -> tuple[list[str], list[float]]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(node_uuids)
|
||||
sorted_uuids, _ = rrf(node_uuids)
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
|
|
@ -998,7 +1004,9 @@ async def episode_mentions_reranker(
|
|||
# rerank on shortest distance
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
||||
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
||||
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
||||
]
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
|
|
@ -1006,7 +1014,7 @@ def maximal_marginal_relevance(
|
|||
candidates: dict[str, list[float]],
|
||||
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
||||
min_score: float = -2.0,
|
||||
) -> list[str]:
|
||||
) -> tuple[list[str], list[float]]:
|
||||
start = time()
|
||||
query_array = np.array(query_vector)
|
||||
candidate_arrays: dict[str, NDArray] = {}
|
||||
|
|
@ -1037,7 +1045,9 @@ def maximal_marginal_relevance(
|
|||
end = time()
|
||||
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
||||
|
||||
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
||||
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
|
||||
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
|
||||
]
|
||||
|
||||
|
||||
async def get_embeddings_for_nodes(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.17.11"
|
||||
version = "0.18.0"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def test_graphiti_init():
|
|||
)
|
||||
|
||||
results = await graphiti.search_(
|
||||
query='What is the hall of portrait?',
|
||||
query='Who is Tania',
|
||||
search_filter=search_filter,
|
||||
)
|
||||
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.17.10"
|
||||
version = "0.18.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue