From 85cf8e58404031e8d8f8a7d0a254d57ad7ca23e9 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:23:45 -0400 Subject: [PATCH] Improve node distance reranker speed (#107) * much faster * clean up code * variable rename --- examples/podcast/podcast_runner.py | 2 +- graphiti_core/search/search_utils.py | 49 +++++++++++++++------------- tests/test_graphiti_int.py | 5 --- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index d3b2404c..792d5720 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True): messages = parse_podcast_messages() if not use_bulk: - for i, message in enumerate(messages[3:14]): + for i, message in enumerate(messages[3:130]): await client.add_episode( name=f'Message {i}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}', diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 5b63d300..38f3bd6e 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -496,34 +496,39 @@ async def node_distance_reranker( sorted_uuids = rrf(results) scores: dict[str, float] = {} - for uuid in sorted_uuids: - # Find the shortest path to center node - records, _, _ = await driver.execute_query( - """ + # Find the shortest path to center node + query = Query(""" MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) - MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity) - WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid] - RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid - """, - edge_uuid=uuid, - center_uuid=center_node_uuid, - ) - distance = 0.01 + MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid}) + RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid + """) - for record in records: - if ( - record['source_uuid'] == center_node_uuid - or record['target_uuid'] == center_node_uuid - ): - continue - distance = record['score'] + path_results = await asyncio.gather( + *[ + driver.execute_query( + query, + edge_uuid=uuid, + center_uuid=center_node_uuid, + ) + for uuid in sorted_uuids + ] + ) + + for uuid, result in zip(sorted_uuids, path_results): + records = result[0] + record = records[0] if len(records) > 0 else None + distance: float = record['score'] if record is not None else float('inf') + if record is not None and ( + record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid + ): + distance = 0 if uuid in scores: - scores[uuid] = min(1 / distance, scores[uuid]) + scores[uuid] = min(distance, scores[uuid]) else: - scores[uuid] = 1 / distance + scores[uuid] = distance # rerank on shortest distance - sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid]) + sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) return sorted_uuids diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index e73500f1..682f9d50 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -73,11 +73,6 @@ def format_context(facts): async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - await graphiti.build_communities() - - edges = await graphiti.search('Freakenomics guest', group_ids=['1']) - - logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges])) edges = await graphiti.search('tania tetlow', group_ids=['1'])