Speed up add episode (#77)
* WIP * updates * use uuid for node dedupe * pret-testing * parallelized node resolution * working add_episode * revert to 4o * format * mypy update * update types
This commit is contained in:
parent
db12ac548d
commit
e9e6039b1e
12 changed files with 427 additions and 177 deletions
|
|
@ -70,6 +70,7 @@ async def main(use_bulk: bool = True):
|
||||||
reference_time=message.actual_timestamp,
|
reference_time=message.actual_timestamp,
|
||||||
source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
episodes: list[RawEpisode] = [
|
episodes: list[RawEpisode] = [
|
||||||
RawEpisode(
|
RawEpisode(
|
||||||
|
|
@ -79,10 +80,10 @@ async def main(use_bulk: bool = True):
|
||||||
source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
reference_time=message.actual_timestamp,
|
reference_time=message.actual_timestamp,
|
||||||
)
|
)
|
||||||
for i, message in enumerate(messages[3:14])
|
for i, message in enumerate(messages[3:20])
|
||||||
]
|
]
|
||||||
|
|
||||||
await client.add_episode_bulk(episodes)
|
await client.add_episode_bulk(episodes)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main(True))
|
asyncio.run(main(False))
|
||||||
|
|
|
||||||
|
|
@ -48,14 +48,17 @@ from graphiti_core.utils.bulk_utils import (
|
||||||
retrieve_previous_episodes_bulk,
|
retrieve_previous_episodes_bulk,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
dedupe_extracted_edges,
|
|
||||||
extract_edges,
|
extract_edges,
|
||||||
|
resolve_extracted_edges,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
EPISODE_WINDOW_LEN,
|
EPISODE_WINDOW_LEN,
|
||||||
build_indices_and_constraints,
|
build_indices_and_constraints,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
from graphiti_core.utils.maintenance.node_operations import (
|
||||||
|
extract_nodes,
|
||||||
|
resolve_extracted_nodes,
|
||||||
|
)
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||||
extract_edge_dates,
|
extract_edge_dates,
|
||||||
invalidate_edges,
|
invalidate_edges,
|
||||||
|
|
@ -177,9 +180,9 @@ class Graphiti:
|
||||||
await build_indices_and_constraints(self.driver)
|
await build_indices_and_constraints(self.driver)
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
self,
|
self,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -207,14 +210,14 @@ class Graphiti:
|
||||||
return await retrieve_episodes(self.driver, reference_time, last_n)
|
return await retrieve_episodes(self.driver, reference_time, last_n)
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
episode_body: str,
|
episode_body: str,
|
||||||
source_description: str,
|
source_description: str,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
source: EpisodeType = EpisodeType.message,
|
source: EpisodeType = EpisodeType.message,
|
||||||
success_callback: Callable | None = None,
|
success_callback: Callable | None = None,
|
||||||
error_callback: Callable | None = None,
|
error_callback: Callable | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process an episode and update the graph.
|
Process an episode and update the graph.
|
||||||
|
|
@ -265,7 +268,6 @@ class Graphiti:
|
||||||
|
|
||||||
nodes: list[EntityNode] = []
|
nodes: list[EntityNode] = []
|
||||||
entity_edges: list[EntityEdge] = []
|
entity_edges: list[EntityEdge] = []
|
||||||
episodic_edges: list[EpisodicEdge] = []
|
|
||||||
embedder = self.llm_client.get_embedder()
|
embedder = self.llm_client.get_embedder()
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
|
|
@ -280,6 +282,8 @@ class Graphiti:
|
||||||
valid_at=reference_time,
|
valid_at=reference_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract entities as nodes
|
||||||
|
|
||||||
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
||||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
|
|
||||||
|
|
@ -288,57 +292,82 @@ class Graphiti:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
||||||
)
|
)
|
||||||
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
|
|
||||||
|
# Resolve extracted nodes with nodes already in the graph
|
||||||
|
existing_nodes_lists: list[list[EntityNode]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
|
|
||||||
self.llm_client, extracted_nodes, existing_nodes
|
|
||||||
)
|
|
||||||
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
|
|
||||||
nodes.extend(touched_nodes)
|
|
||||||
|
|
||||||
|
mentioned_nodes, _ = await resolve_extracted_nodes(
|
||||||
|
self.llm_client, extracted_nodes, existing_nodes_lists
|
||||||
|
)
|
||||||
|
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
||||||
|
nodes.extend(mentioned_nodes)
|
||||||
|
|
||||||
|
# Extract facts as edges given entity nodes
|
||||||
extracted_edges = await extract_edges(
|
extracted_edges = await extract_edges(
|
||||||
self.llm_client, episode, touched_nodes, previous_episodes
|
self.llm_client, episode, mentioned_nodes, previous_episodes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# calculate embeddings
|
||||||
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
|
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
|
||||||
|
|
||||||
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
|
# Resolve extracted edges with edges already in the graph
|
||||||
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
|
existing_edges_list: list[list[EntityEdge]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
get_relevant_edges(
|
||||||
|
[edge],
|
||||||
|
self.driver,
|
||||||
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
|
edge.source_node_uuid,
|
||||||
|
edge.target_node_uuid,
|
||||||
|
)
|
||||||
|
for edge in extracted_edges
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
|
||||||
|
)
|
||||||
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
|
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
|
||||||
|
|
||||||
deduped_edges = await dedupe_extracted_edges(
|
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
|
||||||
self.llm_client,
|
self.llm_client, extracted_edges, existing_edges_list
|
||||||
extracted_edges,
|
|
||||||
existing_edges,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
|
# Extract dates for the newly extracted edges
|
||||||
for edge in deduped_edges:
|
edge_dates = await asyncio.gather(
|
||||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
*[
|
||||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
extract_edge_dates(
|
||||||
|
self.llm_client,
|
||||||
|
edge,
|
||||||
|
episode,
|
||||||
|
previous_episodes,
|
||||||
|
)
|
||||||
|
for edge in deduped_edges
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, edge in enumerate(deduped_edges):
|
||||||
|
valid_at = edge_dates[i][0]
|
||||||
|
invalid_at = edge_dates[i][1]
|
||||||
|
|
||||||
for edge in deduped_edges:
|
|
||||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
|
||||||
self.llm_client,
|
|
||||||
edge,
|
|
||||||
episode,
|
|
||||||
previous_episodes,
|
|
||||||
)
|
|
||||||
edge.valid_at = valid_at
|
edge.valid_at = valid_at
|
||||||
edge.invalid_at = invalid_at
|
edge.invalid_at = invalid_at
|
||||||
if edge.invalid_at:
|
if edge.invalid_at is not None:
|
||||||
edge.expired_at = now
|
|
||||||
for edge in existing_edges:
|
|
||||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
|
||||||
self.llm_client,
|
|
||||||
edge,
|
|
||||||
episode,
|
|
||||||
previous_episodes,
|
|
||||||
)
|
|
||||||
edge.valid_at = valid_at
|
|
||||||
edge.invalid_at = invalid_at
|
|
||||||
if edge.invalid_at:
|
|
||||||
edge.expired_at = now
|
edge.expired_at = now
|
||||||
|
|
||||||
|
entity_edges.extend(deduped_edges)
|
||||||
|
|
||||||
|
existing_edges: list[EntityEdge] = [
|
||||||
|
e for edge_lst in existing_edges_list for e in edge_lst
|
||||||
|
]
|
||||||
|
|
||||||
(
|
(
|
||||||
old_edges_with_nodes_pending_invalidation,
|
old_edges_with_nodes_pending_invalidation,
|
||||||
new_edges_with_nodes,
|
new_edges_with_nodes,
|
||||||
|
|
@ -361,30 +390,18 @@ class Graphiti:
|
||||||
for deduped_edge in deduped_edges:
|
for deduped_edge in deduped_edges:
|
||||||
if deduped_edge.uuid == edge.uuid:
|
if deduped_edge.uuid == edge.uuid:
|
||||||
deduped_edge.expired_at = edge.expired_at
|
deduped_edge.expired_at = edge.expired_at
|
||||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
|
||||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
|
||||||
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
||||||
|
|
||||||
edges_to_save = existing_edges + deduped_edges
|
entity_edges.extend(existing_edges)
|
||||||
|
|
||||||
entity_edges.extend(edges_to_save)
|
|
||||||
|
|
||||||
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
|
|
||||||
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
|
|
||||||
|
|
||||||
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
|
|
||||||
|
|
||||||
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
|
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
|
||||||
|
|
||||||
episodic_edges.extend(
|
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
||||||
build_episodic_edges(
|
mentioned_nodes,
|
||||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
episode,
|
||||||
involved_nodes,
|
now,
|
||||||
episode,
|
|
||||||
now,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
|
|
||||||
logger.info(f'Built episodic edges: {episodic_edges}')
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
||||||
|
|
||||||
# Future optimization would be using batch operations to save nodes and edges
|
# Future optimization would be using batch operations to save nodes and edges
|
||||||
|
|
@ -395,9 +412,7 @@ class Graphiti:
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||||
# for node in nodes:
|
|
||||||
# if isinstance(node, EntityNode):
|
|
||||||
# await node.update_summary(self.driver)
|
|
||||||
if success_callback:
|
if success_callback:
|
||||||
await success_callback(episode)
|
await success_callback(episode)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -407,8 +422,8 @@ class Graphiti:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def add_episode_bulk(
|
async def add_episode_bulk(
|
||||||
self,
|
self,
|
||||||
bulk_episodes: list[RawEpisode],
|
bulk_episodes: list[RawEpisode],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process multiple episodes in bulk and update the graph.
|
Process multiple episodes in bulk and update the graph.
|
||||||
|
|
@ -572,18 +587,18 @@ class Graphiti:
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
async def _search(
|
async def _search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
timestamp: datetime,
|
timestamp: datetime,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
):
|
):
|
||||||
return await hybrid_search(
|
return await hybrid_search(
|
||||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_nodes_by_query(
|
async def get_nodes_by_query(
|
||||||
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve nodes from the graph database based on a text query.
|
Retrieve nodes from the graph database based on a text query.
|
||||||
|
|
|
||||||
|
|
@ -23,12 +23,14 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
v2: PromptVersion
|
v2: PromptVersion
|
||||||
|
v3: PromptVersion
|
||||||
edge_list: PromptVersion
|
edge_list: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
v2: PromptFunction
|
v2: PromptFunction
|
||||||
|
v3: PromptFunction
|
||||||
edge_list: PromptFunction
|
edge_list: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,17 +43,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
Message(
|
Message(
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
|
Given the following context, deduplicate facts from a list of new facts given a list of existing edges:
|
||||||
|
|
||||||
Existing Facts:
|
Existing Edges:
|
||||||
{json.dumps(context['existing_edges'], indent=2)}
|
{json.dumps(context['existing_edges'], indent=2)}
|
||||||
|
|
||||||
New Facts:
|
New Edges:
|
||||||
{json.dumps(context['extracted_edges'], indent=2)}
|
{json.dumps(context['extracted_edges'], indent=2)}
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
If any facts in New Facts is a duplicate of a fact in Existing Facts,
|
If any edge in New Edges is a duplicate of an edge in Existing Edges, add their uuids to the output list.
|
||||||
do not return it in the list of unique facts.
|
When finding duplicates edges, synthesize their facts into a short new fact.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. identical or near identical facts are duplicates
|
1. identical or near identical facts are duplicates
|
||||||
|
|
@ -60,9 +62,11 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
"unique_facts": [
|
"duplicates": [
|
||||||
{{
|
{{
|
||||||
"uuid": "unique identifier of the fact"
|
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
|
||||||
|
"duplicate_of": "uuid of the existing node",
|
||||||
|
"fact": "one sentence description of the fact"
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
@ -113,6 +117,40 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def v3(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are a helpful assistant that de-duplicates edges from edge lists.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
||||||
|
|
||||||
|
Existing Edges:
|
||||||
|
{json.dumps(context['existing_edges'], indent=2)}
|
||||||
|
|
||||||
|
New Edge:
|
||||||
|
{json.dumps(context['extracted_edges'], indent=2)}
|
||||||
|
Task:
|
||||||
|
1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
|
||||||
|
response. Otherwise, return 'is_duplicate: false'
|
||||||
|
2. If is_duplicate is true, also return the uuid of the existing edge in the response
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"is_duplicate": true or false,
|
||||||
|
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def edge_list(context: dict[str, Any]) -> list[Message]:
|
def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
|
|
@ -151,4 +189,4 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}
|
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'edge_list': edge_list}
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,15 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
v2: PromptVersion
|
v2: PromptVersion
|
||||||
|
v3: PromptVersion
|
||||||
node_list: PromptVersion
|
node_list: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
v2: PromptFunction
|
v2: PromptFunction
|
||||||
node_list: PromptVersion
|
v3: PromptFunction
|
||||||
|
node_list: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, Any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
@ -94,22 +96,22 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
Important:
|
Important:
|
||||||
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
|
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
|
||||||
Task:
|
Task:
|
||||||
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their uuids to the output list
|
||||||
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
|
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
|
||||||
relevant information of the summaries of the new and existing nodes.
|
relevant information of the summaries of the new and existing nodes.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||||
duplicate nodes may have different names
|
duplicate nodes may have different names
|
||||||
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be
|
2. In the output, uuid should always be the uuid of the New Node that is a duplicate. duplicate_of should be
|
||||||
the name of the Existing Node.
|
the uuid of the Existing Node.
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
"duplicates": [
|
"duplicates": [
|
||||||
{{
|
{{
|
||||||
"name": "name of the new node",
|
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
|
||||||
"duplicate_of": "name of the existing node",
|
"duplicate_of": "uuid of the existing node",
|
||||||
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
|
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
|
|
@ -119,6 +121,44 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def v3(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are a helpful assistant that de-duplicates nodes from node lists.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
Given the following context, determine whether the New Node represents any of the entities in the list of Existing Nodes.
|
||||||
|
|
||||||
|
Existing Nodes:
|
||||||
|
{json.dumps(context['existing_nodes'], indent=2)}
|
||||||
|
|
||||||
|
New Node:
|
||||||
|
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||||
|
Task:
|
||||||
|
1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the
|
||||||
|
response. Otherwise, return 'is_duplicate: false'
|
||||||
|
2. If is_duplicate is true, also return the uuid of the existing node in the response
|
||||||
|
3. If is_duplicate is true, return a summary that synthesizes the information in the New Node summary and the
|
||||||
|
summary of the Existing Node it is a duplicate of.
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. Use both the name and summary of nodes to determine if the entities are duplicates,
|
||||||
|
duplicate nodes may have different names
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"is_duplicate": true or false,
|
||||||
|
"uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null",
|
||||||
|
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing node"
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def node_list(context: dict[str, Any]) -> list[Message]:
|
def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
|
|
@ -134,19 +174,19 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
{json.dumps(context['nodes'], indent=2)}
|
{json.dumps(context['nodes'], indent=2)}
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
1. Group nodes together such that all duplicate nodes are in the same list of names
|
1. Group nodes together such that all duplicate nodes are in the same list of uuids
|
||||||
2. All duplicate names should be grouped together in the same list
|
2. All duplicate uuids should be grouped together in the same list
|
||||||
3. Also return a new summary that synthesizes the summary into a new short summary
|
3. Also return a new summary that synthesizes the summary into a new short summary
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Each name from the list of nodes should appear EXACTLY once in your response
|
1. Each uuid from the list of nodes should appear EXACTLY once in your response
|
||||||
2. If a node has no duplicates, it should appear in the response in a list of only one name
|
2. If a node has no duplicates, it should appear in the response in a list of only one uuid
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
"nodes": [
|
"nodes": [
|
||||||
{{
|
{{
|
||||||
"names": ["myNode", "node that is a duplicate of myNode"],
|
"uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
|
||||||
"summary": "Brief summary of the node summaries that appear in the list of names."
|
"summary": "Brief summary of the node summaries that appear in the list of names."
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
|
|
@ -156,4 +196,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'v1': v1, 'v2': v2, 'node_list': node_list}
|
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'node_list': node_list}
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
1. Focus on entities, concepts, or actors that are central to the current episode.
|
1. Focus on entities, concepts, or actors that are central to the current episode.
|
||||||
2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
|
2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
|
||||||
3. Provide a brief but informative summary for each node.
|
3. Provide a brief but informative summary for each node.
|
||||||
|
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
|
|
@ -90,6 +91,7 @@ Guidelines:
|
||||||
3. Provide concise but informative summaries for each extracted node.
|
3. Provide concise but informative summaries for each extracted node.
|
||||||
4. Avoid creating nodes for relationships or actions.
|
4. Avoid creating nodes for relationships or actions.
|
||||||
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
|
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
|
||||||
|
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ async def hybrid_search(
|
||||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||||
|
|
||||||
if SearchMethod.bm25 in config.search_methods:
|
if SearchMethod.bm25 in config.search_methods:
|
||||||
text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges)
|
text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges)
|
||||||
search_results.append(text_search)
|
search_results.append(text_search)
|
||||||
|
|
||||||
if SearchMethod.cosine_similarity in config.search_methods:
|
if SearchMethod.cosine_similarity in config.search_methods:
|
||||||
|
|
@ -95,7 +95,7 @@ async def hybrid_search(
|
||||||
)
|
)
|
||||||
|
|
||||||
similarity_search = await edge_similarity_search(
|
similarity_search = await edge_similarity_search(
|
||||||
search_vector, driver, 2 * config.num_edges
|
driver, search_vector, 2 * config.num_edges
|
||||||
)
|
)
|
||||||
search_results.append(similarity_search)
|
search_results.append(similarity_search)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,14 +96,18 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
async def edge_similarity_search(
|
||||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
driver: AsyncDriver,
|
||||||
|
search_vector: list[float],
|
||||||
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
source_node_uuid: str = '*',
|
||||||
|
target_node_uuid: str = '*',
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
YIELD relationship AS r, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n)-[r:RELATES_TO]->(m)
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
|
|
@ -119,6 +123,8 @@ async def edge_similarity_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
""",
|
""",
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
|
source_uuid=source_node_uuid,
|
||||||
|
target_uuid=target_node_uuid,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -214,7 +220,11 @@ async def entity_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
driver: AsyncDriver,
|
||||||
|
query: str,
|
||||||
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
source_node_uuid: str = '*',
|
||||||
|
target_node_uuid: str = '*',
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||||
|
|
@ -222,8 +232,8 @@ async def edge_fulltext_search(
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
YIELD relationship AS r, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r]->(m:Entity)
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
|
|
@ -239,6 +249,8 @@ async def edge_fulltext_search(
|
||||||
ORDER BY score DESC LIMIT $limit
|
ORDER BY score DESC LIMIT $limit
|
||||||
""",
|
""",
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
|
source_uuid=source_node_uuid,
|
||||||
|
target_uuid=target_node_uuid,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -369,6 +381,9 @@ async def get_relevant_nodes(
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
source_node_uuid: str = '*',
|
||||||
|
target_node_uuid: str = '*',
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
relevant_edges: list[EntityEdge] = []
|
relevant_edges: list[EntityEdge] = []
|
||||||
|
|
@ -376,11 +391,16 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
edge_similarity_search(edge.fact_embedding, driver)
|
edge_similarity_search(
|
||||||
|
driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid
|
||||||
|
)
|
||||||
for edge in edges
|
for edge in edges
|
||||||
if edge.fact_embedding is not None
|
if edge.fact_embedding is not None
|
||||||
],
|
],
|
||||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
*[
|
||||||
|
edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid)
|
||||||
|
for edge in edges
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
|
|
@ -43,6 +42,7 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
||||||
|
from graphiti_core.utils.utils import chunk_edges_by_nodes
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -128,7 +128,7 @@ async def dedupe_nodes_bulk(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list(
|
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
||||||
|
|
@ -265,19 +265,7 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
|
||||||
return edges
|
return edges
|
||||||
# We only want to dedupe edges that are between the same pair of nodes
|
# We only want to dedupe edges that are between the same pair of nodes
|
||||||
# We build a map of the edges based on their source and target nodes.
|
# We build a map of the edges based on their source and target nodes.
|
||||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
edge_chunks = chunk_edges_by_nodes(edges)
|
||||||
for edge in edges:
|
|
||||||
# We drop loop edges
|
|
||||||
if edge.source_node_uuid == edge.target_node_uuid:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
|
||||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
|
||||||
pointers.sort()
|
|
||||||
|
|
||||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
|
||||||
|
|
||||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
|
||||||
|
|
||||||
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
@ -109,8 +110,8 @@ async def dedupe_extracted_edges(
|
||||||
existing_edges: list[EntityEdge],
|
existing_edges: list[EntityEdge],
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# Create edge map
|
# Create edge map
|
||||||
edge_map = {}
|
edge_map: dict[str, EntityEdge] = {}
|
||||||
for edge in extracted_edges:
|
for edge in existing_edges:
|
||||||
edge_map[edge.uuid] = edge
|
edge_map[edge.uuid] = edge
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
|
|
@ -124,18 +125,85 @@ async def dedupe_extracted_edges(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
||||||
unique_edge_data = llm_response.get('unique_facts', [])
|
duplicate_data = llm_response.get('duplicates', [])
|
||||||
logger.info(f'Extracted unique edges: {unique_edge_data}')
|
logger.info(f'Extracted unique edges: {duplicate_data}')
|
||||||
|
|
||||||
|
duplicate_uuid_map: dict[str, str] = {}
|
||||||
|
for duplicate in duplicate_data:
|
||||||
|
uuid_value = duplicate['duplicate_of']
|
||||||
|
duplicate_uuid_map[duplicate['uuid']] = uuid_value
|
||||||
|
|
||||||
# Get full edge data
|
# Get full edge data
|
||||||
edges = []
|
edges: list[EntityEdge] = []
|
||||||
for unique_edge in unique_edge_data:
|
for edge in extracted_edges:
|
||||||
edge = edge_map[unique_edge['uuid']]
|
if edge.uuid in duplicate_uuid_map:
|
||||||
edges.append(edge)
|
existing_uuid = duplicate_uuid_map[edge.uuid]
|
||||||
|
existing_edge = edge_map[existing_uuid]
|
||||||
|
edges.append(existing_edge)
|
||||||
|
else:
|
||||||
|
edges.append(edge)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_extracted_edges(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
extracted_edges: list[EntityEdge],
|
||||||
|
existing_edges_lists: list[list[EntityEdge]],
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
resolved_edges: list[EntityEdge] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
|
||||||
|
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return resolved_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_extracted_edge(
|
||||||
|
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
|
||||||
|
) -> EntityEdge:
|
||||||
|
start = time()
|
||||||
|
|
||||||
|
# Prepare context for LLM
|
||||||
|
existing_edges_context = [
|
||||||
|
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
||||||
|
]
|
||||||
|
|
||||||
|
extracted_edge_context = {
|
||||||
|
'uuid': extracted_edge.uuid,
|
||||||
|
'name': extracted_edge.name,
|
||||||
|
'fact': extracted_edge.fact,
|
||||||
|
}
|
||||||
|
|
||||||
|
context = {
|
||||||
|
'existing_edges': existing_edges_context,
|
||||||
|
'extracted_edges': extracted_edge_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v3(context))
|
||||||
|
|
||||||
|
is_duplicate: bool = llm_response.get('is_duplicate', False)
|
||||||
|
uuid: str | None = llm_response.get('uuid', None)
|
||||||
|
|
||||||
|
edge = extracted_edge
|
||||||
|
if is_duplicate:
|
||||||
|
for existing_edge in existing_edges:
|
||||||
|
if existing_edge.uuid != uuid:
|
||||||
|
continue
|
||||||
|
edge = existing_edge
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(
|
||||||
|
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
||||||
|
)
|
||||||
|
|
||||||
|
return edge
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_edge_list(
|
async def dedupe_edge_list(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def extract_message_nodes(
|
async def extract_message_nodes(
|
||||||
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -48,8 +49,8 @@ async def extract_message_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def extract_json_nodes(
|
async def extract_json_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -66,9 +67,9 @@ async def extract_json_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes(
|
async def extract_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
start = time()
|
start = time()
|
||||||
extracted_node_data: list[dict[str, Any]] = []
|
extracted_node_data: list[dict[str, Any]] = []
|
||||||
|
|
@ -95,29 +96,24 @@ async def extract_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_nodes(
|
async def dedupe_extracted_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
existing_nodes: list[EntityNode],
|
existing_nodes: list[EntityNode],
|
||||||
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
# build existing node map
|
# build existing node map
|
||||||
node_map: dict[str, EntityNode] = {}
|
node_map: dict[str, EntityNode] = {}
|
||||||
for node in existing_nodes:
|
for node in existing_nodes:
|
||||||
node_map[node.name] = node
|
node_map[node.uuid] = node
|
||||||
|
|
||||||
# Temp hack
|
|
||||||
new_nodes_map: dict[str, EntityNode] = {}
|
|
||||||
for node in extracted_nodes:
|
|
||||||
new_nodes_map[node.name] = node
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
existing_nodes_context = [
|
existing_nodes_context = [
|
||||||
{'name': node.name, 'summary': node.summary} for node in existing_nodes
|
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
extracted_nodes_context = [
|
extracted_nodes_context = [
|
||||||
{'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -134,42 +130,104 @@ async def dedupe_extracted_nodes(
|
||||||
|
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
for duplicate in duplicate_data:
|
for duplicate in duplicate_data:
|
||||||
uuid = new_nodes_map[duplicate['name']].uuid
|
uuid_value = duplicate['duplicate_of']
|
||||||
uuid_value = node_map[duplicate['duplicate_of']].uuid
|
uuid_map[duplicate['uuid']] = uuid_value
|
||||||
uuid_map[uuid] = uuid_value
|
|
||||||
|
|
||||||
nodes: list[EntityNode] = []
|
nodes: list[EntityNode] = []
|
||||||
brand_new_nodes: list[EntityNode] = []
|
|
||||||
for node in extracted_nodes:
|
for node in extracted_nodes:
|
||||||
if node.uuid in uuid_map:
|
if node.uuid in uuid_map:
|
||||||
existing_uuid = uuid_map[node.uuid]
|
existing_uuid = uuid_map[node.uuid]
|
||||||
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
|
existing_node = node_map[existing_uuid]
|
||||||
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
|
nodes.append(existing_node)
|
||||||
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
|
else:
|
||||||
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
|
nodes.append(node)
|
||||||
if existing_node:
|
|
||||||
nodes.append(existing_node)
|
|
||||||
|
|
||||||
continue
|
return nodes, uuid_map
|
||||||
brand_new_nodes.append(node)
|
|
||||||
nodes.append(node)
|
|
||||||
|
|
||||||
return nodes, uuid_map, brand_new_nodes
|
|
||||||
|
async def resolve_extracted_nodes(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
extracted_nodes: list[EntityNode],
|
||||||
|
existing_nodes_lists: list[list[EntityNode]],
|
||||||
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
uuid_map: dict[str, str] = {}
|
||||||
|
resolved_nodes: list[EntityNode] = []
|
||||||
|
results: list[tuple[EntityNode, dict[str, str]]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
resolve_extracted_node(llm_client, extracted_node, existing_nodes)
|
||||||
|
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
uuid_map.update(result[1])
|
||||||
|
resolved_nodes.append(result[0])
|
||||||
|
|
||||||
|
return resolved_nodes, uuid_map
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_extracted_node(
|
||||||
|
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
|
||||||
|
) -> tuple[EntityNode, dict[str, str]]:
|
||||||
|
start = time()
|
||||||
|
|
||||||
|
# Prepare context for LLM
|
||||||
|
existing_nodes_context = [
|
||||||
|
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
extracted_node_context = {
|
||||||
|
'uuid': extracted_node.uuid,
|
||||||
|
'name': extracted_node.name,
|
||||||
|
'summary': extracted_node.summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
context = {
|
||||||
|
'existing_nodes': existing_nodes_context,
|
||||||
|
'extracted_nodes': extracted_node_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v3(context))
|
||||||
|
|
||||||
|
is_duplicate: bool = llm_response.get('is_duplicate', False)
|
||||||
|
uuid: str | None = llm_response.get('uuid', None)
|
||||||
|
summary = llm_response.get('summary', '')
|
||||||
|
|
||||||
|
node = extracted_node
|
||||||
|
uuid_map: dict[str, str] = {}
|
||||||
|
if is_duplicate:
|
||||||
|
for existing_node in existing_nodes:
|
||||||
|
if existing_node.uuid != uuid:
|
||||||
|
continue
|
||||||
|
node = existing_node
|
||||||
|
node.summary = summary
|
||||||
|
uuid_map[extracted_node.uuid] = existing_node.uuid
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(
|
||||||
|
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
||||||
|
)
|
||||||
|
|
||||||
|
return node, uuid_map
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_node_list(
|
async def dedupe_node_list(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
# build node map
|
# build node map
|
||||||
node_map = {}
|
node_map = {}
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_map[node.name] = node
|
node_map[node.uuid] = node
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes]
|
nodes_context = [
|
||||||
|
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
|
||||||
|
]
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
'nodes': nodes_context,
|
'nodes': nodes_context,
|
||||||
|
|
@ -188,13 +246,12 @@ async def dedupe_node_list(
|
||||||
unique_nodes = []
|
unique_nodes = []
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
for node_data in nodes_data:
|
for node_data in nodes_data:
|
||||||
node = node_map[node_data['names'][0]]
|
node = node_map[node_data['uuids'][0]]
|
||||||
node.summary = node_data['summary']
|
node.summary = node_data['summary']
|
||||||
unique_nodes.append(node)
|
unique_nodes.append(node)
|
||||||
|
|
||||||
for name in node_data['names'][1:]:
|
for uuid in node_data['uuids'][1:]:
|
||||||
uuid = node_map[name].uuid
|
uuid_value = node_map[node_data['uuids'][0]].uuid
|
||||||
uuid_value = node_map[node_data['names'][0]].uuid
|
|
||||||
uuid_map[uuid] = uuid_value
|
uuid_map[uuid] = uuid_value
|
||||||
|
|
||||||
return unique_nodes, uuid_map
|
return unique_nodes, uuid_map
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ async def extract_edge_dates(
|
||||||
edge: EntityEdge,
|
edge: EntityEdge,
|
||||||
current_episode: EpisodicNode,
|
current_episode: EpisodicNode,
|
||||||
previous_episodes: List[EpisodicNode],
|
previous_episodes: List[EpisodicNode],
|
||||||
) -> tuple[datetime | None, datetime | None, str]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
context = {
|
context = {
|
||||||
'edge_name': edge.name,
|
'edge_name': edge.name,
|
||||||
'edge_fact': edge.fact,
|
'edge_fact': edge.fact,
|
||||||
|
|
@ -180,4 +180,4 @@ async def extract_edge_dates(
|
||||||
|
|
||||||
logger.info(f'Edge date extraction explanation: {explanation}')
|
logger.info(f'Edge date extraction explanation: {explanation}')
|
||||||
|
|
||||||
return valid_at_datetime, invalid_at_datetime, explanation
|
return valid_at_datetime, invalid_at_datetime
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,9 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from graphiti_core.edges import EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -37,3 +38,23 @@ def build_episodic_edges(
|
||||||
)
|
)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
||||||
|
# We only want to dedupe edges that are between the same pair of nodes
|
||||||
|
# We build a map of the edges based on their source and target nodes.
|
||||||
|
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||||
|
for edge in edges:
|
||||||
|
# We drop loop edges
|
||||||
|
if edge.source_node_uuid == edge.target_node_uuid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||||
|
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||||
|
pointers.sort()
|
||||||
|
|
||||||
|
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||||
|
|
||||||
|
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||||
|
|
||||||
|
return edge_chunks
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue