fix node distance reranker (#231)
This commit is contained in:
parent
f59a7e6b78
commit
6a152ab91a
2 changed files with 7 additions and 2 deletions
|
|
@ -631,7 +631,7 @@ async def node_distance_reranker(
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# filter out node_uuid center node node uuid
|
# filter out node_uuid center node node uuid
|
||||||
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = Query("""
|
query = Query("""
|
||||||
|
|
@ -649,9 +649,13 @@ async def node_distance_reranker(
|
||||||
|
|
||||||
for result in path_results:
|
for result in path_results:
|
||||||
uuid = result['uuid']
|
uuid = result['uuid']
|
||||||
score = result['score'] if 'score' in result else float('inf')
|
score = result['score']
|
||||||
scores[uuid] = score
|
scores[uuid] = score
|
||||||
|
|
||||||
|
for uuid in filtered_uuids:
|
||||||
|
if uuid not in scores:
|
||||||
|
scores[uuid] = float('inf')
|
||||||
|
|
||||||
# rerank on shortest distance
|
# rerank on shortest distance
|
||||||
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ async def test_graphiti_init():
|
||||||
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||||
group_ids=['test'],
|
group_ids=['test'],
|
||||||
)
|
)
|
||||||
|
|
||||||
pretty_results = {
|
pretty_results = {
|
||||||
'edges': [edge.fact for edge in results.edges],
|
'edges': [edge.fact for edge in results.edges],
|
||||||
'nodes': [node.name for node in results.nodes],
|
'nodes': [node.name for node in results.nodes],
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue