Bulk ingestion (#698)
* partial * update * update * update * update * updates * updates * update * update
This commit is contained in:
parent
94df836396
commit
0675ac2b7d
10 changed files with 2351 additions and 2466 deletions
|
|
@ -25,6 +25,8 @@ from pydantic import BaseModel, Field
|
||||||
from transcript_parser import parse_podcast_messages
|
from transcript_parser import parse_podcast_messages
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.nodes import EpisodeType
|
||||||
|
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -67,7 +69,7 @@ class IsPresidentOf(BaseModel):
|
||||||
"""Relationship between a person and the entity they are a president of"""
|
"""Relationship between a person and the entity they are a president of"""
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main(use_bulk: bool = False):
|
||||||
setup_logging()
|
setup_logging()
|
||||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
|
|
@ -75,21 +77,43 @@ async def main():
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
group_id = str(uuid4())
|
group_id = str(uuid4())
|
||||||
|
|
||||||
for i, message in enumerate(messages[3:14]):
|
raw_episodes: list[RawEpisode] = []
|
||||||
episodes = await client.retrieve_episodes(message.actual_timestamp, 3, group_ids=[group_id])
|
for i, message in enumerate(messages[3:7]):
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
raw_episodes.append(
|
||||||
|
RawEpisode(
|
||||||
await client.add_episode(
|
name=f'Message {i}',
|
||||||
name=f'Message {i}',
|
content=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
reference_time=message.actual_timestamp,
|
||||||
reference_time=message.actual_timestamp,
|
source=EpisodeType.message,
|
||||||
source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if use_bulk:
|
||||||
|
await client.add_episode_bulk(
|
||||||
|
raw_episodes,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
entity_types={'Person': Person},
|
entity_types={'Person': Person},
|
||||||
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
|
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
|
||||||
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
|
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
|
||||||
previous_episode_uuids=episode_uuids,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
for i, message in enumerate(messages[3:14]):
|
||||||
|
episodes = await client.retrieve_episodes(
|
||||||
|
message.actual_timestamp, 3, group_ids=[group_id]
|
||||||
|
)
|
||||||
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
|
|
||||||
|
await client.add_episode(
|
||||||
|
name=f'Message {i}',
|
||||||
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
|
reference_time=message.actual_timestamp,
|
||||||
|
source_description='Podcast Transcript',
|
||||||
|
group_id=group_id,
|
||||||
|
entity_types={'Person': Person},
|
||||||
|
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
|
||||||
|
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
|
||||||
|
previous_episode_uuids=episode_uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main(True))
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ class GeminiRerankerClient(CrossEncoderClient):
|
||||||
"""
|
"""
|
||||||
Google Gemini Reranker Client
|
Google Gemini Reranker Client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ class GeminiEmbedder(EmbedderClient):
|
||||||
"""
|
"""
|
||||||
Google Gemini Embedder Client
|
Google Gemini Embedder Client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GeminiEmbedderConfig | None = None,
|
config: GeminiEmbedderConfig | None = None,
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,6 @@ from graphiti_core.utils.bulk_utils import (
|
||||||
add_nodes_and_edges_bulk,
|
add_nodes_and_edges_bulk,
|
||||||
dedupe_edges_bulk,
|
dedupe_edges_bulk,
|
||||||
dedupe_nodes_bulk,
|
dedupe_nodes_bulk,
|
||||||
extract_edge_dates_bulk,
|
|
||||||
extract_nodes_and_edges_bulk,
|
extract_nodes_and_edges_bulk,
|
||||||
resolve_edge_pointers,
|
resolve_edge_pointers,
|
||||||
retrieve_previous_episodes_bulk,
|
retrieve_previous_episodes_bulk,
|
||||||
|
|
@ -508,7 +507,7 @@ class Graphiti:
|
||||||
|
|
||||||
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
||||||
|
|
||||||
episodic_edges = build_episodic_edges(nodes, episode, now)
|
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
||||||
|
|
||||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||||
|
|
||||||
|
|
@ -536,8 +535,16 @@ class Graphiti:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
#### WIP: USE AT YOUR OWN RISK ####
|
##### EXPERIMENTAL #####
|
||||||
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
|
async def add_episode_bulk(
|
||||||
|
self,
|
||||||
|
bulk_episodes: list[RawEpisode],
|
||||||
|
group_id: str = '',
|
||||||
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
|
excluded_entity_types: list[str] | None = None,
|
||||||
|
edge_types: dict[str, BaseModel] | None = None,
|
||||||
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Process multiple episodes in bulk and update the graph.
|
Process multiple episodes in bulk and update the graph.
|
||||||
|
|
||||||
|
|
@ -580,8 +587,17 @@ class Graphiti:
|
||||||
|
|
||||||
validate_group_id(group_id)
|
validate_group_id(group_id)
|
||||||
|
|
||||||
|
# Create default edge type map
|
||||||
|
edge_type_map_default = (
|
||||||
|
{('Entity', 'Entity'): list(edge_types.keys())}
|
||||||
|
if edge_types is not None
|
||||||
|
else {('Entity', 'Entity'): []}
|
||||||
|
)
|
||||||
|
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
||||||
|
if episode.uuid is not None
|
||||||
|
else EpisodicNode(
|
||||||
name=episode.name,
|
name=episode.name,
|
||||||
labels=[],
|
labels=[],
|
||||||
source=episode.source,
|
source=episode.source,
|
||||||
|
|
@ -594,68 +610,106 @@ class Graphiti:
|
||||||
for episode in bulk_episodes
|
for episode in bulk_episodes
|
||||||
]
|
]
|
||||||
|
|
||||||
# Save all the episodes
|
episodes_by_uuid: dict[str, EpisodicNode] = {
|
||||||
await semaphore_gather(
|
episode.uuid: episode for episode in episodes
|
||||||
*[episode.save(self.driver) for episode in episodes],
|
}
|
||||||
max_coroutines=self.max_coroutines,
|
|
||||||
|
# Save all episodes
|
||||||
|
await add_nodes_and_edges_bulk(
|
||||||
|
driver=self.driver,
|
||||||
|
episodic_nodes=episodes,
|
||||||
|
episodic_edges=[],
|
||||||
|
entity_nodes=[],
|
||||||
|
entity_edges=[],
|
||||||
|
embedder=self.embedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get previous episode context for each episode
|
# Get previous episode context for each episode
|
||||||
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||||
|
|
||||||
# Extract all nodes and edges
|
# Extract all nodes and edges for each episode
|
||||||
(
|
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
||||||
extracted_nodes,
|
self.clients,
|
||||||
extracted_edges,
|
episode_context,
|
||||||
episodic_edges,
|
edge_type_map=edge_type_map or edge_type_map_default,
|
||||||
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
|
edge_types=edge_types,
|
||||||
|
entity_types=entity_types,
|
||||||
# Generate embeddings
|
excluded_entity_types=excluded_entity_types,
|
||||||
await semaphore_gather(
|
|
||||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
|
||||||
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
|
||||||
max_coroutines=self.max_coroutines,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dedupe extracted nodes, compress extracted edges
|
# Dedupe extracted nodes in memory
|
||||||
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
||||||
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
||||||
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
|
||||||
max_coroutines=self.max_coroutines,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# save nodes to KG
|
episodic_edges: list[EpisodicEdge] = []
|
||||||
await semaphore_gather(
|
for episode_uuid, nodes in nodes_by_episode.items():
|
||||||
*[node.save(self.driver) for node in nodes],
|
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
||||||
max_coroutines=self.max_coroutines,
|
|
||||||
)
|
|
||||||
|
|
||||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
||||||
extracted_edges_timestamped, uuid_map
|
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
||||||
)
|
]
|
||||||
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
|
||||||
episodic_edges, uuid_map
|
# Dedupe extracted edges in memory
|
||||||
|
edges_by_episode = await dedupe_edges_bulk(
|
||||||
|
self.clients,
|
||||||
|
extracted_edges_bulk_updated,
|
||||||
|
episode_context,
|
||||||
|
[],
|
||||||
|
edge_types or {},
|
||||||
|
edge_type_map or edge_type_map_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save episodic edges to KG
|
# Extract node attributes
|
||||||
await semaphore_gather(
|
nodes_by_uuid: dict[str, EntityNode] = {
|
||||||
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
|
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
||||||
max_coroutines=self.max_coroutines,
|
}
|
||||||
|
|
||||||
|
extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
|
||||||
|
for node in nodes_by_uuid.values():
|
||||||
|
episode_uuids: list[str] = []
|
||||||
|
for episode_uuid, mentioned_nodes in nodes_by_episode.items():
|
||||||
|
for mentioned_node in mentioned_nodes:
|
||||||
|
if node.uuid == mentioned_node.uuid:
|
||||||
|
episode_uuids.append(episode_uuid)
|
||||||
|
break
|
||||||
|
|
||||||
|
episode_mentions: list[EpisodicNode] = [
|
||||||
|
episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
|
||||||
|
]
|
||||||
|
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
|
||||||
|
|
||||||
|
extract_attributes_params.append((node, episode_mentions))
|
||||||
|
|
||||||
|
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
|
||||||
|
*[
|
||||||
|
extract_attributes_from_nodes(
|
||||||
|
self.clients,
|
||||||
|
[params[0]],
|
||||||
|
params[1][0],
|
||||||
|
params[1][0:],
|
||||||
|
entity_types,
|
||||||
|
)
|
||||||
|
for params in extract_attributes_params
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dedupe extracted edges
|
hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
|
||||||
edges = await dedupe_edges_bulk(
|
|
||||||
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
|
||||||
)
|
|
||||||
logger.debug(f'extracted edge length: {len(edges)}')
|
|
||||||
|
|
||||||
# invalidate edges
|
# TODO: Resolve nodes and edges against the existing graph
|
||||||
|
edges_by_uuid: dict[str, EntityEdge] = {
|
||||||
|
edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
|
||||||
|
}
|
||||||
|
|
||||||
# save edges to KG
|
# save data to KG
|
||||||
await semaphore_gather(
|
await add_nodes_and_edges_bulk(
|
||||||
*[edge.save(self.driver) for edge in edges],
|
self.driver,
|
||||||
max_coroutines=self.max_coroutines,
|
episodes,
|
||||||
|
episodic_edges,
|
||||||
|
hydrated_nodes,
|
||||||
|
list(edges_by_uuid.values()),
|
||||||
|
self.embedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
|
|
@ -828,7 +882,7 @@ class Graphiti:
|
||||||
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
||||||
self.llm_client,
|
self.llm_client,
|
||||||
updated_edge,
|
updated_edge,
|
||||||
related_edges,
|
related_edges,
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class EdgeDuplicate(BaseModel):
|
class EdgeDuplicate(BaseModel):
|
||||||
duplicate_fact_id: int = Field(
|
duplicate_facts: list[int] = Field(
|
||||||
...,
|
...,
|
||||||
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
|
description='List of ids of any duplicate facts. If no duplicate facts are found, default to empty list.',
|
||||||
)
|
)
|
||||||
contradicted_facts: list[int] = Field(
|
contradicted_facts: list[int] = Field(
|
||||||
...,
|
...,
|
||||||
|
|
@ -75,8 +75,9 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
</NEW EDGE>
|
</NEW EDGE>
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact.
|
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
|
||||||
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return -1.
|
as part of the list of duplicate_facts.
|
||||||
|
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return an empty list.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,9 @@ class NodeDuplicate(BaseModel):
|
||||||
...,
|
...,
|
||||||
description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
|
description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
|
||||||
)
|
)
|
||||||
additional_duplicates: list[int] = Field(
|
duplicates: list[int] = Field(
|
||||||
...,
|
...,
|
||||||
description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
|
description='idx of all duplicate entities.',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -94,7 +94,7 @@ def node(context: dict[str, Any]) -> list[Message]:
|
||||||
1. Compare `new_entity` against each item in `existing_entities`.
|
1. Compare `new_entity` against each item in `existing_entities`.
|
||||||
2. If it refers to the same real‐world object or concept, collect its index.
|
2. If it refers to the same real‐world object or concept, collect its index.
|
||||||
3. Let `duplicate_idx` = the *first* collected index, or –1 if none.
|
3. Let `duplicate_idx` = the *first* collected index, or –1 if none.
|
||||||
4. Let `additional_duplicates` = the list of *any other* collected indices (empty list if none).
|
4. Let `duplicates` = the list of *all* collected indices (empty list if none).
|
||||||
|
|
||||||
Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
|
Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
|
||||||
is a duplicate of, or a combination of the two).
|
is a duplicate of, or a combination of the two).
|
||||||
|
|
|
||||||
|
|
@ -16,50 +16,40 @@ limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from math import ceil
|
|
||||||
|
|
||||||
from numpy import dot, sqrt
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Any
|
from typing_extensions import Any
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.graph_queries import (
|
from graphiti_core.graph_queries import (
|
||||||
get_entity_edge_save_bulk_query,
|
get_entity_edge_save_bulk_query,
|
||||||
get_entity_node_save_bulk_query,
|
get_entity_node_save_bulk_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, normalize_l2, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
EPISODIC_EDGE_SAVE_BULK,
|
EPISODIC_EDGE_SAVE_BULK,
|
||||||
)
|
)
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
EPISODIC_NODE_SAVE_BULK,
|
EPISODIC_NODE_SAVE_BULK,
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
|
||||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
build_episodic_edges,
|
|
||||||
dedupe_edge_list,
|
|
||||||
dedupe_extracted_edges,
|
|
||||||
extract_edges,
|
extract_edges,
|
||||||
|
resolve_extracted_edge,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
EPISODE_WINDOW_LEN,
|
EPISODE_WINDOW_LEN,
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.node_operations import (
|
from graphiti_core.utils.maintenance.node_operations import (
|
||||||
dedupe_extracted_nodes,
|
|
||||||
dedupe_node_list,
|
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
|
resolve_extracted_nodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -68,6 +58,7 @@ CHUNK_SIZE = 10
|
||||||
|
|
||||||
class RawEpisode(BaseModel):
|
class RawEpisode(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
uuid: str | None = Field(default=None)
|
||||||
content: str
|
content: str
|
||||||
source_description: str
|
source_description: str
|
||||||
source: EpisodeType
|
source: EpisodeType
|
||||||
|
|
@ -179,233 +170,258 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
clients: GraphitiClients,
|
clients: GraphitiClients,
|
||||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||||
|
edge_type_map: dict[tuple[str, str], list[str]],
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
excluded_entity_types: list[str] | None = None,
|
excluded_entity_types: list[str] | None = None,
|
||||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
edge_types: dict[str, BaseModel] | None = None,
|
||||||
extracted_nodes_bulk = await semaphore_gather(
|
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
||||||
|
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
||||||
for episode, previous_episodes in episode_tuples
|
for episode, previous_episodes in episode_tuples
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes, previous_episodes_list = (
|
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
|
||||||
[episode[0] for episode in episode_tuples],
|
|
||||||
[episode[1] for episode in episode_tuples],
|
|
||||||
)
|
|
||||||
|
|
||||||
extracted_edges_bulk = await semaphore_gather(
|
|
||||||
*[
|
*[
|
||||||
extract_edges(
|
extract_edges(
|
||||||
clients,
|
clients,
|
||||||
episode,
|
episode,
|
||||||
extracted_nodes_bulk[i],
|
extracted_nodes_bulk[i],
|
||||||
previous_episodes_list[i],
|
previous_episodes,
|
||||||
{},
|
edge_type_map=edge_type_map,
|
||||||
episode.group_id,
|
group_id=episode.group_id,
|
||||||
|
edge_types=edge_types,
|
||||||
)
|
)
|
||||||
for i, episode in enumerate(episodes)
|
for i, (episode, previous_episodes) in enumerate(episode_tuples)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
episodic_edges: list[EpisodicEdge] = []
|
return extracted_nodes_bulk, extracted_edges_bulk
|
||||||
for i, episode in enumerate(episodes):
|
|
||||||
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
|
|
||||||
|
|
||||||
nodes: list[EntityNode] = []
|
|
||||||
for extracted_nodes in extracted_nodes_bulk:
|
|
||||||
nodes += extracted_nodes
|
|
||||||
|
|
||||||
edges: list[EntityEdge] = []
|
|
||||||
for extracted_edges in extracted_edges_bulk:
|
|
||||||
edges += extracted_edges
|
|
||||||
|
|
||||||
return nodes, edges, episodic_edges
|
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_nodes_bulk(
|
async def dedupe_nodes_bulk(
|
||||||
driver: GraphDriver,
|
clients: GraphitiClients,
|
||||||
llm_client: LLMClient,
|
extracted_nodes: list[list[EntityNode]],
|
||||||
extracted_nodes: list[EntityNode],
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
# Compress nodes
|
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||||||
nodes, uuid_map = node_name_match(extracted_nodes)
|
embedder = clients.embedder
|
||||||
|
min_score = 0.8
|
||||||
|
|
||||||
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
# generate embeddings
|
||||||
|
await semaphore_gather(
|
||||||
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
||||||
|
|
||||||
existing_nodes_chunks: list[list[EntityNode]] = list(
|
|
||||||
await semaphore_gather(
|
|
||||||
*[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
# Find similar results
|
||||||
await semaphore_gather(
|
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
||||||
*[
|
for i, nodes_i in enumerate(extracted_nodes):
|
||||||
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
existing_nodes: list[EntityNode] = []
|
||||||
for i, node_chunk in enumerate(node_chunks)
|
for j, nodes_j in enumerate(extracted_nodes):
|
||||||
]
|
if i == j:
|
||||||
)
|
continue
|
||||||
|
existing_nodes += nodes_j
|
||||||
|
|
||||||
|
candidates_i: list[EntityNode] = []
|
||||||
|
for node in nodes_i:
|
||||||
|
for existing_node in existing_nodes:
|
||||||
|
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||||||
|
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||||||
|
node_words = set(node.name.lower().split())
|
||||||
|
existing_node_words = set(existing_node.name.lower().split())
|
||||||
|
has_overlap = not node_words.isdisjoint(existing_node_words)
|
||||||
|
if has_overlap:
|
||||||
|
candidates_i.append(existing_node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for semantic similarity even if there is no overlap
|
||||||
|
similarity = np.dot(
|
||||||
|
normalize_l2(node.name_embedding or []),
|
||||||
|
normalize_l2(existing_node.name_embedding or []),
|
||||||
|
)
|
||||||
|
if similarity >= min_score:
|
||||||
|
candidates_i.append(existing_node)
|
||||||
|
|
||||||
|
dedupe_tuples.append((nodes_i, candidates_i))
|
||||||
|
|
||||||
|
# Determine Node Resolutions
|
||||||
|
bulk_node_resolutions: list[
|
||||||
|
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
||||||
|
] = await semaphore_gather(
|
||||||
|
*[
|
||||||
|
resolve_extracted_nodes(
|
||||||
|
clients,
|
||||||
|
dedupe_tuple[0],
|
||||||
|
episode_tuples[i][0],
|
||||||
|
episode_tuples[i][1],
|
||||||
|
entity_types,
|
||||||
|
existing_nodes_override=dedupe_tuples[i][1],
|
||||||
|
)
|
||||||
|
for i, dedupe_tuple in enumerate(dedupe_tuples)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
final_nodes: list[EntityNode] = []
|
# Collect all duplicate pairs sorted by uuid
|
||||||
for result in results:
|
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
|
||||||
final_nodes.extend(result[0])
|
for _, _, duplicates in bulk_node_resolutions:
|
||||||
partial_uuid_map = result[1]
|
for duplicate in duplicates:
|
||||||
compressed_map.update(partial_uuid_map)
|
n, m = duplicate
|
||||||
|
if n.uuid < m.uuid:
|
||||||
|
duplicate_pairs.append((n, m))
|
||||||
|
else:
|
||||||
|
duplicate_pairs.append((m, n))
|
||||||
|
|
||||||
return final_nodes, compressed_map
|
# Build full deduplication map
|
||||||
|
duplicate_map: dict[str, str] = {}
|
||||||
|
for value, key in duplicate_pairs:
|
||||||
|
if key.uuid in duplicate_map:
|
||||||
|
existing_value = duplicate_map[key.uuid]
|
||||||
|
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
||||||
|
else:
|
||||||
|
duplicate_map[key.uuid] = value.uuid
|
||||||
|
|
||||||
|
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||||
|
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
||||||
|
|
||||||
|
node_uuid_map: dict[str, EntityNode] = {
|
||||||
|
node.uuid: node for nodes in extracted_nodes for node in nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||||
|
for i, nodes in enumerate(extracted_nodes):
|
||||||
|
episode = episode_tuples[i][0]
|
||||||
|
|
||||||
|
nodes_by_episode[episode.uuid] = [
|
||||||
|
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
return nodes_by_episode, compressed_map
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_edges_bulk(
|
async def dedupe_edges_bulk(
|
||||||
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
clients: GraphitiClients,
|
||||||
) -> list[EntityEdge]:
|
extracted_edges: list[list[EntityEdge]],
|
||||||
# First compress edges
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||||
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
_entities: list[EntityNode],
|
||||||
|
edge_types: dict[str, BaseModel],
|
||||||
|
_edge_type_map: dict[tuple[str, str], list[str]],
|
||||||
|
) -> dict[str, list[EntityEdge]]:
|
||||||
|
embedder = clients.embedder
|
||||||
|
min_score = 0.6
|
||||||
|
|
||||||
edge_chunks = [
|
# generate embeddings
|
||||||
compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
|
await semaphore_gather(
|
||||||
]
|
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
|
||||||
|
|
||||||
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
|
||||||
await semaphore_gather(
|
|
||||||
*[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
# Find similar results
|
||||||
await semaphore_gather(
|
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
||||||
*[
|
for i, edges_i in enumerate(extracted_edges):
|
||||||
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
existing_edges: list[EntityEdge] = []
|
||||||
for i, edge_chunk in enumerate(edge_chunks)
|
for j, edges_j in enumerate(extracted_edges):
|
||||||
]
|
if i == j:
|
||||||
)
|
continue
|
||||||
|
existing_edges += edges_j
|
||||||
|
|
||||||
|
for edge in edges_i:
|
||||||
|
candidates: list[EntityEdge] = []
|
||||||
|
for existing_edge in existing_edges:
|
||||||
|
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||||||
|
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||||||
|
edge_words = set(edge.fact.lower().split())
|
||||||
|
existing_edge_words = set(existing_edge.fact.lower().split())
|
||||||
|
has_overlap = not edge_words.isdisjoint(existing_edge_words)
|
||||||
|
if has_overlap:
|
||||||
|
candidates.append(existing_edge)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for semantic similarity even if there is no overlap
|
||||||
|
similarity = np.dot(
|
||||||
|
normalize_l2(edge.fact_embedding or []),
|
||||||
|
normalize_l2(existing_edge.fact_embedding or []),
|
||||||
|
)
|
||||||
|
if similarity >= min_score:
|
||||||
|
candidates.append(existing_edge)
|
||||||
|
|
||||||
|
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
|
||||||
|
|
||||||
|
bulk_edge_resolutions: list[
|
||||||
|
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
|
||||||
|
] = await semaphore_gather(
|
||||||
|
*[
|
||||||
|
resolve_extracted_edge(
|
||||||
|
clients.llm_client, edge, candidates, candidates, episode, edge_types
|
||||||
|
)
|
||||||
|
for episode, edge, candidates in dedupe_tuples
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
|
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
|
||||||
return edges
|
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
||||||
|
episode, edge, candidates = dedupe_tuples[i]
|
||||||
|
for duplicate in duplicates:
|
||||||
|
if edge.uuid < duplicate.uuid:
|
||||||
|
duplicate_pairs.append((edge, duplicate))
|
||||||
|
else:
|
||||||
|
duplicate_pairs.append((duplicate, edge))
|
||||||
|
|
||||||
|
# Build full deduplication map
|
||||||
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
duplicate_map: dict[str, str] = {}
|
||||||
uuid_map: dict[str, str] = {}
|
for value, key in duplicate_pairs:
|
||||||
name_map: dict[str, EntityNode] = {}
|
if key.uuid in duplicate_map:
|
||||||
for node in nodes:
|
existing_value = duplicate_map[key.uuid]
|
||||||
if node.name in name_map:
|
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
||||||
uuid_map[node.uuid] = name_map[node.name].uuid
|
|
||||||
continue
|
|
||||||
|
|
||||||
name_map[node.name] = node
|
|
||||||
|
|
||||||
return [node for node in name_map.values()], uuid_map
|
|
||||||
|
|
||||||
|
|
||||||
async def compress_nodes(
|
|
||||||
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
||||||
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
|
|
||||||
if len(nodes) == 0:
|
|
||||||
return nodes, uuid_map
|
|
||||||
|
|
||||||
# Our approach involves us deduplicating chunks of nodes in parallel.
|
|
||||||
# We want n chunks of size n so that n ** 2 == len(nodes).
|
|
||||||
# We want chunk sizes to be at least 10 for optimizing LLM processing time
|
|
||||||
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
|
|
||||||
|
|
||||||
# First calculate similarity scores between nodes
|
|
||||||
similarity_scores: list[tuple[int, int, float]] = [
|
|
||||||
(i, j, dot(n.name_embedding or [], m.name_embedding or []))
|
|
||||||
for i, n in enumerate(nodes)
|
|
||||||
for j, m in enumerate(nodes[:i])
|
|
||||||
]
|
|
||||||
|
|
||||||
# We now sort by semantic similarity
|
|
||||||
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
|
|
||||||
|
|
||||||
# initialize our chunks based on chunk size
|
|
||||||
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
|
|
||||||
|
|
||||||
# Draft the most similar nodes into the same chunk
|
|
||||||
while len(similarity_scores) > 0:
|
|
||||||
i, j, _ = similarity_scores.pop()
|
|
||||||
# determine if any of the nodes have already been drafted into a chunk
|
|
||||||
n = nodes[i]
|
|
||||||
m = nodes[j]
|
|
||||||
# make sure the shortest chunks get preference
|
|
||||||
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
|
|
||||||
|
|
||||||
n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
|
||||||
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
|
||||||
|
|
||||||
# both nodes already in a chunk
|
|
||||||
if n_chunk > -1 and m_chunk > -1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# n has a chunk and that chunk is not full
|
|
||||||
elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
|
|
||||||
# put m in the same chunk as n
|
|
||||||
node_chunks[n_chunk].append(m)
|
|
||||||
|
|
||||||
# m has a chunk and that chunk is not full
|
|
||||||
elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
|
|
||||||
# put n in the same chunk as m
|
|
||||||
node_chunks[m_chunk].append(n)
|
|
||||||
|
|
||||||
# neither node has a chunk or the chunk is full
|
|
||||||
else:
|
else:
|
||||||
# add both nodes to the shortest chunk
|
duplicate_map[key.uuid] = value.uuid
|
||||||
node_chunks[-1].extend([n, m])
|
|
||||||
|
|
||||||
results = await semaphore_gather(
|
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||||
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
||||||
)
|
|
||||||
|
|
||||||
extended_map = dict(uuid_map)
|
edge_uuid_map: dict[str, EntityEdge] = {
|
||||||
compressed_nodes: list[EntityNode] = []
|
edge.uuid: edge for edges in extracted_edges for edge in edges
|
||||||
for node_chunk, uuid_map_chunk in results:
|
}
|
||||||
compressed_nodes += node_chunk
|
|
||||||
extended_map.update(uuid_map_chunk)
|
|
||||||
|
|
||||||
# Check if we have removed all duplicates
|
edges_by_episode: dict[str, list[EntityEdge]] = {}
|
||||||
if len(compressed_nodes) == len(nodes):
|
for i, edges in enumerate(extracted_edges):
|
||||||
compressed_uuid_map = compress_uuid_map(extended_map)
|
episode = episode_tuples[i][0]
|
||||||
return compressed_nodes, compressed_uuid_map
|
|
||||||
|
|
||||||
return await compress_nodes(llm_client, compressed_nodes, extended_map)
|
edges_by_episode[episode.uuid] = [
|
||||||
|
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
|
||||||
|
]
|
||||||
|
|
||||||
|
return edges_by_episode
|
||||||
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
|
||||||
if len(edges) == 0:
|
|
||||||
return edges
|
|
||||||
# We only want to dedupe edges that are between the same pair of nodes
|
|
||||||
# We build a map of the edges based on their source and target nodes.
|
|
||||||
edge_chunks = chunk_edges_by_nodes(edges)
|
|
||||||
|
|
||||||
results = await semaphore_gather(
|
|
||||||
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
|
||||||
)
|
|
||||||
|
|
||||||
compressed_edges: list[EntityEdge] = []
|
|
||||||
for edge_chunk in results:
|
|
||||||
compressed_edges += edge_chunk
|
|
||||||
|
|
||||||
# Check if we have removed all duplicates
|
|
||||||
if len(compressed_edges) == len(edges):
|
|
||||||
return compressed_edges
|
|
||||||
|
|
||||||
return await compress_edges(llm_client, compressed_edges)
|
|
||||||
|
|
||||||
|
|
||||||
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
||||||
# make sure all uuid values aren't mapped to other uuids
|
|
||||||
compressed_map = {}
|
compressed_map = {}
|
||||||
for key, uuid in uuid_map.items():
|
|
||||||
curr_value = uuid
|
|
||||||
while curr_value in uuid_map:
|
|
||||||
curr_value = uuid_map[curr_value]
|
|
||||||
|
|
||||||
compressed_map[key] = curr_value
|
def find_min_uuid(start: str) -> str:
|
||||||
|
path = []
|
||||||
|
visited = set()
|
||||||
|
curr = start
|
||||||
|
|
||||||
|
while curr in uuid_map and curr not in visited:
|
||||||
|
visited.add(curr)
|
||||||
|
path.append(curr)
|
||||||
|
curr = uuid_map[curr]
|
||||||
|
|
||||||
|
# Also include the last resolved value (could be outside the map)
|
||||||
|
path.append(curr)
|
||||||
|
|
||||||
|
# Resolve to lex smallest UUID in the path
|
||||||
|
min_uuid = min(path)
|
||||||
|
|
||||||
|
# Assign all UUIDs in the path to the min_uuid
|
||||||
|
for node in path:
|
||||||
|
compressed_map[node] = min_uuid
|
||||||
|
|
||||||
|
return min_uuid
|
||||||
|
|
||||||
|
for key in uuid_map:
|
||||||
|
if key not in compressed_map:
|
||||||
|
find_min_uuid(key)
|
||||||
|
|
||||||
return compressed_map
|
return compressed_map
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -420,63 +436,3 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
||||||
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
async def extract_edge_dates_bulk(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
extracted_edges: list[EntityEdge],
|
|
||||||
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
edges: list[EntityEdge] = []
|
|
||||||
# confirm that all of our edges have at least one episode
|
|
||||||
for edge in extracted_edges:
|
|
||||||
if edge.episodes is not None and len(edge.episodes) > 0:
|
|
||||||
edges.append(edge)
|
|
||||||
|
|
||||||
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
|
|
||||||
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
|
||||||
}
|
|
||||||
|
|
||||||
results = await semaphore_gather(
|
|
||||||
*[
|
|
||||||
extract_edge_dates(
|
|
||||||
llm_client,
|
|
||||||
edge,
|
|
||||||
episode_uuid_map[edge.episodes[0]][0], # type: ignore
|
|
||||||
episode_uuid_map[edge.episodes[0]][1], # type: ignore
|
|
||||||
)
|
|
||||||
for edge in edges
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
valid_at = result[0]
|
|
||||||
invalid_at = result[1]
|
|
||||||
edge = edges[i]
|
|
||||||
|
|
||||||
edge.valid_at = valid_at
|
|
||||||
edge.invalid_at = invalid_at
|
|
||||||
if edge.invalid_at:
|
|
||||||
edge.expired_at = utc_now()
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
|
||||||
# We only want to dedupe edges that are between the same pair of nodes
|
|
||||||
# We build a map of the edges based on their source and target nodes.
|
|
||||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
|
||||||
for edge in edges:
|
|
||||||
# We drop loop edges
|
|
||||||
if edge.source_node_uuid == edge.target_node_uuid:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
|
||||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
|
||||||
pointers.sort()
|
|
||||||
|
|
||||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
|
||||||
|
|
||||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
|
||||||
|
|
||||||
return edge_chunks
|
|
||||||
|
|
|
||||||
|
|
@ -45,15 +45,15 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def build_episodic_edges(
|
def build_episodic_edges(
|
||||||
entity_nodes: list[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
episode: EpisodicNode,
|
episode_uuid: str,
|
||||||
created_at: datetime,
|
created_at: datetime,
|
||||||
) -> list[EpisodicEdge]:
|
) -> list[EpisodicEdge]:
|
||||||
episodic_edges: list[EpisodicEdge] = [
|
episodic_edges: list[EpisodicEdge] = [
|
||||||
EpisodicEdge(
|
EpisodicEdge(
|
||||||
source_node_uuid=episode.uuid,
|
source_node_uuid=episode_uuid,
|
||||||
target_node_uuid=node.uuid,
|
target_node_uuid=node.uuid,
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
group_id=episode.group_id,
|
group_id=node.group_id,
|
||||||
)
|
)
|
||||||
for node in entity_nodes
|
for node in entity_nodes
|
||||||
]
|
]
|
||||||
|
|
@ -68,19 +68,23 @@ def build_duplicate_of_edges(
|
||||||
created_at: datetime,
|
created_at: datetime,
|
||||||
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
is_duplicate_of_edges: list[EntityEdge] = [
|
is_duplicate_of_edges: list[EntityEdge] = []
|
||||||
EntityEdge(
|
for source_node, target_node in duplicate_nodes:
|
||||||
source_node_uuid=source_node.uuid,
|
if source_node.uuid == target_node.uuid:
|
||||||
target_node_uuid=target_node.uuid,
|
continue
|
||||||
name='IS_DUPLICATE_OF',
|
|
||||||
group_id=episode.group_id,
|
is_duplicate_of_edges.append(
|
||||||
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
EntityEdge(
|
||||||
episodes=[episode.uuid],
|
source_node_uuid=source_node.uuid,
|
||||||
created_at=created_at,
|
target_node_uuid=target_node.uuid,
|
||||||
valid_at=created_at,
|
name='IS_DUPLICATE_OF',
|
||||||
|
group_id=episode.group_id,
|
||||||
|
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
||||||
|
episodes=[episode.uuid],
|
||||||
|
created_at=created_at,
|
||||||
|
valid_at=created_at,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for source_node, target_node in duplicate_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
return is_duplicate_of_edges
|
return is_duplicate_of_edges
|
||||||
|
|
||||||
|
|
@ -240,50 +244,6 @@ async def extract_edges(
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_edges(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
extracted_edges: list[EntityEdge],
|
|
||||||
existing_edges: list[EntityEdge],
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
# Create edge map
|
|
||||||
edge_map: dict[str, EntityEdge] = {}
|
|
||||||
for edge in existing_edges:
|
|
||||||
edge_map[edge.uuid] = edge
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
|
||||||
context = {
|
|
||||||
'extracted_edges': [
|
|
||||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
|
|
||||||
],
|
|
||||||
'existing_edges': [
|
|
||||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.edge(context))
|
|
||||||
duplicate_data = llm_response.get('duplicates', [])
|
|
||||||
logger.debug(f'Extracted unique edges: {duplicate_data}')
|
|
||||||
|
|
||||||
duplicate_uuid_map: dict[str, str] = {}
|
|
||||||
for duplicate in duplicate_data:
|
|
||||||
uuid_value = duplicate['duplicate_of']
|
|
||||||
duplicate_uuid_map[duplicate['uuid']] = uuid_value
|
|
||||||
|
|
||||||
# Get full edge data
|
|
||||||
edges: list[EntityEdge] = []
|
|
||||||
for edge in extracted_edges:
|
|
||||||
if edge.uuid in duplicate_uuid_map:
|
|
||||||
existing_uuid = duplicate_uuid_map[edge.uuid]
|
|
||||||
existing_edge = edge_map[existing_uuid]
|
|
||||||
# Add current episode to the episodes list
|
|
||||||
existing_edge.episodes += edge.episodes
|
|
||||||
edges.append(existing_edge)
|
|
||||||
else:
|
|
||||||
edges.append(edge)
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_edges(
|
async def resolve_extracted_edges(
|
||||||
clients: GraphitiClients,
|
clients: GraphitiClients,
|
||||||
extracted_edges: list[EntityEdge],
|
extracted_edges: list[EntityEdge],
|
||||||
|
|
@ -335,7 +295,7 @@ async def resolve_extracted_edges(
|
||||||
edge_types_lst.append(extracted_edge_types)
|
edge_types_lst.append(extracted_edge_types)
|
||||||
|
|
||||||
# resolve edges with related edges in the graph and find invalidation candidates
|
# resolve edges with related edges in the graph and find invalidation candidates
|
||||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
resolve_extracted_edge(
|
resolve_extracted_edge(
|
||||||
|
|
@ -416,9 +376,9 @@ async def resolve_extracted_edge(
|
||||||
existing_edges: list[EntityEdge],
|
existing_edges: list[EntityEdge],
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
edge_types: dict[str, BaseModel] | None = None,
|
edge_types: dict[str, BaseModel] | None = None,
|
||||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||||
return extracted_edge, []
|
return extracted_edge, [], []
|
||||||
|
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -457,15 +417,16 @@ async def resolve_extracted_edge(
|
||||||
model_size=ModelSize.small,
|
model_size=ModelSize.small,
|
||||||
)
|
)
|
||||||
|
|
||||||
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
duplicate_fact_ids: list[int] = list(
|
||||||
|
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
|
||||||
resolved_edge = (
|
|
||||||
related_edges[duplicate_fact_id]
|
|
||||||
if 0 <= duplicate_fact_id < len(related_edges)
|
|
||||||
else extracted_edge
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if duplicate_fact_id >= 0 and episode is not None:
|
resolved_edge = extracted_edge
|
||||||
|
for duplicate_fact_id in duplicate_fact_ids:
|
||||||
|
resolved_edge = related_edges[duplicate_fact_id]
|
||||||
|
break
|
||||||
|
|
||||||
|
if duplicate_fact_ids and episode is not None:
|
||||||
resolved_edge.episodes.append(episode.uuid)
|
resolved_edge.episodes.append(episode.uuid)
|
||||||
|
|
||||||
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
||||||
|
|
@ -519,59 +480,12 @@ async def resolve_extracted_edge(
|
||||||
break
|
break
|
||||||
|
|
||||||
# Determine which contradictory edges need to be expired
|
# Determine which contradictory edges need to be expired
|
||||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
|
invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
|
||||||
|
resolved_edge, invalidation_candidates
|
||||||
return resolved_edge, invalidated_edges
|
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_edge(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
extracted_edge: EntityEdge,
|
|
||||||
related_edges: list[EntityEdge],
|
|
||||||
episode: EpisodicNode | None = None,
|
|
||||||
) -> EntityEdge:
|
|
||||||
if len(related_edges) == 0:
|
|
||||||
return extracted_edge
|
|
||||||
|
|
||||||
start = time()
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
|
||||||
related_edges_context = [
|
|
||||||
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
||||||
]
|
|
||||||
|
|
||||||
extracted_edge_context = {
|
|
||||||
'fact': extracted_edge.fact,
|
|
||||||
}
|
|
||||||
|
|
||||||
context = {
|
|
||||||
'related_edges': related_edges_context,
|
|
||||||
'extracted_edges': extracted_edge_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
|
||||||
prompt_library.dedupe_edges.edge(context),
|
|
||||||
response_model=EdgeDuplicate,
|
|
||||||
model_size=ModelSize.small,
|
|
||||||
)
|
)
|
||||||
|
duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
|
||||||
|
|
||||||
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
return resolved_edge, invalidated_edges, duplicate_edges
|
||||||
|
|
||||||
edge = (
|
|
||||||
related_edges[duplicate_fact_id]
|
|
||||||
if 0 <= duplicate_fact_id < len(related_edges)
|
|
||||||
else extracted_edge
|
|
||||||
)
|
|
||||||
|
|
||||||
if duplicate_fact_id >= 0 and episode is not None:
|
|
||||||
edge.episodes.append(episode.uuid)
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.debug(
|
|
||||||
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
|
||||||
)
|
|
||||||
|
|
||||||
return edge
|
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_edge_list(
|
async def dedupe_edge_list(
|
||||||
|
|
|
||||||
|
|
@ -176,62 +176,13 @@ async def extract_nodes(
|
||||||
return extracted_nodes
|
return extracted_nodes
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_nodes(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
extracted_nodes: list[EntityNode],
|
|
||||||
existing_nodes: list[EntityNode],
|
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
||||||
start = time()
|
|
||||||
|
|
||||||
# build existing node map
|
|
||||||
node_map: dict[str, EntityNode] = {}
|
|
||||||
for node in existing_nodes:
|
|
||||||
node_map[node.uuid] = node
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
|
||||||
existing_nodes_context = [
|
|
||||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
extracted_nodes_context = [
|
|
||||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
context = {
|
|
||||||
'existing_nodes': existing_nodes_context,
|
|
||||||
'extracted_nodes': extracted_nodes_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.node(context))
|
|
||||||
|
|
||||||
duplicate_data = llm_response.get('duplicates', [])
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.debug(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
|
||||||
|
|
||||||
uuid_map: dict[str, str] = {}
|
|
||||||
for duplicate in duplicate_data:
|
|
||||||
uuid_value = duplicate['duplicate_of']
|
|
||||||
uuid_map[duplicate['uuid']] = uuid_value
|
|
||||||
|
|
||||||
nodes: list[EntityNode] = []
|
|
||||||
for node in extracted_nodes:
|
|
||||||
if node.uuid in uuid_map:
|
|
||||||
existing_uuid = uuid_map[node.uuid]
|
|
||||||
existing_node = node_map[existing_uuid]
|
|
||||||
nodes.append(existing_node)
|
|
||||||
else:
|
|
||||||
nodes.append(node)
|
|
||||||
|
|
||||||
return nodes, uuid_map
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_nodes(
|
async def resolve_extracted_nodes(
|
||||||
clients: GraphitiClients,
|
clients: GraphitiClients,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
|
existing_nodes_override: list[EntityNode] | None = None,
|
||||||
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
driver = clients.driver
|
driver = clients.driver
|
||||||
|
|
@ -249,9 +200,13 @@ async def resolve_extracted_nodes(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_nodes_dict: dict[str, EntityNode] = {
|
candidate_nodes: list[EntityNode] = (
|
||||||
node.uuid: node for result in search_results for node in result.nodes
|
[node for result in search_results for node in result.nodes]
|
||||||
}
|
if existing_nodes_override is None
|
||||||
|
else existing_nodes_override
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_nodes_dict: dict[str, EntityNode] = {node.uuid: node for node in candidate_nodes}
|
||||||
|
|
||||||
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
|
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
|
||||||
|
|
||||||
|
|
@ -321,13 +276,11 @@ async def resolve_extracted_nodes(
|
||||||
resolved_nodes.append(resolved_node)
|
resolved_nodes.append(resolved_node)
|
||||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||||
|
|
||||||
additional_duplicates: list[int] = resolution.get('additional_duplicates', [])
|
duplicates: list[int] = resolution.get('duplicates', [])
|
||||||
for idx in additional_duplicates:
|
for idx in duplicates:
|
||||||
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
||||||
if existing_node == resolved_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_duplicates.append((resolved_node, existing_nodes[idx]))
|
node_duplicates.append((resolved_node, existing_node))
|
||||||
|
|
||||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue