Invalidation updates && improvements (#20)
* wip * wip * wip * fix: Linter errors * fix formatting * chore: fix ruff * fix: Duplication --------- Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
parent
94873f1083
commit
1f1652f56c
9 changed files with 507 additions and 54 deletions
|
|
@ -28,7 +28,10 @@ from core.utils.bulk_utils import (
|
|||
resolve_edge_pointers,
|
||||
retrieve_previous_episodes_bulk,
|
||||
)
|
||||
from core.utils.maintenance.edge_operations import dedupe_extracted_edges, extract_edges
|
||||
from core.utils.maintenance.edge_operations import (
|
||||
dedupe_extracted_edges,
|
||||
extract_edges,
|
||||
)
|
||||
from core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_indices_and_constraints,
|
||||
|
|
@ -116,6 +119,7 @@ class Graphiti:
|
|||
)
|
||||
|
||||
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
|
||||
# Calculate Embeddings
|
||||
|
||||
|
|
@ -124,14 +128,14 @@ class Graphiti:
|
|||
)
|
||||
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
|
||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
new_nodes, _ = await dedupe_extracted_nodes(
|
||||
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
|
||||
self.llm_client, extracted_nodes, existing_nodes
|
||||
)
|
||||
logger.info(f'Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}')
|
||||
nodes.extend(new_nodes)
|
||||
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
|
||||
nodes.extend(touched_nodes)
|
||||
|
||||
extracted_edges = await extract_edges(
|
||||
self.llm_client, episode, new_nodes, previous_episodes
|
||||
self.llm_client, episode, touched_nodes, previous_episodes
|
||||
)
|
||||
|
||||
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
|
||||
|
|
@ -140,10 +144,23 @@ class Graphiti:
|
|||
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
|
||||
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
|
||||
|
||||
# deduped_edges = await dedupe_extracted_edges_v2(
|
||||
# self.llm_client,
|
||||
# extract_node_and_edge_triplets(extracted_edges, nodes),
|
||||
# extract_node_and_edge_triplets(existing_edges, nodes),
|
||||
# )
|
||||
|
||||
deduped_edges = await dedupe_extracted_edges(
|
||||
self.llm_client, extracted_edges, existing_edges
|
||||
self.llm_client,
|
||||
extracted_edges,
|
||||
existing_edges,
|
||||
)
|
||||
|
||||
edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
|
||||
for edge in deduped_edges:
|
||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||
|
||||
(
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
|
|
@ -155,26 +172,36 @@ class Graphiti:
|
|||
self.llm_client,
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
|
||||
entity_edges.extend(invalidated_edges)
|
||||
for edge in invalidated_edges:
|
||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||
|
||||
edges_to_save = invalidated_edges
|
||||
|
||||
# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
|
||||
for deduped_edge in deduped_edges:
|
||||
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
|
||||
edges_to_save.append(deduped_edge)
|
||||
|
||||
entity_edges.extend(edges_to_save)
|
||||
|
||||
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
|
||||
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
|
||||
|
||||
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
|
||||
|
||||
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
||||
|
||||
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
|
||||
entity_edges.extend(deduped_edges)
|
||||
|
||||
new_edges = await dedupe_extracted_edges(
|
||||
self.llm_client, extracted_edges, existing_edges
|
||||
)
|
||||
|
||||
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in new_edges]}')
|
||||
|
||||
entity_edges.extend(new_edges)
|
||||
episodic_edges.extend(
|
||||
build_episodic_edges(
|
||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||
nodes,
|
||||
involved_nodes,
|
||||
episode,
|
||||
now,
|
||||
)
|
||||
|
|
@ -182,12 +209,6 @@ class Graphiti:
|
|||
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
|
||||
logger.info(f'Built episodic edges: {episodic_edges}')
|
||||
|
||||
# invalidated_edges = await self.invalidate_edges(
|
||||
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
|
||||
# )
|
||||
|
||||
# edges.extend(invalidated_edges)
|
||||
|
||||
# Future optimization would be using batch operations to save nodes and edges
|
||||
await episode.save(self.driver)
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ from .models import Message, PromptFunction, PromptVersion
|
|||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
edge_list: PromptVersion
|
||||
v2: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
edge_list: PromptFunction
|
||||
|
||||
|
||||
|
|
@ -54,6 +55,48 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that de-duplicates relationship from edge lists.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
|
||||
|
||||
Existing Edges:
|
||||
{json.dumps(context['existing_edges'], indent=2)}
|
||||
|
||||
New Edges:
|
||||
{json.dumps(context['extracted_edges'], indent=2)}
|
||||
|
||||
Task:
|
||||
1. start with the list of edges from New Edges
|
||||
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
|
||||
edge in the list
|
||||
3. Respond with the resulting list of edges
|
||||
|
||||
Guidelines:
|
||||
1. Use both the triplet name and fact of edges to determine if they are duplicates,
|
||||
duplicate edges may have different names meaning the same thing and slight variations in the facts.
|
||||
2. If you encounter facts that are semantically equivalent or very similar, keep the original edge
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_edges": [
|
||||
{{
|
||||
"triplet": "source_node_name-edge_name-target_node_name",
|
||||
"fact": "one sentence description of the fact"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def edge_list(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
|
|
@ -90,4 +133,4 @@ def edge_list(context: dict[str, any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {'v1': v1, 'edge_list': edge_list}
|
||||
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}
|
||||
|
|
|
|||
|
|
@ -15,14 +15,20 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.',
|
||||
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
|
||||
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
|
||||
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
|
||||
Do not invalidate relationships merely because they weren't mentioned in new edges.
|
||||
Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.
|
||||
|
||||
Previous Episodes:
|
||||
{context['previous_episodes']}
|
||||
|
||||
Current Episode:
|
||||
{context['current_episode']}
|
||||
|
||||
Existing Edges (sorted by timestamp, newest first):
|
||||
{context['existing_edges']}
|
||||
|
|
@ -30,19 +36,19 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
New Edges:
|
||||
{context['new_edges']}
|
||||
|
||||
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
|
||||
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), TIMESTAMP)"
|
||||
|
||||
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
||||
{{
|
||||
"invalidated_edges": [
|
||||
{{
|
||||
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
|
||||
"reason": "Brief explanation of why this edge is being invalidated"
|
||||
"fact": "Updated fact of the edge"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no relationships need to be invalidated, return an empty list for "invalidated_edges".
|
||||
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ async def dedupe_nodes_bulk(
|
|||
|
||||
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
|
||||
|
||||
nodes, partial_uuid_map = await dedupe_extracted_nodes(
|
||||
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
|
||||
llm_client, compressed_nodes, existing_nodes
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from core.edges import EntityEdge, EpisodicEdge
|
|||
from core.llm_client import LLMClient
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.prompts import prompt_library
|
||||
from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -179,6 +180,51 @@ async def extract_edges(
|
|||
return edges
|
||||
|
||||
|
||||
def create_edge_identifier(
|
||||
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
||||
) -> str:
|
||||
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
||||
|
||||
|
||||
async def dedupe_extracted_edges_v2(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[NodeEdgeNodeTriplet],
|
||||
existing_edges: list[NodeEdgeNodeTriplet],
|
||||
) -> list[NodeEdgeNodeTriplet]:
|
||||
# Create edge map
|
||||
edge_map = {}
|
||||
for n1, edge, n2 in existing_edges:
|
||||
edge_map[create_edge_identifier(n1, edge, n2)] = edge
|
||||
for n1, edge, n2 in extracted_edges:
|
||||
if create_edge_identifier(n1, edge, n2) in edge_map:
|
||||
continue
|
||||
edge_map[create_edge_identifier(n1, edge, n2)] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'extracted_edges': [
|
||||
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
|
||||
for n1, edge, n2 in extracted_edges
|
||||
],
|
||||
'existing_edges': [
|
||||
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
|
||||
for n1, edge, n2 in extracted_edges
|
||||
],
|
||||
}
|
||||
logger.info(prompt_library.dedupe_edges.v2(context))
|
||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context))
|
||||
new_edges_data = llm_response.get('new_edges', [])
|
||||
logger.info(f'Extracted new edges: {new_edges_data}')
|
||||
|
||||
# Get full edge data
|
||||
edges = []
|
||||
for edge_data in new_edges_data:
|
||||
edge = edge_map[edge_data['triplet']]
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def dedupe_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
|
|
|
|||
|
|
@ -108,6 +108,11 @@ async def dedupe_extracted_nodes(
|
|||
for node in existing_nodes:
|
||||
node_map[node.name] = node
|
||||
|
||||
# Temp hack
|
||||
new_nodes_map = {}
|
||||
for node in extracted_nodes:
|
||||
new_nodes_map[node.name] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_nodes_context = [
|
||||
{'name': node.name, 'summary': node.summary} for node in existing_nodes
|
||||
|
|
@ -131,20 +136,25 @@ async def dedupe_extracted_nodes(
|
|||
|
||||
uuid_map = {}
|
||||
for duplicate in duplicate_data:
|
||||
uuid = node_map[duplicate['name']].uuid
|
||||
uuid = new_nodes_map[duplicate['name']].uuid
|
||||
uuid_value = node_map[duplicate['duplicate_of']].uuid
|
||||
uuid_map[uuid] = uuid_value
|
||||
|
||||
nodes = []
|
||||
brand_new_nodes = []
|
||||
for node in extracted_nodes:
|
||||
if node.uuid in uuid_map:
|
||||
existing_name = uuid_map[node.name]
|
||||
existing_node = node_map[existing_name]
|
||||
existing_uuid = uuid_map[node.uuid]
|
||||
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
|
||||
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
|
||||
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
|
||||
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
|
||||
nodes.append(existing_node)
|
||||
continue
|
||||
brand_new_nodes.append(node)
|
||||
nodes.append(node)
|
||||
|
||||
return nodes, uuid_map
|
||||
return nodes, uuid_map, brand_new_nodes
|
||||
|
||||
|
||||
async def dedupe_node_list(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import List
|
|||
|
||||
from core.edges import EntityEdge
|
||||
from core.llm_client import LLMClient
|
||||
from core.nodes import EntityNode
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.prompts import prompt_library
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -12,6 +12,20 @@ logger = logging.getLogger(__name__)
|
|||
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
|
||||
|
||||
|
||||
def extract_node_and_edge_triplets(
|
||||
edges: list[EntityEdge], nodes: list[EntityNode]
|
||||
) -> list[NodeEdgeNodeTriplet]:
|
||||
return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
|
||||
|
||||
|
||||
def extract_node_edge_node_triplet(
|
||||
edge: EntityEdge, nodes: list[EntityNode]
|
||||
) -> NodeEdgeNodeTriplet:
|
||||
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
||||
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
||||
return (source_node, edge, target_node)
|
||||
|
||||
|
||||
def prepare_edges_for_invalidation(
|
||||
existing_edges: list[EntityEdge],
|
||||
new_edges: list[EntityEdge],
|
||||
|
|
@ -39,13 +53,22 @@ def prepare_edges_for_invalidation(
|
|||
|
||||
async def invalidate_edges(
|
||||
llm_client: LLMClient,
|
||||
existing_edges_pending_invalidation: List[NodeEdgeNodeTriplet],
|
||||
new_edges: List[NodeEdgeNodeTriplet],
|
||||
) -> List[EntityEdge]:
|
||||
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
|
||||
new_edges: list[NodeEdgeNodeTriplet],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityEdge]:
|
||||
invalidated_edges = [] # TODO: this is not yet used?
|
||||
|
||||
context = prepare_invalidation_context(existing_edges_pending_invalidation, new_edges)
|
||||
context = prepare_invalidation_context(
|
||||
existing_edges_pending_invalidation,
|
||||
new_edges,
|
||||
current_episode,
|
||||
previous_episodes,
|
||||
)
|
||||
logger.info(prompt_library.invalidate_edges.v1(context))
|
||||
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
||||
logger.info(f'invalidate_edges LLM response: {llm_response}')
|
||||
|
||||
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
||||
invalidated_edges = process_edge_invalidation_llm_response(
|
||||
|
|
@ -56,21 +79,26 @@ async def invalidate_edges(
|
|||
|
||||
|
||||
def prepare_invalidation_context(
|
||||
existing_edges: List[NodeEdgeNodeTriplet], new_edges: List[NodeEdgeNodeTriplet]
|
||||
existing_edges: list[NodeEdgeNodeTriplet],
|
||||
new_edges: list[NodeEdgeNodeTriplet],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> dict:
|
||||
return {
|
||||
'existing_edges': [
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})'
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
|
||||
for source_node, edge, target_node in sorted(
|
||||
existing_edges, key=lambda x: x[1].created_at, reverse=True
|
||||
)
|
||||
],
|
||||
'new_edges': [
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})'
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
|
||||
for source_node, edge, target_node in sorted(
|
||||
new_edges, key=lambda x: x[1].created_at, reverse=True
|
||||
)
|
||||
],
|
||||
'current_episode': current_episode.content,
|
||||
'previous_episodes': [episode.content for episode in previous_episodes],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -86,8 +114,9 @@ def process_edge_invalidation_llm_response(
|
|||
)
|
||||
if edge_to_update:
|
||||
edge_to_update.expired_at = datetime.now()
|
||||
edge_to_update.fact = edge_to_invalidate['fact']
|
||||
invalidated_edges.append(edge_to_update)
|
||||
logger.info(
|
||||
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Reason: {edge_to_invalidate['reason']}"
|
||||
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
|
||||
)
|
||||
return invalidated_edges
|
||||
|
|
|
|||
292
core/utils/search/search_utils.py
Normal file
292
core/utils/search/search_utils.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from neo4j import time as neo4j_time
|
||||
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
||||
RETURN
|
||||
n.uuid AS source_node_uuid,
|
||||
n.name AS source_name,
|
||||
n.summary AS source_summary,
|
||||
m.uuid AS target_node_uuid,
|
||||
m.name AS target_name,
|
||||
m.summary AS target_summary,
|
||||
r.uuid AS uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
|
||||
""",
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
||||
for record in records:
|
||||
n_uuid = record['source_node_uuid']
|
||||
if n_uuid in context:
|
||||
context[n_uuid]['facts'].append(record['fact'])
|
||||
else:
|
||||
context[n_uuid] = {
|
||||
'name': record['source_name'],
|
||||
'summary': record['source_summary'],
|
||||
'facts': [record['fact']],
|
||||
}
|
||||
|
||||
m_uuid = record['target_node_uuid']
|
||||
if m_uuid not in context:
|
||||
context[m_uuid] = {
|
||||
'name': record['target_name'],
|
||||
'summary': record['target_summary'],
|
||||
'facts': [],
|
||||
}
|
||||
logger.info(f'bfs search returned context: {context}')
|
||||
return context
|
||||
|
||||
|
||||
async def edge_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n)-[r:RELATES_TO]->(m)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
expired_at=safely_parse_db_date(record['expired_at']),
|
||||
valid_at=safely_parse_db_date(record['valid_at']),
|
||||
invalid_At=safely_parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def entity_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=[],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def entity_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = query + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||
RETURN
|
||||
node.uuid As uuid,
|
||||
node.name AS name,
|
||||
node.created_at AS created_at,
|
||||
node.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=[],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def edge_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = query + '~'
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n:Entity)-[r]->(m:Entity)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
expired_at=safely_parse_db_date(record['expired_at']),
|
||||
valid_at=safely_parse_db_date(record['valid_at']),
|
||||
invalid_At=safely_parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime:
|
||||
if date_str:
|
||||
return datetime.fromisoformat(date_str.iso_format())
|
||||
return None
|
||||
|
||||
|
||||
async def get_relevant_nodes(
|
||||
nodes: list[EntityNode],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
relevant_nodes: list[EntityNode] = []
|
||||
relevant_node_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
||||
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for node in result:
|
||||
if node.uuid in relevant_node_uuids:
|
||||
continue
|
||||
|
||||
relevant_node_uuids.add(node.uuid)
|
||||
relevant_nodes.append(node)
|
||||
|
||||
end = time()
|
||||
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
|
||||
|
||||
return relevant_nodes
|
||||
|
||||
|
||||
async def get_relevant_edges(
|
||||
edges: list[EntityEdge],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
relevant_edges: list[EntityEdge] = []
|
||||
relevant_edge_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for edge in result:
|
||||
if edge.uuid in relevant_edge_uuids:
|
||||
continue
|
||||
|
||||
relevant_edge_uuids.add(edge.uuid)
|
||||
relevant_edges.append(edge)
|
||||
|
||||
end = time()
|
||||
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
||||
|
||||
return relevant_edges
|
||||
26
runner.py
26
runner.py
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
|
@ -42,26 +43,31 @@ async def main():
|
|||
await clear_data(client.driver)
|
||||
|
||||
# await client.build_indices()
|
||||
await client.add_episode(
|
||||
name='Message 1',
|
||||
episode_body='Paul: I love apples',
|
||||
source_description='WhatsApp Message',
|
||||
)
|
||||
await client.add_episode(
|
||||
name='Message 2',
|
||||
episode_body='Paul: I hate apples now',
|
||||
source_description='WhatsApp Message',
|
||||
)
|
||||
await client.add_episode(
|
||||
name='Message 3',
|
||||
episode_body='Jane: I am married to Paul',
|
||||
source_description='WhatsApp Message',
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
await client.add_episode(
|
||||
name='Message 4',
|
||||
episode_body='Paul: I have divorced Jane',
|
||||
source_description='WhatsApp Message',
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
await client.add_episode(
|
||||
name='Message 5',
|
||||
episode_body='Jane: I miss Paul',
|
||||
source_description='WhatsApp Message',
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
await client.add_episode(
|
||||
name='Message 6',
|
||||
episode_body='Jane: I dont miss Paul anymore, I hate him',
|
||||
source_description='WhatsApp Message',
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
|
||||
# await client.add_episode(
|
||||
# name="Message 3",
|
||||
# episode_body="Assistant: The best type of apples available are Fuji apples",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue