Improve node distance reranker speed (#107)

* much faster

* clean up code

* variable rename
This commit is contained in:
Preston Rasmussen 2024-09-12 11:23:45 -04:00 committed by GitHub
parent 8085b52f2a
commit 85cf8e5840
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 28 deletions

View file

@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()
if not use_bulk:
for i, message in enumerate(messages[3:14]):
for i, message in enumerate(messages[3:130]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',

View file

@ -496,34 +496,39 @@ async def node_distance_reranker(
sorted_uuids = rrf(results)
scores: dict[str, float] = {}
for uuid in sorted_uuids:
# Find the shortest path to center node
records, _, _ = await driver.execute_query(
"""
# Find the shortest path to center node
query = Query("""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""")
for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']
path_results = await asyncio.gather(
*[
driver.execute_query(
query,
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
for uuid in sorted_uuids
]
)
for uuid, result in zip(sorted_uuids, path_results):
records = result[0]
record = records[0] if len(records) > 0 else None
distance: float = record['score'] if record is not None else float('inf')
if record is not None and (
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
):
distance = 0
if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
scores[uuid] = min(distance, scores[uuid])
else:
scores[uuid] = 1 / distance
scores[uuid] = distance
# rerank on shortest distance
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids

View file

@ -73,11 +73,6 @@ def format_context(facts):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('tania tetlow', group_ids=['1'])