Improve node distance reranker speed (#107)
* much faster * clean up code * variable rename
This commit is contained in:
parent
8085b52f2a
commit
85cf8e5840
3 changed files with 28 additions and 28 deletions
|
|
@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
|
|||
messages = parse_podcast_messages()
|
||||
|
||||
if not use_bulk:
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
for i, message in enumerate(messages[3:130]):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
|
|
|
|||
|
|
@ -496,34 +496,39 @@ async def node_distance_reranker(
|
|||
sorted_uuids = rrf(results)
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
for uuid in sorted_uuids:
|
||||
# Find the shortest path to center node
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
||||
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
|
||||
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
|
||||
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
||||
""",
|
||||
edge_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
)
|
||||
distance = 0.01
|
||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
|
||||
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
||||
""")
|
||||
|
||||
for record in records:
|
||||
if (
|
||||
record['source_uuid'] == center_node_uuid
|
||||
or record['target_uuid'] == center_node_uuid
|
||||
):
|
||||
continue
|
||||
distance = record['score']
|
||||
path_results = await asyncio.gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
edge_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
)
|
||||
for uuid in sorted_uuids
|
||||
]
|
||||
)
|
||||
|
||||
for uuid, result in zip(sorted_uuids, path_results):
|
||||
records = result[0]
|
||||
record = records[0] if len(records) > 0 else None
|
||||
distance: float = record['score'] if record is not None else float('inf')
|
||||
if record is not None and (
|
||||
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
|
||||
):
|
||||
distance = 0
|
||||
|
||||
if uuid in scores:
|
||||
scores[uuid] = min(1 / distance, scores[uuid])
|
||||
scores[uuid] = min(distance, scores[uuid])
|
||||
else:
|
||||
scores[uuid] = 1 / distance
|
||||
scores[uuid] = distance
|
||||
|
||||
# rerank on shortest distance
|
||||
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
return sorted_uuids
|
||||
|
|
|
|||
|
|
@ -73,11 +73,6 @@ def format_context(facts):
|
|||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
await graphiti.build_communities()
|
||||
|
||||
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
|
||||
|
||||
logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue