Bulk ingestion (#698)
* partial * update * update * update * update * updates * updates * update * update
This commit is contained in:
parent
94df836396
commit
0675ac2b7d
10 changed files with 2351 additions and 2466 deletions
|
|
@ -25,6 +25,8 @@ from pydantic import BaseModel, Field
|
|||
from transcript_parser import parse_podcast_messages
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -67,7 +69,7 @@ class IsPresidentOf(BaseModel):
|
|||
"""Relationship between a person and the entity they are a president of"""
|
||||
|
||||
|
||||
async def main():
|
||||
async def main(use_bulk: bool = False):
|
||||
setup_logging()
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
await clear_data(client.driver)
|
||||
|
|
@ -75,21 +77,43 @@ async def main():
|
|||
messages = parse_podcast_messages()
|
||||
group_id = str(uuid4())
|
||||
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
episodes = await client.retrieve_episodes(message.actual_timestamp, 3, group_ids=[group_id])
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
raw_episodes: list[RawEpisode] = []
|
||||
for i, message in enumerate(messages[3:7]):
|
||||
raw_episodes.append(
|
||||
RawEpisode(
|
||||
name=f'Message {i}',
|
||||
content=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source=EpisodeType.message,
|
||||
source_description='Podcast Transcript',
|
||||
)
|
||||
)
|
||||
if use_bulk:
|
||||
await client.add_episode_bulk(
|
||||
raw_episodes,
|
||||
group_id=group_id,
|
||||
entity_types={'Person': Person},
|
||||
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
|
||||
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
|
||||
previous_episode_uuids=episode_uuids,
|
||||
)
|
||||
else:
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
episodes = await client.retrieve_episodes(
|
||||
message.actual_timestamp, 3, group_ids=[group_id]
|
||||
)
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id=group_id,
|
||||
entity_types={'Person': Person},
|
||||
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
|
||||
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
|
||||
previous_episode_uuids=episode_uuids,
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(main(True))
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class GeminiRerankerClient(CrossEncoderClient):
|
|||
"""
|
||||
Google Gemini Reranker Client
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class GeminiEmbedder(EmbedderClient):
|
|||
"""
|
||||
Google Gemini Embedder Client
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GeminiEmbedderConfig | None = None,
|
||||
|
|
|
|||
|
|
@ -57,7 +57,6 @@ from graphiti_core.utils.bulk_utils import (
|
|||
add_nodes_and_edges_bulk,
|
||||
dedupe_edges_bulk,
|
||||
dedupe_nodes_bulk,
|
||||
extract_edge_dates_bulk,
|
||||
extract_nodes_and_edges_bulk,
|
||||
resolve_edge_pointers,
|
||||
retrieve_previous_episodes_bulk,
|
||||
|
|
@ -508,7 +507,7 @@ class Graphiti:
|
|||
|
||||
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
||||
|
||||
episodic_edges = build_episodic_edges(nodes, episode, now)
|
||||
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
||||
|
||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||
|
||||
|
|
@ -536,8 +535,16 @@ class Graphiti:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
#### WIP: USE AT YOUR OWN RISK ####
|
||||
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
|
||||
##### EXPERIMENTAL #####
|
||||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
group_id: str = '',
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
):
|
||||
"""
|
||||
Process multiple episodes in bulk and update the graph.
|
||||
|
||||
|
|
@ -580,8 +587,17 @@ class Graphiti:
|
|||
|
||||
validate_group_id(group_id)
|
||||
|
||||
# Create default edge type map
|
||||
edge_type_map_default = (
|
||||
{('Entity', 'Entity'): list(edge_types.keys())}
|
||||
if edge_types is not None
|
||||
else {('Entity', 'Entity'): []}
|
||||
)
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
||||
if episode.uuid is not None
|
||||
else EpisodicNode(
|
||||
name=episode.name,
|
||||
labels=[],
|
||||
source=episode.source,
|
||||
|
|
@ -594,68 +610,106 @@ class Graphiti:
|
|||
for episode in bulk_episodes
|
||||
]
|
||||
|
||||
# Save all the episodes
|
||||
await semaphore_gather(
|
||||
*[episode.save(self.driver) for episode in episodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
episodes_by_uuid: dict[str, EpisodicNode] = {
|
||||
episode.uuid: episode for episode in episodes
|
||||
}
|
||||
|
||||
# Save all episodes
|
||||
await add_nodes_and_edges_bulk(
|
||||
driver=self.driver,
|
||||
episodic_nodes=episodes,
|
||||
episodic_edges=[],
|
||||
entity_nodes=[],
|
||||
entity_edges=[],
|
||||
embedder=self.embedder,
|
||||
)
|
||||
|
||||
# Get previous episode context for each episode
|
||||
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||
|
||||
# Extract all nodes and edges
|
||||
(
|
||||
extracted_nodes,
|
||||
extracted_edges,
|
||||
episodic_edges,
|
||||
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
|
||||
|
||||
# Generate embeddings
|
||||
await semaphore_gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
||||
max_coroutines=self.max_coroutines,
|
||||
# Extract all nodes and edges for each episode
|
||||
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
||||
self.clients,
|
||||
episode_context,
|
||||
edge_type_map=edge_type_map or edge_type_map_default,
|
||||
edge_types=edge_types,
|
||||
entity_types=entity_types,
|
||||
excluded_entity_types=excluded_entity_types,
|
||||
)
|
||||
|
||||
# Dedupe extracted nodes, compress extracted edges
|
||||
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
||||
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
||||
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
||||
max_coroutines=self.max_coroutines,
|
||||
# Dedupe extracted nodes in memory
|
||||
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
||||
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
||||
)
|
||||
|
||||
# save nodes to KG
|
||||
await semaphore_gather(
|
||||
*[node.save(self.driver) for node in nodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
episodic_edges: list[EpisodicEdge] = []
|
||||
for episode_uuid, nodes in nodes_by_episode.items():
|
||||
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
||||
|
||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||
extracted_edges_timestamped, uuid_map
|
||||
)
|
||||
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
||||
episodic_edges, uuid_map
|
||||
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
||||
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
||||
]
|
||||
|
||||
# Dedupe extracted edges in memory
|
||||
edges_by_episode = await dedupe_edges_bulk(
|
||||
self.clients,
|
||||
extracted_edges_bulk_updated,
|
||||
episode_context,
|
||||
[],
|
||||
edge_types or {},
|
||||
edge_type_map or edge_type_map_default,
|
||||
)
|
||||
|
||||
# save episodic edges to KG
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
|
||||
max_coroutines=self.max_coroutines,
|
||||
# Extract node attributes
|
||||
nodes_by_uuid: dict[str, EntityNode] = {
|
||||
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
||||
}
|
||||
|
||||
extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
|
||||
for node in nodes_by_uuid.values():
|
||||
episode_uuids: list[str] = []
|
||||
for episode_uuid, mentioned_nodes in nodes_by_episode.items():
|
||||
for mentioned_node in mentioned_nodes:
|
||||
if node.uuid == mentioned_node.uuid:
|
||||
episode_uuids.append(episode_uuid)
|
||||
break
|
||||
|
||||
episode_mentions: list[EpisodicNode] = [
|
||||
episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
|
||||
]
|
||||
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
|
||||
|
||||
extract_attributes_params.append((node, episode_mentions))
|
||||
|
||||
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
|
||||
*[
|
||||
extract_attributes_from_nodes(
|
||||
self.clients,
|
||||
[params[0]],
|
||||
params[1][0],
|
||||
params[1][0:],
|
||||
entity_types,
|
||||
)
|
||||
for params in extract_attributes_params
|
||||
]
|
||||
)
|
||||
|
||||
# Dedupe extracted edges
|
||||
edges = await dedupe_edges_bulk(
|
||||
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
||||
)
|
||||
logger.debug(f'extracted edge length: {len(edges)}')
|
||||
hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
|
||||
|
||||
# invalidate edges
|
||||
# TODO: Resolve nodes and edges against the existing graph
|
||||
edges_by_uuid: dict[str, EntityEdge] = {
|
||||
edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
|
||||
}
|
||||
|
||||
# save edges to KG
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in edges],
|
||||
max_coroutines=self.max_coroutines,
|
||||
# save data to KG
|
||||
await add_nodes_and_edges_bulk(
|
||||
self.driver,
|
||||
episodes,
|
||||
episodic_edges,
|
||||
hydrated_nodes,
|
||||
list(edges_by_uuid.values()),
|
||||
self.embedder,
|
||||
)
|
||||
|
||||
end = time()
|
||||
|
|
@ -828,7 +882,7 @@ class Graphiti:
|
|||
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
||||
)[0]
|
||||
|
||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
||||
self.llm_client,
|
||||
updated_edge,
|
||||
related_edges,
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
|
|||
|
||||
|
||||
class EdgeDuplicate(BaseModel):
|
||||
duplicate_fact_id: int = Field(
|
||||
duplicate_facts: list[int] = Field(
|
||||
...,
|
||||
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
|
||||
description='List of ids of any duplicate facts. If no duplicate facts are found, default to empty list.',
|
||||
)
|
||||
contradicted_facts: list[int] = Field(
|
||||
...,
|
||||
|
|
@ -75,8 +75,9 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|||
</NEW EDGE>
|
||||
|
||||
Task:
|
||||
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact.
|
||||
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return -1.
|
||||
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
|
||||
as part of the list of duplicate_facts.
|
||||
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return an empty list.
|
||||
|
||||
Guidelines:
|
||||
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ class NodeDuplicate(BaseModel):
|
|||
...,
|
||||
description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
|
||||
)
|
||||
additional_duplicates: list[int] = Field(
|
||||
duplicates: list[int] = Field(
|
||||
...,
|
||||
description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
|
||||
description='idx of all duplicate entities.',
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ def node(context: dict[str, Any]) -> list[Message]:
|
|||
1. Compare `new_entity` against each item in `existing_entities`.
|
||||
2. If it refers to the same real‐world object or concept, collect its index.
|
||||
3. Let `duplicate_idx` = the *first* collected index, or –1 if none.
|
||||
4. Let `additional_duplicates` = the list of *any other* collected indices (empty list if none).
|
||||
4. Let `duplicates` = the list of *all* collected indices (empty list if none).
|
||||
|
||||
Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
|
||||
is a duplicate of, or a combination of the two).
|
||||
|
|
|
|||
|
|
@ -16,50 +16,40 @@ limitations under the License.
|
|||
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
from numpy import dot, sqrt
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.graph_queries import (
|
||||
get_entity_edge_save_bulk_query,
|
||||
get_entity_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, normalize_l2, semaphore_gather
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
EPISODIC_EDGE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
EPISODIC_NODE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
build_episodic_edges,
|
||||
dedupe_edge_list,
|
||||
dedupe_extracted_edges,
|
||||
extract_edges,
|
||||
resolve_extracted_edge,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
dedupe_extracted_nodes,
|
||||
dedupe_node_list,
|
||||
extract_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -68,6 +58,7 @@ CHUNK_SIZE = 10
|
|||
|
||||
class RawEpisode(BaseModel):
|
||||
name: str
|
||||
uuid: str | None = Field(default=None)
|
||||
content: str
|
||||
source_description: str
|
||||
source: EpisodeType
|
||||
|
|
@ -179,233 +170,258 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
async def extract_nodes_and_edges_bulk(
|
||||
clients: GraphitiClients,
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
edge_type_map: dict[tuple[str, str], list[str]],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||
extracted_nodes_bulk = await semaphore_gather(
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
||||
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
||||
*[
|
||||
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
||||
for episode, previous_episodes in episode_tuples
|
||||
]
|
||||
)
|
||||
|
||||
episodes, previous_episodes_list = (
|
||||
[episode[0] for episode in episode_tuples],
|
||||
[episode[1] for episode in episode_tuples],
|
||||
)
|
||||
|
||||
extracted_edges_bulk = await semaphore_gather(
|
||||
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
|
||||
*[
|
||||
extract_edges(
|
||||
clients,
|
||||
episode,
|
||||
extracted_nodes_bulk[i],
|
||||
previous_episodes_list[i],
|
||||
{},
|
||||
episode.group_id,
|
||||
previous_episodes,
|
||||
edge_type_map=edge_type_map,
|
||||
group_id=episode.group_id,
|
||||
edge_types=edge_types,
|
||||
)
|
||||
for i, episode in enumerate(episodes)
|
||||
for i, (episode, previous_episodes) in enumerate(episode_tuples)
|
||||
]
|
||||
)
|
||||
|
||||
episodic_edges: list[EpisodicEdge] = []
|
||||
for i, episode in enumerate(episodes):
|
||||
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
for extracted_nodes in extracted_nodes_bulk:
|
||||
nodes += extracted_nodes
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
for extracted_edges in extracted_edges_bulk:
|
||||
edges += extracted_edges
|
||||
|
||||
return nodes, edges, episodic_edges
|
||||
return extracted_nodes_bulk, extracted_edges_bulk
|
||||
|
||||
|
||||
async def dedupe_nodes_bulk(
|
||||
driver: GraphDriver,
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
# Compress nodes
|
||||
nodes, uuid_map = node_name_match(extracted_nodes)
|
||||
clients: GraphitiClients,
|
||||
extracted_nodes: list[list[EntityNode]],
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||||
embedder = clients.embedder
|
||||
min_score = 0.8
|
||||
|
||||
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
||||
|
||||
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||
|
||||
existing_nodes_chunks: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
|
||||
)
|
||||
# generate embeddings
|
||||
await semaphore_gather(
|
||||
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
||||
)
|
||||
|
||||
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
||||
for i, node_chunk in enumerate(node_chunks)
|
||||
]
|
||||
)
|
||||
# Find similar results
|
||||
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
||||
for i, nodes_i in enumerate(extracted_nodes):
|
||||
existing_nodes: list[EntityNode] = []
|
||||
for j, nodes_j in enumerate(extracted_nodes):
|
||||
if i == j:
|
||||
continue
|
||||
existing_nodes += nodes_j
|
||||
|
||||
candidates_i: list[EntityNode] = []
|
||||
for node in nodes_i:
|
||||
for existing_node in existing_nodes:
|
||||
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||||
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||||
node_words = set(node.name.lower().split())
|
||||
existing_node_words = set(existing_node.name.lower().split())
|
||||
has_overlap = not node_words.isdisjoint(existing_node_words)
|
||||
if has_overlap:
|
||||
candidates_i.append(existing_node)
|
||||
continue
|
||||
|
||||
# Check for semantic similarity even if there is no overlap
|
||||
similarity = np.dot(
|
||||
normalize_l2(node.name_embedding or []),
|
||||
normalize_l2(existing_node.name_embedding or []),
|
||||
)
|
||||
if similarity >= min_score:
|
||||
candidates_i.append(existing_node)
|
||||
|
||||
dedupe_tuples.append((nodes_i, candidates_i))
|
||||
|
||||
# Determine Node Resolutions
|
||||
bulk_node_resolutions: list[
|
||||
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
||||
] = await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_nodes(
|
||||
clients,
|
||||
dedupe_tuple[0],
|
||||
episode_tuples[i][0],
|
||||
episode_tuples[i][1],
|
||||
entity_types,
|
||||
existing_nodes_override=dedupe_tuples[i][1],
|
||||
)
|
||||
for i, dedupe_tuple in enumerate(dedupe_tuples)
|
||||
]
|
||||
)
|
||||
|
||||
final_nodes: list[EntityNode] = []
|
||||
for result in results:
|
||||
final_nodes.extend(result[0])
|
||||
partial_uuid_map = result[1]
|
||||
compressed_map.update(partial_uuid_map)
|
||||
# Collect all duplicate pairs sorted by uuid
|
||||
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
|
||||
for _, _, duplicates in bulk_node_resolutions:
|
||||
for duplicate in duplicates:
|
||||
n, m = duplicate
|
||||
if n.uuid < m.uuid:
|
||||
duplicate_pairs.append((n, m))
|
||||
else:
|
||||
duplicate_pairs.append((m, n))
|
||||
|
||||
return final_nodes, compressed_map
|
||||
# Build full deduplication map
|
||||
duplicate_map: dict[str, str] = {}
|
||||
for value, key in duplicate_pairs:
|
||||
if key.uuid in duplicate_map:
|
||||
existing_value = duplicate_map[key.uuid]
|
||||
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
||||
else:
|
||||
duplicate_map[key.uuid] = value.uuid
|
||||
|
||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
||||
|
||||
node_uuid_map: dict[str, EntityNode] = {
|
||||
node.uuid: node for nodes in extracted_nodes for node in nodes
|
||||
}
|
||||
|
||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||
for i, nodes in enumerate(extracted_nodes):
|
||||
episode = episode_tuples[i][0]
|
||||
|
||||
nodes_by_episode[episode.uuid] = [
|
||||
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
||||
]
|
||||
|
||||
return nodes_by_episode, compressed_map
|
||||
|
||||
|
||||
async def dedupe_edges_bulk(
|
||||
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||
) -> list[EntityEdge]:
|
||||
# First compress edges
|
||||
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
||||
clients: GraphitiClients,
|
||||
extracted_edges: list[list[EntityEdge]],
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
_entities: list[EntityNode],
|
||||
edge_types: dict[str, BaseModel],
|
||||
_edge_type_map: dict[tuple[str, str], list[str]],
|
||||
) -> dict[str, list[EntityEdge]]:
|
||||
embedder = clients.embedder
|
||||
min_score = 0.6
|
||||
|
||||
edge_chunks = [
|
||||
compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
|
||||
]
|
||||
|
||||
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
|
||||
)
|
||||
# generate embeddings
|
||||
await semaphore_gather(
|
||||
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
|
||||
)
|
||||
|
||||
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
||||
for i, edge_chunk in enumerate(edge_chunks)
|
||||
]
|
||||
)
|
||||
# Find similar results
|
||||
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
||||
for i, edges_i in enumerate(extracted_edges):
|
||||
existing_edges: list[EntityEdge] = []
|
||||
for j, edges_j in enumerate(extracted_edges):
|
||||
if i == j:
|
||||
continue
|
||||
existing_edges += edges_j
|
||||
|
||||
for edge in edges_i:
|
||||
candidates: list[EntityEdge] = []
|
||||
for existing_edge in existing_edges:
|
||||
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||||
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||||
edge_words = set(edge.fact.lower().split())
|
||||
existing_edge_words = set(existing_edge.fact.lower().split())
|
||||
has_overlap = not edge_words.isdisjoint(existing_edge_words)
|
||||
if has_overlap:
|
||||
candidates.append(existing_edge)
|
||||
continue
|
||||
|
||||
# Check for semantic similarity even if there is no overlap
|
||||
similarity = np.dot(
|
||||
normalize_l2(edge.fact_embedding or []),
|
||||
normalize_l2(existing_edge.fact_embedding or []),
|
||||
)
|
||||
if similarity >= min_score:
|
||||
candidates.append(existing_edge)
|
||||
|
||||
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
|
||||
|
||||
bulk_edge_resolutions: list[
|
||||
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
|
||||
] = await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_edge(
|
||||
clients.llm_client, edge, candidates, candidates, episode, edge_types
|
||||
)
|
||||
for episode, edge, candidates in dedupe_tuples
|
||||
]
|
||||
)
|
||||
|
||||
edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
|
||||
return edges
|
||||
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
|
||||
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
||||
episode, edge, candidates = dedupe_tuples[i]
|
||||
for duplicate in duplicates:
|
||||
if edge.uuid < duplicate.uuid:
|
||||
duplicate_pairs.append((edge, duplicate))
|
||||
else:
|
||||
duplicate_pairs.append((duplicate, edge))
|
||||
|
||||
|
||||
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
uuid_map: dict[str, str] = {}
|
||||
name_map: dict[str, EntityNode] = {}
|
||||
for node in nodes:
|
||||
if node.name in name_map:
|
||||
uuid_map[node.uuid] = name_map[node.name].uuid
|
||||
continue
|
||||
|
||||
name_map[node.name] = node
|
||||
|
||||
return [node for node in name_map.values()], uuid_map
|
||||
|
||||
|
||||
async def compress_nodes(
|
||||
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
|
||||
if len(nodes) == 0:
|
||||
return nodes, uuid_map
|
||||
|
||||
# Our approach involves us deduplicating chunks of nodes in parallel.
|
||||
# We want n chunks of size n so that n ** 2 == len(nodes).
|
||||
# We want chunk sizes to be at least 10 for optimizing LLM processing time
|
||||
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
|
||||
|
||||
# First calculate similarity scores between nodes
|
||||
similarity_scores: list[tuple[int, int, float]] = [
|
||||
(i, j, dot(n.name_embedding or [], m.name_embedding or []))
|
||||
for i, n in enumerate(nodes)
|
||||
for j, m in enumerate(nodes[:i])
|
||||
]
|
||||
|
||||
# We now sort by semantic similarity
|
||||
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
|
||||
|
||||
# initialize our chunks based on chunk size
|
||||
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
|
||||
|
||||
# Draft the most similar nodes into the same chunk
|
||||
while len(similarity_scores) > 0:
|
||||
i, j, _ = similarity_scores.pop()
|
||||
# determine if any of the nodes have already been drafted into a chunk
|
||||
n = nodes[i]
|
||||
m = nodes[j]
|
||||
# make sure the shortest chunks get preference
|
||||
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
|
||||
|
||||
n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
||||
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
||||
|
||||
# both nodes already in a chunk
|
||||
if n_chunk > -1 and m_chunk > -1:
|
||||
continue
|
||||
|
||||
# n has a chunk and that chunk is not full
|
||||
elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
|
||||
# put m in the same chunk as n
|
||||
node_chunks[n_chunk].append(m)
|
||||
|
||||
# m has a chunk and that chunk is not full
|
||||
elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
|
||||
# put n in the same chunk as m
|
||||
node_chunks[m_chunk].append(n)
|
||||
|
||||
# neither node has a chunk or the chunk is full
|
||||
# Build full deduplication map
|
||||
duplicate_map: dict[str, str] = {}
|
||||
for value, key in duplicate_pairs:
|
||||
if key.uuid in duplicate_map:
|
||||
existing_value = duplicate_map[key.uuid]
|
||||
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
||||
else:
|
||||
# add both nodes to the shortest chunk
|
||||
node_chunks[-1].extend([n, m])
|
||||
duplicate_map[key.uuid] = value.uuid
|
||||
|
||||
results = await semaphore_gather(
|
||||
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
||||
)
|
||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
||||
|
||||
extended_map = dict(uuid_map)
|
||||
compressed_nodes: list[EntityNode] = []
|
||||
for node_chunk, uuid_map_chunk in results:
|
||||
compressed_nodes += node_chunk
|
||||
extended_map.update(uuid_map_chunk)
|
||||
edge_uuid_map: dict[str, EntityEdge] = {
|
||||
edge.uuid: edge for edges in extracted_edges for edge in edges
|
||||
}
|
||||
|
||||
# Check if we have removed all duplicates
|
||||
if len(compressed_nodes) == len(nodes):
|
||||
compressed_uuid_map = compress_uuid_map(extended_map)
|
||||
return compressed_nodes, compressed_uuid_map
|
||||
edges_by_episode: dict[str, list[EntityEdge]] = {}
|
||||
for i, edges in enumerate(extracted_edges):
|
||||
episode = episode_tuples[i][0]
|
||||
|
||||
return await compress_nodes(llm_client, compressed_nodes, extended_map)
|
||||
edges_by_episode[episode.uuid] = [
|
||||
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
|
||||
]
|
||||
|
||||
|
||||
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
||||
if len(edges) == 0:
|
||||
return edges
|
||||
# We only want to dedupe edges that are between the same pair of nodes
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunks = chunk_edges_by_nodes(edges)
|
||||
|
||||
results = await semaphore_gather(
|
||||
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
||||
)
|
||||
|
||||
compressed_edges: list[EntityEdge] = []
|
||||
for edge_chunk in results:
|
||||
compressed_edges += edge_chunk
|
||||
|
||||
# Check if we have removed all duplicates
|
||||
if len(compressed_edges) == len(edges):
|
||||
return compressed_edges
|
||||
|
||||
return await compress_edges(llm_client, compressed_edges)
|
||||
return edges_by_episode
|
||||
|
||||
|
||||
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
||||
# make sure all uuid values aren't mapped to other uuids
|
||||
compressed_map = {}
|
||||
for key, uuid in uuid_map.items():
|
||||
curr_value = uuid
|
||||
while curr_value in uuid_map:
|
||||
curr_value = uuid_map[curr_value]
|
||||
|
||||
compressed_map[key] = curr_value
|
||||
def find_min_uuid(start: str) -> str:
|
||||
path = []
|
||||
visited = set()
|
||||
curr = start
|
||||
|
||||
while curr in uuid_map and curr not in visited:
|
||||
visited.add(curr)
|
||||
path.append(curr)
|
||||
curr = uuid_map[curr]
|
||||
|
||||
# Also include the last resolved value (could be outside the map)
|
||||
path.append(curr)
|
||||
|
||||
# Resolve to lex smallest UUID in the path
|
||||
min_uuid = min(path)
|
||||
|
||||
# Assign all UUIDs in the path to the min_uuid
|
||||
for node in path:
|
||||
compressed_map[node] = min_uuid
|
||||
|
||||
return min_uuid
|
||||
|
||||
for key in uuid_map:
|
||||
if key not in compressed_map:
|
||||
find_min_uuid(key)
|
||||
|
||||
return compressed_map
|
||||
|
||||
|
||||
|
|
@ -420,63 +436,3 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
|||
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def extract_edge_dates_bulk(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
) -> list[EntityEdge]:
|
||||
edges: list[EntityEdge] = []
|
||||
# confirm that all of our edges have at least one episode
|
||||
for edge in extracted_edges:
|
||||
if edge.episodes is not None and len(edge.episodes) > 0:
|
||||
edges.append(edge)
|
||||
|
||||
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
|
||||
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
||||
}
|
||||
|
||||
results = await semaphore_gather(
|
||||
*[
|
||||
extract_edge_dates(
|
||||
llm_client,
|
||||
edge,
|
||||
episode_uuid_map[edge.episodes[0]][0], # type: ignore
|
||||
episode_uuid_map[edge.episodes[0]][1], # type: ignore
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
valid_at = result[0]
|
||||
invalid_at = result[1]
|
||||
edge = edges[i]
|
||||
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = utc_now()
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
||||
# We only want to dedupe edges that are between the same pair of nodes
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
# We drop loop edges
|
||||
if edge.source_node_uuid == edge.target_node_uuid:
|
||||
continue
|
||||
|
||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||
pointers.sort()
|
||||
|
||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||
|
||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||
|
||||
return edge_chunks
|
||||
|
|
|
|||
|
|
@ -45,15 +45,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def build_episodic_edges(
|
||||
entity_nodes: list[EntityNode],
|
||||
episode: EpisodicNode,
|
||||
episode_uuid: str,
|
||||
created_at: datetime,
|
||||
) -> list[EpisodicEdge]:
|
||||
episodic_edges: list[EpisodicEdge] = [
|
||||
EpisodicEdge(
|
||||
source_node_uuid=episode.uuid,
|
||||
source_node_uuid=episode_uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
created_at=created_at,
|
||||
group_id=episode.group_id,
|
||||
group_id=node.group_id,
|
||||
)
|
||||
for node in entity_nodes
|
||||
]
|
||||
|
|
@ -68,19 +68,23 @@ def build_duplicate_of_edges(
|
|||
created_at: datetime,
|
||||
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
||||
) -> list[EntityEdge]:
|
||||
is_duplicate_of_edges: list[EntityEdge] = [
|
||||
EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='IS_DUPLICATE_OF',
|
||||
group_id=episode.group_id,
|
||||
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
||||
episodes=[episode.uuid],
|
||||
created_at=created_at,
|
||||
valid_at=created_at,
|
||||
is_duplicate_of_edges: list[EntityEdge] = []
|
||||
for source_node, target_node in duplicate_nodes:
|
||||
if source_node.uuid == target_node.uuid:
|
||||
continue
|
||||
|
||||
is_duplicate_of_edges.append(
|
||||
EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='IS_DUPLICATE_OF',
|
||||
group_id=episode.group_id,
|
||||
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
||||
episodes=[episode.uuid],
|
||||
created_at=created_at,
|
||||
valid_at=created_at,
|
||||
)
|
||||
)
|
||||
for source_node, target_node in duplicate_nodes
|
||||
]
|
||||
|
||||
return is_duplicate_of_edges
|
||||
|
||||
|
|
@ -240,50 +244,6 @@ async def extract_edges(
|
|||
return edges
|
||||
|
||||
|
||||
async def dedupe_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
) -> list[EntityEdge]:
|
||||
# Create edge map
|
||||
edge_map: dict[str, EntityEdge] = {}
|
||||
for edge in existing_edges:
|
||||
edge_map[edge.uuid] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'extracted_edges': [
|
||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
|
||||
],
|
||||
'existing_edges': [
|
||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.edge(context))
|
||||
duplicate_data = llm_response.get('duplicates', [])
|
||||
logger.debug(f'Extracted unique edges: {duplicate_data}')
|
||||
|
||||
duplicate_uuid_map: dict[str, str] = {}
|
||||
for duplicate in duplicate_data:
|
||||
uuid_value = duplicate['duplicate_of']
|
||||
duplicate_uuid_map[duplicate['uuid']] = uuid_value
|
||||
|
||||
# Get full edge data
|
||||
edges: list[EntityEdge] = []
|
||||
for edge in extracted_edges:
|
||||
if edge.uuid in duplicate_uuid_map:
|
||||
existing_uuid = duplicate_uuid_map[edge.uuid]
|
||||
existing_edge = edge_map[existing_uuid]
|
||||
# Add current episode to the episodes list
|
||||
existing_edge.episodes += edge.episodes
|
||||
edges.append(existing_edge)
|
||||
else:
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def resolve_extracted_edges(
|
||||
clients: GraphitiClients,
|
||||
extracted_edges: list[EntityEdge],
|
||||
|
|
@ -335,7 +295,7 @@ async def resolve_extracted_edges(
|
|||
edge_types_lst.append(extracted_edge_types)
|
||||
|
||||
# resolve edges with related edges in the graph and find invalidation candidates
|
||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_edge(
|
||||
|
|
@ -416,9 +376,9 @@ async def resolve_extracted_edge(
|
|||
existing_edges: list[EntityEdge],
|
||||
episode: EpisodicNode,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||
return extracted_edge, []
|
||||
return extracted_edge, [], []
|
||||
|
||||
start = time()
|
||||
|
||||
|
|
@ -457,15 +417,16 @@ async def resolve_extracted_edge(
|
|||
model_size=ModelSize.small,
|
||||
)
|
||||
|
||||
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
||||
|
||||
resolved_edge = (
|
||||
related_edges[duplicate_fact_id]
|
||||
if 0 <= duplicate_fact_id < len(related_edges)
|
||||
else extracted_edge
|
||||
duplicate_fact_ids: list[int] = list(
|
||||
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
|
||||
)
|
||||
|
||||
if duplicate_fact_id >= 0 and episode is not None:
|
||||
resolved_edge = extracted_edge
|
||||
for duplicate_fact_id in duplicate_fact_ids:
|
||||
resolved_edge = related_edges[duplicate_fact_id]
|
||||
break
|
||||
|
||||
if duplicate_fact_ids and episode is not None:
|
||||
resolved_edge.episodes.append(episode.uuid)
|
||||
|
||||
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
||||
|
|
@ -519,59 +480,12 @@ async def resolve_extracted_edge(
|
|||
break
|
||||
|
||||
# Determine which contradictory edges need to be expired
|
||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
|
||||
|
||||
return resolved_edge, invalidated_edges
|
||||
|
||||
|
||||
async def dedupe_extracted_edge(
|
||||
llm_client: LLMClient,
|
||||
extracted_edge: EntityEdge,
|
||||
related_edges: list[EntityEdge],
|
||||
episode: EpisodicNode | None = None,
|
||||
) -> EntityEdge:
|
||||
if len(related_edges) == 0:
|
||||
return extracted_edge
|
||||
|
||||
start = time()
|
||||
|
||||
# Prepare context for LLM
|
||||
related_edges_context = [
|
||||
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
||||
]
|
||||
|
||||
extracted_edge_context = {
|
||||
'fact': extracted_edge.fact,
|
||||
}
|
||||
|
||||
context = {
|
||||
'related_edges': related_edges_context,
|
||||
'extracted_edges': extracted_edge_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_edges.edge(context),
|
||||
response_model=EdgeDuplicate,
|
||||
model_size=ModelSize.small,
|
||||
invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
|
||||
resolved_edge, invalidation_candidates
|
||||
)
|
||||
duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
|
||||
|
||||
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
||||
|
||||
edge = (
|
||||
related_edges[duplicate_fact_id]
|
||||
if 0 <= duplicate_fact_id < len(related_edges)
|
||||
else extracted_edge
|
||||
)
|
||||
|
||||
if duplicate_fact_id >= 0 and episode is not None:
|
||||
edge.episodes.append(episode.uuid)
|
||||
|
||||
end = time()
|
||||
logger.debug(
|
||||
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
||||
)
|
||||
|
||||
return edge
|
||||
return resolved_edge, invalidated_edges, duplicate_edges
|
||||
|
||||
|
||||
async def dedupe_edge_list(
|
||||
|
|
|
|||
|
|
@ -176,62 +176,13 @@ async def extract_nodes(
|
|||
return extracted_nodes
|
||||
|
||||
|
||||
async def dedupe_extracted_nodes(
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
# build existing node map
|
||||
node_map: dict[str, EntityNode] = {}
|
||||
for node in existing_nodes:
|
||||
node_map[node.uuid] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_nodes_context = [
|
||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
|
||||
]
|
||||
|
||||
extracted_nodes_context = [
|
||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
||||
]
|
||||
|
||||
context = {
|
||||
'existing_nodes': existing_nodes_context,
|
||||
'extracted_nodes': extracted_nodes_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.node(context))
|
||||
|
||||
duplicate_data = llm_response.get('duplicates', [])
|
||||
|
||||
end = time()
|
||||
logger.debug(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
||||
|
||||
uuid_map: dict[str, str] = {}
|
||||
for duplicate in duplicate_data:
|
||||
uuid_value = duplicate['duplicate_of']
|
||||
uuid_map[duplicate['uuid']] = uuid_value
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
for node in extracted_nodes:
|
||||
if node.uuid in uuid_map:
|
||||
existing_uuid = uuid_map[node.uuid]
|
||||
existing_node = node_map[existing_uuid]
|
||||
nodes.append(existing_node)
|
||||
else:
|
||||
nodes.append(node)
|
||||
|
||||
return nodes, uuid_map
|
||||
|
||||
|
||||
async def resolve_extracted_nodes(
|
||||
clients: GraphitiClients,
|
||||
extracted_nodes: list[EntityNode],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
existing_nodes_override: list[EntityNode] | None = None,
|
||||
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
||||
llm_client = clients.llm_client
|
||||
driver = clients.driver
|
||||
|
|
@ -249,9 +200,13 @@ async def resolve_extracted_nodes(
|
|||
]
|
||||
)
|
||||
|
||||
existing_nodes_dict: dict[str, EntityNode] = {
|
||||
node.uuid: node for result in search_results for node in result.nodes
|
||||
}
|
||||
candidate_nodes: list[EntityNode] = (
|
||||
[node for result in search_results for node in result.nodes]
|
||||
if existing_nodes_override is None
|
||||
else existing_nodes_override
|
||||
)
|
||||
|
||||
existing_nodes_dict: dict[str, EntityNode] = {node.uuid: node for node in candidate_nodes}
|
||||
|
||||
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
|
||||
|
||||
|
|
@ -321,13 +276,11 @@ async def resolve_extracted_nodes(
|
|||
resolved_nodes.append(resolved_node)
|
||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||
|
||||
additional_duplicates: list[int] = resolution.get('additional_duplicates', [])
|
||||
for idx in additional_duplicates:
|
||||
duplicates: list[int] = resolution.get('duplicates', [])
|
||||
for idx in duplicates:
|
||||
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
||||
if existing_node == resolved_node:
|
||||
continue
|
||||
|
||||
node_duplicates.append((resolved_node, existing_nodes[idx]))
|
||||
node_duplicates.append((resolved_node, existing_node))
|
||||
|
||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue