Bulk ingestion (#698)

* partial

* update

* update

* update

* update

* updates

* updates

* update

* update
This commit is contained in:
Preston Rasmussen 2025-07-10 12:14:49 -04:00 committed by GitHub
parent 94df836396
commit 0675ac2b7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 2351 additions and 2466 deletions

View file

@ -25,6 +25,8 @@ from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
@ -67,7 +69,7 @@ class IsPresidentOf(BaseModel):
"""Relationship between a person and the entity they are a president of"""
async def main():
async def main(use_bulk: bool = False):
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
@ -75,21 +77,43 @@ async def main():
messages = parse_podcast_messages()
group_id = str(uuid4())
for i, message in enumerate(messages[3:14]):
episodes = await client.retrieve_episodes(message.actual_timestamp, 3, group_ids=[group_id])
episode_uuids = [episode.uuid for episode in episodes]
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
raw_episodes: list[RawEpisode] = []
for i, message in enumerate(messages[3:7]):
raw_episodes.append(
RawEpisode(
name=f'Message {i}',
content=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source=EpisodeType.message,
source_description='Podcast Transcript',
)
)
if use_bulk:
await client.add_episode_bulk(
raw_episodes,
group_id=group_id,
entity_types={'Person': Person},
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
previous_episode_uuids=episode_uuids,
)
else:
for i, message in enumerate(messages[3:14]):
episodes = await client.retrieve_episodes(
message.actual_timestamp, 3, group_ids=[group_id]
)
episode_uuids = [episode.uuid for episode in episodes]
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id=group_id,
entity_types={'Person': Person},
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
previous_episode_uuids=episode_uuids,
)
asyncio.run(main())
asyncio.run(main(True))

View file

@ -44,6 +44,7 @@ class GeminiRerankerClient(CrossEncoderClient):
"""
Google Gemini Reranker Client
"""
def __init__(
self,
config: LLMConfig | None = None,

View file

@ -46,6 +46,7 @@ class GeminiEmbedder(EmbedderClient):
"""
Google Gemini Embedder Client
"""
def __init__(
self,
config: GeminiEmbedderConfig | None = None,

View file

@ -57,7 +57,6 @@ from graphiti_core.utils.bulk_utils import (
add_nodes_and_edges_bulk,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_edge_dates_bulk,
extract_nodes_and_edges_bulk,
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
@ -508,7 +507,7 @@ class Graphiti:
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
episodic_edges = build_episodic_edges(nodes, episode, now)
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
episode.entity_edges = [edge.uuid for edge in entity_edges]
@ -536,8 +535,16 @@ class Graphiti:
except Exception as e:
raise e
#### WIP: USE AT YOUR OWN RISK ####
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
##### EXPERIMENTAL #####
async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
group_id: str = '',
entity_types: dict[str, BaseModel] | None = None,
excluded_entity_types: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
):
"""
Process multiple episodes in bulk and update the graph.
@ -580,8 +587,17 @@ class Graphiti:
validate_group_id(group_id)
# Create default edge type map
edge_type_map_default = (
{('Entity', 'Entity'): list(edge_types.keys())}
if edge_types is not None
else {('Entity', 'Entity'): []}
)
episodes = [
EpisodicNode(
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
if episode.uuid is not None
else EpisodicNode(
name=episode.name,
labels=[],
source=episode.source,
@ -594,68 +610,106 @@ class Graphiti:
for episode in bulk_episodes
]
# Save all the episodes
await semaphore_gather(
*[episode.save(self.driver) for episode in episodes],
max_coroutines=self.max_coroutines,
episodes_by_uuid: dict[str, EpisodicNode] = {
episode.uuid: episode for episode in episodes
}
# Save all episodes
await add_nodes_and_edges_bulk(
driver=self.driver,
episodic_nodes=episodes,
episodic_edges=[],
entity_nodes=[],
entity_edges=[],
embedder=self.embedder,
)
# Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
# Extract all nodes and edges
(
extracted_nodes,
extracted_edges,
episodic_edges,
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
# Generate embeddings
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
max_coroutines=self.max_coroutines,
# Extract all nodes and edges for each episode
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
self.clients,
episode_context,
edge_type_map=edge_type_map or edge_type_map_default,
edge_types=edge_types,
entity_types=entity_types,
excluded_entity_types=excluded_entity_types,
)
# Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
max_coroutines=self.max_coroutines,
# Dedupe extracted nodes in memory
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
self.clients, extracted_nodes_bulk, episode_context, entity_types
)
# save nodes to KG
await semaphore_gather(
*[node.save(self.driver) for node in nodes],
max_coroutines=self.max_coroutines,
)
episodic_edges: list[EpisodicEdge] = []
for episode_uuid, nodes in nodes_by_episode.items():
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
extracted_edges_timestamped, uuid_map
)
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
]
# Dedupe extracted edges in memory
edges_by_episode = await dedupe_edges_bulk(
self.clients,
extracted_edges_bulk_updated,
episode_context,
[],
edge_types or {},
edge_type_map or edge_type_map_default,
)
# save episodic edges to KG
await semaphore_gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
max_coroutines=self.max_coroutines,
# Extract node attributes
nodes_by_uuid: dict[str, EntityNode] = {
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
}
extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
for node in nodes_by_uuid.values():
episode_uuids: list[str] = []
for episode_uuid, mentioned_nodes in nodes_by_episode.items():
for mentioned_node in mentioned_nodes:
if node.uuid == mentioned_node.uuid:
episode_uuids.append(episode_uuid)
break
episode_mentions: list[EpisodicNode] = [
episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
]
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
extract_attributes_params.append((node, episode_mentions))
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
*[
extract_attributes_from_nodes(
self.clients,
[params[0]],
params[1][0],
params[1][0:],
entity_types,
)
for params in extract_attributes_params
]
)
# Dedupe extracted edges
edges = await dedupe_edges_bulk(
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
)
logger.debug(f'extracted edge length: {len(edges)}')
hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
# invalidate edges
# TODO: Resolve nodes and edges against the existing graph
edges_by_uuid: dict[str, EntityEdge] = {
edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
}
# save edges to KG
await semaphore_gather(
*[edge.save(self.driver) for edge in edges],
max_coroutines=self.max_coroutines,
# save data to KG
await add_nodes_and_edges_bulk(
self.driver,
episodes,
episodic_edges,
hydrated_nodes,
list(edges_by_uuid.values()),
self.embedder,
)
end = time()
@ -828,7 +882,7 @@ class Graphiti:
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
)[0]
resolved_edge, invalidated_edges = await resolve_extracted_edge(
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
self.llm_client,
updated_edge,
related_edges,

View file

@ -23,9 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
class EdgeDuplicate(BaseModel):
duplicate_fact_id: int = Field(
duplicate_facts: list[int] = Field(
...,
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
description='List of ids of any duplicate facts. If no duplicate facts are found, default to empty list.',
)
contradicted_facts: list[int] = Field(
...,
@ -75,8 +75,9 @@ def edge(context: dict[str, Any]) -> list[Message]:
</NEW EDGE>
Task:
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact.
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return -1.
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
as part of the list of duplicate_facts.
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return an empty list.
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.

View file

@ -32,9 +32,9 @@ class NodeDuplicate(BaseModel):
...,
description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
)
additional_duplicates: list[int] = Field(
duplicates: list[int] = Field(
...,
description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
description='idx of all duplicate entities.',
)
@ -94,7 +94,7 @@ def node(context: dict[str, Any]) -> list[Message]:
1. Compare `new_entity` against each item in `existing_entities`.
2. If it refers to the same realworld object or concept, collect its index.
3. Let `duplicate_idx` = the *first* collected index, or 1 if none.
4. Let `additional_duplicates` = the list of *any other* collected indices (empty list if none).
4. Let `duplicates` = the list of *all* collected indices (empty list if none).
Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
is a duplicate of, or a combination of the two).

View file

@ -16,50 +16,40 @@ limitations under the License.
import logging
import typing
from collections import defaultdict
from datetime import datetime
from math import ceil
from numpy import dot, sqrt
from pydantic import BaseModel
import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graph_queries import (
get_entity_edge_save_bulk_query,
get_entity_node_save_bulk_query,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.helpers import DEFAULT_DATABASE, normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
EPISODIC_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.utils.maintenance.edge_operations import (
build_episodic_edges,
dedupe_edge_list,
dedupe_extracted_edges,
extract_edges,
resolve_extracted_edge,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
dedupe_extracted_nodes,
dedupe_node_list,
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
logger = logging.getLogger(__name__)
@ -68,6 +58,7 @@ CHUNK_SIZE = 10
class RawEpisode(BaseModel):
name: str
uuid: str | None = Field(default=None)
content: str
source_description: str
source: EpisodeType
@ -179,233 +170,258 @@ async def add_nodes_and_edges_bulk_tx(
async def extract_nodes_and_edges_bulk(
clients: GraphitiClients,
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
edge_type_map: dict[tuple[str, str], list[str]],
entity_types: dict[str, BaseModel] | None = None,
excluded_entity_types: list[str] | None = None,
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await semaphore_gather(
edge_types: dict[str, BaseModel] | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
*[
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
for episode, previous_episodes in episode_tuples
]
)
episodes, previous_episodes_list = (
[episode[0] for episode in episode_tuples],
[episode[1] for episode in episode_tuples],
)
extracted_edges_bulk = await semaphore_gather(
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
*[
extract_edges(
clients,
episode,
extracted_nodes_bulk[i],
previous_episodes_list[i],
{},
episode.group_id,
previous_episodes,
edge_type_map=edge_type_map,
group_id=episode.group_id,
edge_types=edge_types,
)
for i, episode in enumerate(episodes)
for i, (episode, previous_episodes) in enumerate(episode_tuples)
]
)
episodic_edges: list[EpisodicEdge] = []
for i, episode in enumerate(episodes):
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
nodes: list[EntityNode] = []
for extracted_nodes in extracted_nodes_bulk:
nodes += extracted_nodes
edges: list[EntityEdge] = []
for extracted_edges in extracted_edges_bulk:
edges += extracted_edges
return nodes, edges, episodic_edges
return extracted_nodes_bulk, extracted_edges_bulk
async def dedupe_nodes_bulk(
driver: GraphDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
# Compress nodes
nodes, uuid_map = node_name_match(extracted_nodes)
clients: GraphitiClients,
extracted_nodes: list[list[EntityNode]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder
min_score = 0.8
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
existing_nodes_chunks: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
)
# generate embeddings
await semaphore_gather(
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
)
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
await semaphore_gather(
*[
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
for i, node_chunk in enumerate(node_chunks)
]
)
# Find similar results
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
for i, nodes_i in enumerate(extracted_nodes):
existing_nodes: list[EntityNode] = []
for j, nodes_j in enumerate(extracted_nodes):
if i == j:
continue
existing_nodes += nodes_j
candidates_i: list[EntityNode] = []
for node in nodes_i:
for existing_node in existing_nodes:
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
node_words = set(node.name.lower().split())
existing_node_words = set(existing_node.name.lower().split())
has_overlap = not node_words.isdisjoint(existing_node_words)
if has_overlap:
candidates_i.append(existing_node)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(node.name_embedding or []),
normalize_l2(existing_node.name_embedding or []),
)
if similarity >= min_score:
candidates_i.append(existing_node)
dedupe_tuples.append((nodes_i, candidates_i))
# Determine Node Resolutions
bulk_node_resolutions: list[
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
] = await semaphore_gather(
*[
resolve_extracted_nodes(
clients,
dedupe_tuple[0],
episode_tuples[i][0],
episode_tuples[i][1],
entity_types,
existing_nodes_override=dedupe_tuples[i][1],
)
for i, dedupe_tuple in enumerate(dedupe_tuples)
]
)
final_nodes: list[EntityNode] = []
for result in results:
final_nodes.extend(result[0])
partial_uuid_map = result[1]
compressed_map.update(partial_uuid_map)
# Collect all duplicate pairs sorted by uuid
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
for _, _, duplicates in bulk_node_resolutions:
for duplicate in duplicates:
n, m = duplicate
if n.uuid < m.uuid:
duplicate_pairs.append((n, m))
else:
duplicate_pairs.append((m, n))
return final_nodes, compressed_map
# Build full deduplication map
duplicate_map: dict[str, str] = {}
for value, key in duplicate_pairs:
if key.uuid in duplicate_map:
existing_value = duplicate_map[key.uuid]
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
else:
duplicate_map[key.uuid] = value.uuid
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for nodes in extracted_nodes for node in nodes
}
nodes_by_episode: dict[str, list[EntityNode]] = {}
for i, nodes in enumerate(extracted_nodes):
episode = episode_tuples[i][0]
nodes_by_episode[episode.uuid] = [
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
]
return nodes_by_episode, compressed_map
async def dedupe_edges_bulk(
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
) -> list[EntityEdge]:
# First compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges)
clients: GraphitiClients,
extracted_edges: list[list[EntityEdge]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
_entities: list[EntityNode],
edge_types: dict[str, BaseModel],
_edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]:
embedder = clients.embedder
min_score = 0.6
edge_chunks = [
compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
]
relevant_edges_chunks: list[list[EntityEdge]] = list(
await semaphore_gather(
*[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
)
# generate embeddings
await semaphore_gather(
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
)
resolved_edge_chunks: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
for i, edge_chunk in enumerate(edge_chunks)
]
)
# Find similar results
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
for i, edges_i in enumerate(extracted_edges):
existing_edges: list[EntityEdge] = []
for j, edges_j in enumerate(extracted_edges):
if i == j:
continue
existing_edges += edges_j
for edge in edges_i:
candidates: list[EntityEdge] = []
for existing_edge in existing_edges:
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
edge_words = set(edge.fact.lower().split())
existing_edge_words = set(existing_edge.fact.lower().split())
has_overlap = not edge_words.isdisjoint(existing_edge_words)
if has_overlap:
candidates.append(existing_edge)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(edge.fact_embedding or []),
normalize_l2(existing_edge.fact_embedding or []),
)
if similarity >= min_score:
candidates.append(existing_edge)
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
bulk_edge_resolutions: list[
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
] = await semaphore_gather(
*[
resolve_extracted_edge(
clients.llm_client, edge, candidates, candidates, episode, edge_types
)
for episode, edge, candidates in dedupe_tuples
]
)
edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
return edges
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
episode, edge, candidates = dedupe_tuples[i]
for duplicate in duplicates:
if edge.uuid < duplicate.uuid:
duplicate_pairs.append((edge, duplicate))
else:
duplicate_pairs.append((duplicate, edge))
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
name_map: dict[str, EntityNode] = {}
for node in nodes:
if node.name in name_map:
uuid_map[node.uuid] = name_map[node.name].uuid
continue
name_map[node.name] = node
return [node for node in name_map.values()], uuid_map
async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]:
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
if len(nodes) == 0:
return nodes, uuid_map
# Our approach involves us deduplicating chunks of nodes in parallel.
# We want n chunks of size n so that n ** 2 == len(nodes).
# We want chunk sizes to be at least 10 for optimizing LLM processing time
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
# First calculate similarity scores between nodes
similarity_scores: list[tuple[int, int, float]] = [
(i, j, dot(n.name_embedding or [], m.name_embedding or []))
for i, n in enumerate(nodes)
for j, m in enumerate(nodes[:i])
]
# We now sort by semantic similarity
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
# initialize our chunks based on chunk size
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
# Draft the most similar nodes into the same chunk
while len(similarity_scores) > 0:
i, j, _ = similarity_scores.pop()
# determine if any of the nodes have already been drafted into a chunk
n = nodes[i]
m = nodes[j]
# make sure the shortest chunks get preference
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
# both nodes already in a chunk
if n_chunk > -1 and m_chunk > -1:
continue
# n has a chunk and that chunk is not full
elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
# put m in the same chunk as n
node_chunks[n_chunk].append(m)
# m has a chunk and that chunk is not full
elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
# put n in the same chunk as m
node_chunks[m_chunk].append(n)
# neither node has a chunk or the chunk is full
# Build full deduplication map
duplicate_map: dict[str, str] = {}
for value, key in duplicate_pairs:
if key.uuid in duplicate_map:
existing_value = duplicate_map[key.uuid]
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
else:
# add both nodes to the shortest chunk
node_chunks[-1].extend([n, m])
duplicate_map[key.uuid] = value.uuid
results = await semaphore_gather(
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
)
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = []
for node_chunk, uuid_map_chunk in results:
compressed_nodes += node_chunk
extended_map.update(uuid_map_chunk)
edge_uuid_map: dict[str, EntityEdge] = {
edge.uuid: edge for edges in extracted_edges for edge in edges
}
# Check if we have removed all duplicates
if len(compressed_nodes) == len(nodes):
compressed_uuid_map = compress_uuid_map(extended_map)
return compressed_nodes, compressed_uuid_map
edges_by_episode: dict[str, list[EntityEdge]] = {}
for i, edges in enumerate(extracted_edges):
episode = episode_tuples[i][0]
return await compress_nodes(llm_client, compressed_nodes, extended_map)
edges_by_episode[episode.uuid] = [
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
]
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
if len(edges) == 0:
return edges
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunks = chunk_edges_by_nodes(edges)
results = await semaphore_gather(
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
)
compressed_edges: list[EntityEdge] = []
for edge_chunk in results:
compressed_edges += edge_chunk
# Check if we have removed all duplicates
if len(compressed_edges) == len(edges):
return compressed_edges
return await compress_edges(llm_client, compressed_edges)
return edges_by_episode
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
# make sure all uuid values aren't mapped to other uuids
compressed_map = {}
for key, uuid in uuid_map.items():
curr_value = uuid
while curr_value in uuid_map:
curr_value = uuid_map[curr_value]
compressed_map[key] = curr_value
def find_min_uuid(start: str) -> str:
path = []
visited = set()
curr = start
while curr in uuid_map and curr not in visited:
visited.add(curr)
path.append(curr)
curr = uuid_map[curr]
# Also include the last resolved value (could be outside the map)
path.append(curr)
# Resolve to lex smallest UUID in the path
min_uuid = min(path)
# Assign all UUIDs in the path to the min_uuid
for node in path:
compressed_map[node] = min_uuid
return min_uuid
for key in uuid_map:
if key not in compressed_map:
find_min_uuid(key)
return compressed_map
@ -420,63 +436,3 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
return edges
async def extract_edge_dates_bulk(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
) -> list[EntityEdge]:
edges: list[EntityEdge] = []
# confirm that all of our edges have at least one episode
for edge in extracted_edges:
if edge.episodes is not None and len(edge.episodes) > 0:
edges.append(edge)
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
}
results = await semaphore_gather(
*[
extract_edge_dates(
llm_client,
edge,
episode_uuid_map[edge.episodes[0]][0], # type: ignore
episode_uuid_map[edge.episodes[0]][1], # type: ignore
)
for edge in edges
]
)
for i, result in enumerate(results):
valid_at = result[0]
invalid_at = result[1]
edge = edges[i]
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = utc_now()
return edges
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
return edge_chunks

View file

@ -45,15 +45,15 @@ logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: list[EntityNode],
episode: EpisodicNode,
episode_uuid: str,
created_at: datetime,
) -> list[EpisodicEdge]:
episodic_edges: list[EpisodicEdge] = [
EpisodicEdge(
source_node_uuid=episode.uuid,
source_node_uuid=episode_uuid,
target_node_uuid=node.uuid,
created_at=created_at,
group_id=episode.group_id,
group_id=node.group_id,
)
for node in entity_nodes
]
@ -68,19 +68,23 @@ def build_duplicate_of_edges(
created_at: datetime,
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
) -> list[EntityEdge]:
is_duplicate_of_edges: list[EntityEdge] = [
EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='IS_DUPLICATE_OF',
group_id=episode.group_id,
fact=f'{source_node.name} is a duplicate of {target_node.name}',
episodes=[episode.uuid],
created_at=created_at,
valid_at=created_at,
is_duplicate_of_edges: list[EntityEdge] = []
for source_node, target_node in duplicate_nodes:
if source_node.uuid == target_node.uuid:
continue
is_duplicate_of_edges.append(
EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='IS_DUPLICATE_OF',
group_id=episode.group_id,
fact=f'{source_node.name} is a duplicate of {target_node.name}',
episodes=[episode.uuid],
created_at=created_at,
valid_at=created_at,
)
)
for source_node, target_node in duplicate_nodes
]
return is_duplicate_of_edges
@ -240,50 +244,6 @@ async def extract_edges(
return edges
async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
) -> list[EntityEdge]:
# Create edge map
edge_map: dict[str, EntityEdge] = {}
for edge in existing_edges:
edge_map[edge.uuid] = edge
# Prepare context for LLM
context = {
'extracted_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
],
'existing_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
],
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.edge(context))
duplicate_data = llm_response.get('duplicates', [])
logger.debug(f'Extracted unique edges: {duplicate_data}')
duplicate_uuid_map: dict[str, str] = {}
for duplicate in duplicate_data:
uuid_value = duplicate['duplicate_of']
duplicate_uuid_map[duplicate['uuid']] = uuid_value
# Get full edge data
edges: list[EntityEdge] = []
for edge in extracted_edges:
if edge.uuid in duplicate_uuid_map:
existing_uuid = duplicate_uuid_map[edge.uuid]
existing_edge = edge_map[existing_uuid]
# Add current episode to the episodes list
existing_edge.episodes += edge.episodes
edges.append(existing_edge)
else:
edges.append(edge)
return edges
async def resolve_extracted_edges(
clients: GraphitiClients,
extracted_edges: list[EntityEdge],
@ -335,7 +295,7 @@ async def resolve_extracted_edges(
edge_types_lst.append(extracted_edge_types)
# resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
await semaphore_gather(
*[
resolve_extracted_edge(
@ -416,9 +376,9 @@ async def resolve_extracted_edge(
existing_edges: list[EntityEdge],
episode: EpisodicNode,
edge_types: dict[str, BaseModel] | None = None,
) -> tuple[EntityEdge, list[EntityEdge]]:
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, []
return extracted_edge, [], []
start = time()
@ -457,15 +417,16 @@ async def resolve_extracted_edge(
model_size=ModelSize.small,
)
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
resolved_edge = (
related_edges[duplicate_fact_id]
if 0 <= duplicate_fact_id < len(related_edges)
else extracted_edge
duplicate_fact_ids: list[int] = list(
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
)
if duplicate_fact_id >= 0 and episode is not None:
resolved_edge = extracted_edge
for duplicate_fact_id in duplicate_fact_ids:
resolved_edge = related_edges[duplicate_fact_id]
break
if duplicate_fact_ids and episode is not None:
resolved_edge.episodes.append(episode.uuid)
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
@ -519,59 +480,12 @@ async def resolve_extracted_edge(
break
# Determine which contradictory edges need to be expired
invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
return resolved_edge, invalidated_edges
async def dedupe_extracted_edge(
llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
episode: EpisodicNode | None = None,
) -> EntityEdge:
if len(related_edges) == 0:
return extracted_edge
start = time()
# Prepare context for LLM
related_edges_context = [
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
]
extracted_edge_context = {
'fact': extracted_edge.fact,
}
context = {
'related_edges': related_edges_context,
'extracted_edges': extracted_edge_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge(context),
response_model=EdgeDuplicate,
model_size=ModelSize.small,
invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
resolved_edge, invalidation_candidates
)
duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
edge = (
related_edges[duplicate_fact_id]
if 0 <= duplicate_fact_id < len(related_edges)
else extracted_edge
)
if duplicate_fact_id >= 0 and episode is not None:
edge.episodes.append(episode.uuid)
end = time()
logger.debug(
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
)
return edge
return resolved_edge, invalidated_edges, duplicate_edges
async def dedupe_edge_list(

View file

@ -176,62 +176,13 @@ async def extract_nodes(
return extracted_nodes
async def dedupe_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
# build existing node map
node_map: dict[str, EntityNode] = {}
for node in existing_nodes:
node_map[node.uuid] = node
# Prepare context for LLM
existing_nodes_context = [
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
]
extracted_nodes_context = [
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
]
context = {
'existing_nodes': existing_nodes_context,
'extracted_nodes': extracted_nodes_context,
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.node(context))
duplicate_data = llm_response.get('duplicates', [])
end = time()
logger.debug(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
uuid_map: dict[str, str] = {}
for duplicate in duplicate_data:
uuid_value = duplicate['duplicate_of']
uuid_map[duplicate['uuid']] = uuid_value
nodes: list[EntityNode] = []
for node in extracted_nodes:
if node.uuid in uuid_map:
existing_uuid = uuid_map[node.uuid]
existing_node = node_map[existing_uuid]
nodes.append(existing_node)
else:
nodes.append(node)
return nodes, uuid_map
async def resolve_extracted_nodes(
clients: GraphitiClients,
extracted_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
llm_client = clients.llm_client
driver = clients.driver
@ -249,9 +200,13 @@ async def resolve_extracted_nodes(
]
)
existing_nodes_dict: dict[str, EntityNode] = {
node.uuid: node for result in search_results for node in result.nodes
}
candidate_nodes: list[EntityNode] = (
[node for result in search_results for node in result.nodes]
if existing_nodes_override is None
else existing_nodes_override
)
existing_nodes_dict: dict[str, EntityNode] = {node.uuid: node for node in candidate_nodes}
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
@ -321,13 +276,11 @@ async def resolve_extracted_nodes(
resolved_nodes.append(resolved_node)
uuid_map[extracted_node.uuid] = resolved_node.uuid
additional_duplicates: list[int] = resolution.get('additional_duplicates', [])
for idx in additional_duplicates:
duplicates: list[int] = resolution.get('duplicates', [])
for idx in duplicates:
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
if existing_node == resolved_node:
continue
node_duplicates.append((resolved_node, existing_nodes[idx]))
node_duplicates.append((resolved_node, existing_node))
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')

3899
uv.lock generated

File diff suppressed because it is too large Load diff