Improve node distance reranker speed (#107)
* much faster * clean up code * variable rename
This commit is contained in:
parent
8085b52f2a
commit
85cf8e5840
3 changed files with 28 additions and 28 deletions
|
|
@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
|
|
||||||
if not use_bulk:
|
if not use_bulk:
|
||||||
for i, message in enumerate(messages[3:14]):
|
for i, message in enumerate(messages[3:130]):
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
|
|
|
||||||
|
|
@ -496,34 +496,39 @@ async def node_distance_reranker(
|
||||||
sorted_uuids = rrf(results)
|
sorted_uuids = rrf(results)
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
for uuid in sorted_uuids:
|
# Find the shortest path to center node
|
||||||
# Find the shortest path to center node
|
query = Query("""
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
||||||
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
|
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
|
||||||
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
|
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
||||||
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
""")
|
||||||
""",
|
|
||||||
edge_uuid=uuid,
|
|
||||||
center_uuid=center_node_uuid,
|
|
||||||
)
|
|
||||||
distance = 0.01
|
|
||||||
|
|
||||||
for record in records:
|
path_results = await asyncio.gather(
|
||||||
if (
|
*[
|
||||||
record['source_uuid'] == center_node_uuid
|
driver.execute_query(
|
||||||
or record['target_uuid'] == center_node_uuid
|
query,
|
||||||
):
|
edge_uuid=uuid,
|
||||||
continue
|
center_uuid=center_node_uuid,
|
||||||
distance = record['score']
|
)
|
||||||
|
for uuid in sorted_uuids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for uuid, result in zip(sorted_uuids, path_results):
|
||||||
|
records = result[0]
|
||||||
|
record = records[0] if len(records) > 0 else None
|
||||||
|
distance: float = record['score'] if record is not None else float('inf')
|
||||||
|
if record is not None and (
|
||||||
|
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
|
||||||
|
):
|
||||||
|
distance = 0
|
||||||
|
|
||||||
if uuid in scores:
|
if uuid in scores:
|
||||||
scores[uuid] = min(1 / distance, scores[uuid])
|
scores[uuid] = min(distance, scores[uuid])
|
||||||
else:
|
else:
|
||||||
scores[uuid] = 1 / distance
|
scores[uuid] = distance
|
||||||
|
|
||||||
# rerank on shortest distance
|
# rerank on shortest distance
|
||||||
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
return sorted_uuids
|
return sorted_uuids
|
||||||
|
|
|
||||||
|
|
@ -73,11 +73,6 @@ def format_context(facts):
|
||||||
async def test_graphiti_init():
|
async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
await graphiti.build_communities()
|
|
||||||
|
|
||||||
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
|
|
||||||
|
|
||||||
logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges]))
|
|
||||||
|
|
||||||
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue