fix node distance reranker (#231)
This commit is contained in:
parent
f59a7e6b78
commit
6a152ab91a
2 changed files with 7 additions and 2 deletions
|
|
@ -631,7 +631,7 @@ async def node_distance_reranker(
|
|||
) -> list[str]:
|
||||
# filter out node_uuid center node node uuid
|
||||
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
||||
scores: dict[str, float] = {}
|
||||
scores: dict[str, float] = {center_node_uuid: 0.0}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
|
|
@ -649,9 +649,13 @@ async def node_distance_reranker(
|
|||
|
||||
for result in path_results:
|
||||
uuid = result['uuid']
|
||||
score = result['score'] if 'score' in result else float('inf')
|
||||
score = result['score']
|
||||
scores[uuid] = score
|
||||
|
||||
for uuid in filtered_uuids:
|
||||
if uuid not in scores:
|
||||
scores[uuid] = float('inf')
|
||||
|
||||
# rerank on shortest distance
|
||||
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ async def test_graphiti_init():
|
|||
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||
group_ids=['test'],
|
||||
)
|
||||
|
||||
pretty_results = {
|
||||
'edges': [edge.fact for edge in results.edges],
|
||||
'nodes': [node.name for node in results.nodes],
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue