fix node distance reranker (#231)

This commit is contained in:
Preston Rasmussen 2024-12-06 12:08:54 -05:00 committed by GitHub
parent f59a7e6b78
commit 6a152ab91a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 2 deletions

View file

@ -631,7 +631,7 @@ async def node_distance_reranker(
) -> list[str]: ) -> list[str]:
# filter out node_uuid center node node uuid # filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids)) filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {} scores: dict[str, float] = {center_node_uuid: 0.0}
# Find the shortest path to center node # Find the shortest path to center node
query = Query(""" query = Query("""
@ -649,9 +649,13 @@ async def node_distance_reranker(
for result in path_results: for result in path_results:
uuid = result['uuid'] uuid = result['uuid']
score = result['score'] if 'score' in result else float('inf') score = result['score']
scores[uuid] = score scores[uuid] = score
for uuid in filtered_uuids:
if uuid not in scores:
scores[uuid] = float('inf')
# rerank on shortest distance # rerank on shortest distance
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

View file

@ -72,6 +72,7 @@ async def test_graphiti_init():
COMBINED_HYBRID_SEARCH_CROSS_ENCODER, COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
group_ids=['test'], group_ids=['test'],
) )
pretty_results = { pretty_results = {
'edges': [edge.fact for edge in results.edges], 'edges': [edge.fact for edge in results.edges],
'nodes': [node.name for node in results.nodes], 'nodes': [node.name for node in results.nodes],