Create Bulk Add Episode for faster processing (#9)

* benchmark logging

* load schema updates

* add extract bulk nodes and edges

* updated bulk calls

* compression updates

* bulk updates

* bulk logic first pass

* updated bulk process

* debug

* remove exact names first

* cleaned up prompt

* fix bad merge

* update

* fix merge issues
This commit is contained in:
Preston Rasmussen 2024-08-21 12:03:32 -04:00 committed by GitHub
parent a6fd0ddb75
commit d6add504bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 675 additions and 109 deletions

View file

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from datetime import datetime from datetime import datetime
from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
from uuid import uuid4 from uuid import uuid4
import logging import logging
@ -76,10 +77,15 @@ class EntityEdge(Edge):
) )
async def generate_embedding(self, embedder, model="text-embedding-3-small"): async def generate_embedding(self, embedder, model="text-embedding-3-small"):
start = time()
text = self.fact.replace("\n", " ") text = self.fact.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.fact_embedding = embedding[:EMBEDDING_DIM] self.fact_embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.info(f"embedded {text} in {end-start} ms")
return embedding return embedding
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
@ -105,6 +111,6 @@ class EntityEdge(Edge):
invalid_at=self.invalid_at, invalid_at=self.invalid_at,
) )
logger.info(f"Saved Node to neo4j: {self.uuid}") logger.info(f"Saved edge to neo4j: {self.uuid}")
return result return result

View file

@ -4,25 +4,32 @@ import logging
from typing import Callable, LiteralString from typing import Callable, LiteralString
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv from dotenv import load_dotenv
from time import time
import os import os
from core.llm_client.config import EMBEDDING_DIM from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, EpisodicEdge from core.edges import EntityEdge, Edge, EpisodicEdge
from core.utils import ( from core.utils import (
build_episodic_edges, build_episodic_edges,
retrieve_episodes, retrieve_episodes,
) )
from core.llm_client import LLMClient, OpenAIClient, LLMConfig from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import ( from core.utils.bulk_utils import (
extract_edges, BulkEpisode,
dedupe_extracted_edges, extract_nodes_and_edges_bulk,
retrieve_previous_episodes_bulk,
compress_nodes,
dedupe_nodes_bulk,
resolve_edge_pointers,
dedupe_edges_bulk,
) )
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import ( from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation,
invalidate_edges, invalidate_edges,
prepare_edges_for_invalidation,
) )
from core.utils.search.search_utils import ( from core.utils.search.search_utils import (
edge_similarity_search, edge_similarity_search,
@ -58,30 +65,47 @@ class Graphiti:
self.driver.close() self.driver.close()
async def retrieve_episodes( async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages" self,
reference_time: datetime,
last_n: int,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph""" """Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, last_n, sources) return await retrieve_episodes(self.driver, reference_time, last_n, sources)
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode( async def add_episode(
self, self,
name: str, name: str,
episode_body: str, episode_body: str,
source_description: str, source_description: str,
reference_time: datetime = None, reference_time: datetime,
episode_type="string", episode_type="string",
success_callback: Callable | None = None, success_callback: Callable | None = None,
error_callback: Callable | None = None, error_callback: Callable | None = None,
): ):
"""Process an episode and update the graph""" """Process an episode and update the graph"""
try: try:
start = time()
nodes: list[EntityNode] = [] nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = [] entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = [] episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings embedder = self.llm_client.client.embeddings
now = datetime.now() now = datetime.now()
previous_episodes = await self.retrieve_episodes(last_n=3) previous_episodes = await self.retrieve_episodes(
reference_time, last_n=EPISODE_WINDOW_LEN
)
episode = EpisodicNode( episode = EpisodicNode(
name=name, name=name,
labels=[], labels=[],
@ -105,7 +129,7 @@ class Graphiti:
logger.info( logger.info(
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
) )
new_nodes = await dedupe_extracted_nodes( new_nodes, _ = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes self.llm_client, extracted_nodes, existing_nodes
) )
logger.info( logger.info(
@ -151,8 +175,15 @@ class Graphiti:
) )
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}") logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
entity_edges.extend(deduped_edges) entity_edges.extend(deduped_edges)
new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
entity_edges.extend(new_edges)
episodic_edges.extend( episodic_edges.extend(
build_episodic_edges( build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
@ -175,6 +206,9 @@ class Graphiti:
await asyncio.gather(*[node.save(self.driver) for node in nodes]) await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
end = time()
logger.info(f"Completed add_episode in {(end-start) * 1000} ms")
# for node in nodes: # for node in nodes:
# if isinstance(node, EntityNode): # if isinstance(node, EntityNode):
# await node.update_summary(self.driver) # await node.update_summary(self.driver)
@ -190,36 +224,19 @@ class Graphiti:
index_queries: list[LiteralString] = [ index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)", "CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)", "CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)", "CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)", "CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)", "CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)", "CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)", "CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)", "CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)", "CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)", "CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)", "CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)", "CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
] "CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
# Add the range indices "CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
for query in index_queries:
await self.driver.execute_query(query)
# Add the semantic indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
"""
)
await self.driver.execute_query(
""" """
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding) FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
@ -227,10 +244,7 @@ class Graphiti:
`vector.dimensions`: 1024, `vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""" """,
)
await self.driver.execute_query(
""" """
CREATE VECTOR INDEX name_embedding IF NOT EXISTS CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding) FOR (n:Entity) ON (n.name_embedding)
@ -238,7 +252,19 @@ class Graphiti:
`vector.dimensions`: 1024, `vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""",
""" """
CREATE CONSTRAINT entity_name IF NOT EXISTS
FOR (n:Entity) REQUIRE n.name IS UNIQUE
""",
"""
CREATE CONSTRAINT edge_facts IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
""",
]
await asyncio.gather(
*[self.driver.execute_query(query) for query in index_queries]
) )
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]: async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
@ -267,3 +293,78 @@ class Graphiti:
context = await bfs(node_ids, self.driver) context = await bfs(node_ids, self.driver)
return context return context
async def add_episode_bulk(
self,
bulk_episodes: list[BulkEpisode],
):
try:
start = time()
embedder = self.llm_client.client.embeddings
now = datetime.now()
episodes = [
EpisodicNode(
name=episode.name,
labels=[],
source="messages",
content=episode.content,
source_description=episode.source_description,
created_at=now,
valid_at=episode.reference_time,
)
for episode in bulk_episodes
]
# Save all the episodes
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
# Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
# Extract all nodes and edges
extracted_nodes, extracted_edges, episodic_edges = (
await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
)
# Generate embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
*[edge.generate_embedding(embedder) for edge in extracted_edges],
)
# Dedupe extracted nodes
nodes, uuid_map = await dedupe_nodes_bulk(
self.driver, self.llm_client, extracted_nodes
)
# save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes])
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
)
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
)
# save episodic edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
# Dedupe extracted edges
edges = await dedupe_edges_bulk(
self.driver, self.llm_client, extracted_edges
)
logger.info(f"extracted edge length: {len(edges)}")
# invalidate edges
# save edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
end = time()
logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms")
except Exception as e:
raise e

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pydantic import Field from time import time
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
@ -35,14 +35,13 @@ class EpisodicNode(Node):
source: str = Field(description="source type") source: str = Field(description="source type")
source_description: str = Field(description="description of the data source") source_description: str = Field(description="description of the data source")
content: str = Field(description="raw episode data") content: str = Field(description="raw episode data")
valid_at: datetime = Field(
description="datetime of when the original document was created",
)
entity_edges: list[str] = Field( entity_edges: list[str] = Field(
description="list of entity edges referenced in this episode", description="list of entity edges referenced in this episode",
default_factory=list, default_factory=list,
) )
valid_at: datetime | None = Field(
description="datetime of when the original document was created",
default=None,
)
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
result = await driver.execute_query( result = await driver.execute_query(
@ -80,9 +79,12 @@ class EntityNode(Node):
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"): async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
start = time()
text = self.name.replace("\n", " ") text = self.name.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM] self.name_embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.info(f"embedded {text} in {end-start} ms")
return embedding return embedding

View file

@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
edge_list: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
edge_list: PromptFunction
def v1(context: dict[str, any]) -> list[Message]: def v1(context: dict[str, any]) -> list[Message]:
@ -43,7 +45,6 @@ def v1(context: dict[str, any]) -> list[Message]:
{{ {{
"new_edges": [ "new_edges": [
{{ {{
"name": "Unique identifier for the edge",
"fact": "one sentence description of the fact" "fact": "one sentence description of the fact"
}} }}
] ]
@ -53,4 +54,40 @@ def v1(context: dict[str, any]) -> list[Message]:
] ]
versions: Versions = {"v1": v1} def edge_list(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates edges from edge lists.",
),
Message(
role="user",
content=f"""
Given the following context, find all of the duplicates in a list of edges:
Edges:
{json.dumps(context['edges'], indent=2)}
Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
edges with the same name may not be duplicates
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
facts should be in the response
Respond with a JSON object in the following format:
{{
"unique_edges": [
{{
"fact": "fact of a unique edge",
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1, "edge_list": edge_list}

View file

@ -6,10 +6,14 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
v2: PromptVersion
node_list: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
v2: PromptFunction
node_list: PromptVersion
def v1(context: dict[str, any]) -> list[Message]: def v1(context: dict[str, any]) -> list[Message]:
@ -44,7 +48,6 @@ def v1(context: dict[str, any]) -> list[Message]:
"new_nodes": [ "new_nodes": [
{{ {{
"name": "Unique identifier for the node", "name": "Unique identifier for the node",
"summary": "Brief summary of the node's role or significance"
}} }}
] ]
}} }}
@ -53,4 +56,79 @@ def v1(context: dict[str, any]) -> list[Message]:
] ]
versions: Versions = {"v1": v1} def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates nodes from node lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes:
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
New Nodes:
{json.dumps(context['extracted_nodes'], indent=2)}
Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
duplicate nodes may have different names
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be
the name of the Existing Node.
Respond with a JSON object in the following format:
{{
"duplicates": [
{{
"name": "name of the new node",
"duplicate_of": "name of the existing node"
}}
]
}}
""",
),
]
def node_list(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates nodes from node lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate a list of nodes:
Nodes:
{json.dumps(context['nodes'], indent=2)}
Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All dupolicate names should be grouped together in the same list
Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response
2. If a node has no duplicates, it should appear in the response in a list of only one name
Respond with a JSON object in the following format:
{{
"nodes": [
{{
"names": ["myNode", "node that is a duplicate of myNode"],
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1, "v2": v2, "node_list": node_list}

206
core/utils/bulk_utils.py Normal file
View file

@ -0,0 +1,206 @@
import asyncio
from collections import defaultdict
from datetime import datetime
from neo4j import AsyncDriver
from pydantic import BaseModel
from core.edges import EpisodicEdge, EntityEdge, Edge
from core.llm_client import LLMClient
from core.nodes import EpisodicNode, EntityNode
from core.utils import retrieve_episodes
from core.utils.maintenance.edge_operations import (
extract_edges,
build_episodic_edges,
dedupe_edge_list,
dedupe_extracted_edges,
)
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.node_operations import (
extract_nodes,
dedupe_node_list,
dedupe_extracted_nodes,
)
from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges
CHUNK_SIZE = 10
class BulkEpisode(BaseModel):
name: str
content: str
source_description: str
episode_type: str
reference_time: datetime
async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather(
*[
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
for episode in episodes
]
)
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
]
return episode_tuples
async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await asyncio.gather(
*[
extract_nodes(llm_client, episode, previous_episodes)
for episode, previous_episodes in episode_tuples
]
)
episodes, previous_episodes_list = [episode[0] for episode in episode_tuples], [
episode[1] for episode in episode_tuples
]
extracted_edges_bulk = await asyncio.gather(
*[
extract_edges(
llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]
)
for i, episode in enumerate(episodes)
]
)
episodic_edges: list[EpisodicEdge] = []
for i, episode in enumerate(episodes):
episodic_edges += build_episodic_edges(
extracted_nodes_bulk[i], episode, episode.created_at
)
nodes: list[EntityNode] = []
for extracted_nodes in extracted_nodes_bulk:
nodes += extracted_nodes
edges: list[EntityEdge] = []
for extracted_edges in extracted_edges_bulk:
edges += extracted_edges
return nodes, edges, episodic_edges
async def dedupe_nodes_bulk(
driver: AsyncDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
# Compress nodes
nodes, uuid_map = node_name_match(extracted_nodes)
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
nodes, partial_uuid_map = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
)
compressed_map.update(partial_uuid_map)
return nodes, compressed_map
async def dedupe_edges_bulk(
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
) -> list[EntityEdge]:
# Compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges)
existing_edges = await get_relevant_edges(compressed_edges, driver)
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
return edges
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map = {}
name_map = {}
for node in nodes:
if node.name in name_map:
uuid_map[node.uuid] = name_map[node.name].uuid
continue
name_map[node.name] = node
return [node for node in name_map.values()], uuid_map
async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]:
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
results = await asyncio.gather(
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
)
extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = []
for node_chunk, uuid_map_chunk in results:
compressed_nodes += node_chunk
extended_map.update(uuid_map_chunk)
# Check if we have removed all duplicates
if len(compressed_nodes) == len(nodes):
compressed_uuid_map = compress_uuid_map(extended_map)
return compressed_nodes, compressed_uuid_map
return await compress_nodes(llm_client, compressed_nodes, extended_map)
async def compress_edges(
llm_client: LLMClient, edges: list[EntityEdge]
) -> list[EntityEdge]:
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
results = await asyncio.gather(
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
)
compressed_edges: list[EntityEdge] = []
for edge_chunk in results:
compressed_edges += edge_chunk
# Check if we have removed all duplicates
if len(compressed_edges) == len(edges):
return compressed_edges
return await compress_edges(llm_client, compressed_edges)
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
# make sure all uuid values aren't mapped to other uuids
compressed_map = {}
for key, uuid in uuid_map.items():
curr_value = uuid
while curr_value in uuid_map.keys():
curr_value = uuid_map[curr_value]
compressed_map[key] = curr_value
return compressed_map
def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]):
for edge in edges:
source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid
edge.source_node_uuid = (
uuid_map[source_uuid] if source_uuid in uuid_map else source_uuid
)
edge.target_node_uuid = (
uuid_map[target_uuid] if target_uuid in uuid_map else target_uuid
)
return edges

View file

@ -1,6 +1,7 @@
import json import json
from typing import List from typing import List
from datetime import datetime from datetime import datetime
from time import time
from pydantic import BaseModel from pydantic import BaseModel
@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
def build_episodic_edges( def build_episodic_edges(
entity_nodes: List[EntityNode], entity_nodes: List[EntityNode],
episode: EpisodicNode, episode: EpisodicNode,
transaction_from: datetime, created_at: datetime,
) -> List[EpisodicEdge]: ) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = [] edges: List[EpisodicEdge] = []
@ -25,7 +26,7 @@ def build_episodic_edges(
edge = EpisodicEdge( edge = EpisodicEdge(
source_node_uuid=episode.uuid, source_node_uuid=episode.uuid,
target_node_uuid=node.uuid, target_node_uuid=node.uuid,
created_at=transaction_from, created_at=created_at,
) )
edges.append(edge) edges.append(edge)
@ -144,6 +145,8 @@ async def extract_edges(
nodes: list[EntityNode], nodes: list[EntityNode],
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time()
# Prepare context for LLM # Prepare context for LLM
context = { context = {
"episode_content": episode.content, "episode_content": episode.content,
@ -167,7 +170,9 @@ async def extract_edges(
prompt_library.extract_edges.v2(context) prompt_library.extract_edges.v2(context)
) )
edges_data = llm_response.get("edges", []) edges_data = llm_response.get("edges", [])
logger.info(f"Extracted new edges: {edges_data}")
end = time()
logger.info(f"Extracted new edges: {edges_data} in {(end - start) * 1000} ms")
# Convert the extracted data into EntityEdge objects # Convert the extracted data into EntityEdge objects
edges = [] edges = []
@ -199,11 +204,11 @@ async def dedupe_extracted_edges(
# Create edge map # Create edge map
edge_map = {} edge_map = {}
for edge in existing_edges: for edge in existing_edges:
edge_map[edge.name] = edge edge_map[edge.fact] = edge
for edge in extracted_edges: for edge in extracted_edges:
if edge.name in edge_map.keys(): if edge.fact in edge_map.keys():
continue continue
edge_map[edge.name] = edge edge_map[edge.fact] = edge
# Prepare context for LLM # Prepare context for LLM
context = { context = {
@ -224,7 +229,40 @@ async def dedupe_extracted_edges(
# Get full edge data # Get full edge data
edges = [] edges = []
for edge_data in new_edges_data: for edge_data in new_edges_data:
edge = edge_map[edge_data["name"]] edge = edge_map[edge_data["fact"]]
edges.append(edge) edges.append(edge)
return edges return edges
async def dedupe_edge_list(
llm_client: LLMClient,
edges: list[EntityEdge],
) -> list[EntityEdge]:
start = time()
# Create edge map
edge_map = {}
for edge in edges:
edge_map[edge.fact] = edge
# Prepare context for LLM
context = {"edges": [{"name": edge.name, "fact": edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context)
)
unique_edges_data = llm_response.get("unique_edges", [])
end = time()
logger.info(
f"Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms "
)
# Get full edge data
unique_edges = []
for edge_data in unique_edges_data:
fact = edge_data["fact"]
unique_edges.append(edge_map[fact])
return unique_edges

View file

@ -4,6 +4,7 @@ from core.nodes import EpisodicNode
from neo4j import AsyncDriver from neo4j import AsyncDriver
import logging import logging
EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,11 +19,15 @@ async def clear_data(driver: AsyncDriver):
async def retrieve_episodes( async def retrieve_episodes(
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages" driver: AsyncDriver,
reference_time: datetime,
last_n: int,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph""" """Retrieve the last n episodic nodes from the graph"""
query = """ result = await driver.execute_query(
MATCH (e:Episodic) """
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
RETURN e.content as content, RETURN e.content as content,
e.created_at as created_at, e.created_at as created_at,
e.valid_at as valid_at, e.valid_at as valid_at,
@ -32,8 +37,10 @@ async def retrieve_episodes(
e.source as source e.source as source
ORDER BY e.created_at DESC ORDER BY e.created_at DESC
LIMIT $num_episodes LIMIT $num_episodes
""" """,
result = await driver.execute_query(query, num_episodes=last_n) reference_time=reference_time,
num_episodes=last_n,
)
episodes = [ episodes = [
EpisodicNode( EpisodicNode(
content=record["content"], content=record["content"],

View file

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from time import time
from core.nodes import EntityNode, EpisodicNode from core.nodes import EntityNode, EpisodicNode
import logging import logging
@ -68,6 +69,8 @@ async def extract_nodes(
episode: EpisodicNode, episode: EpisodicNode,
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
) -> list[EntityNode]: ) -> list[EntityNode]:
start = time()
# Prepare context for LLM # Prepare context for LLM
context = { context = {
"episode_content": episode.content, "episode_content": episode.content,
@ -87,7 +90,9 @@ async def extract_nodes(
prompt_library.extract_nodes.v3(context) prompt_library.extract_nodes.v3(context)
) )
new_nodes_data = llm_response.get("new_nodes", []) new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Extracted new nodes: {new_nodes_data}")
end = time()
logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms")
# Convert the extracted data into EntityNode objects # Convert the extracted data into EntityNode objects
new_nodes = [] new_nodes = []
for node_data in new_nodes_data: for node_data in new_nodes_data:
@ -107,15 +112,13 @@ async def dedupe_extracted_nodes(
llm_client: LLMClient, llm_client: LLMClient,
extracted_nodes: list[EntityNode], extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode], existing_nodes: list[EntityNode],
) -> list[EntityNode]: ) -> tuple[list[EntityNode], dict[str, str]]:
# build node map start = time()
# build existing node map
node_map = {} node_map = {}
for node in existing_nodes: for node in existing_nodes:
node_map[node.name] = node node_map[node.name] = node
for node in extracted_nodes:
if node.name in node_map.keys():
continue
node_map[node.name] = node
# Prepare context for LLM # Prepare context for LLM
existing_nodes_context = [ existing_nodes_context = [
@ -132,16 +135,69 @@ async def dedupe_extracted_nodes(
} }
llm_response = await llm_client.generate_response( llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.v1(context) prompt_library.dedupe_nodes.v2(context)
) )
new_nodes_data = llm_response.get("new_nodes", []) duplicate_data = llm_response.get("duplicates", [])
logger.info(f"Deduplicated nodes: {new_nodes_data}")
end = time()
logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms")
uuid_map = {}
for duplicate in duplicate_data:
uuid = node_map[duplicate["name"]].uuid
uuid_value = node_map[duplicate["duplicate_of"]].uuid
uuid_map[uuid] = uuid_value
# Get full node data
nodes = [] nodes = []
for node_data in new_nodes_data: for node in extracted_nodes:
node = node_map[node_data["name"]] if node.uuid in uuid_map:
existing_name = uuid_map[node.name]
existing_node = node_map[existing_name]
nodes.append(existing_node)
continue
nodes.append(node) nodes.append(node)
return nodes return nodes, uuid_map
async def dedupe_node_list(
llm_client: LLMClient,
nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
# build node map
node_map = {}
for node in nodes:
node_map[node.name] = node
# Prepare context for LLM
nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes]
context = {
"nodes": nodes_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.node_list(context)
)
nodes_data = llm_response.get("nodes", [])
end = time()
logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms")
# Get full node data
unique_nodes = []
uuid_map: dict[str, str] = {}
for node_data in nodes_data:
node = node_map[node_data["names"][0]]
unique_nodes.append(node)
for name in node_data["names"][1:]:
uuid = node_map[name].uuid
uuid_value = node_map[node_data["names"][0]].uuid
uuid_map[uuid] = uuid_value
return unique_nodes, uuid_map

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
@ -9,6 +10,8 @@ from core.nodes import EntityNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
async def bfs(node_ids: list[str], driver: AsyncDriver): async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -60,7 +63,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search( async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -80,9 +83,10 @@ async def edge_similarity_search(
r.expired_at AS expired_at, r.expired_at AS expired_at,
r.valid_at AS valid_at, r.valid_at AS valid_at,
r.invalid_at AS invalid_at r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10 ORDER BY score DESC LIMIT $limit
""", """,
search_vector=search_vector, search_vector=search_vector,
limit=limit,
) )
edges: list[EntityEdge] = [] edges: list[EntityEdge] = []
@ -106,18 +110,16 @@ async def edge_similarity_search(
edges.append(edge) edges.append(edge)
logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}")
return edges return edges
async def entity_similarity_search( async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]: ) -> list[EntityNode]:
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector) CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score YIELD node AS n, score
RETURN RETURN
n.uuid As uuid, n.uuid As uuid,
@ -127,6 +129,7 @@ async def entity_similarity_search(
ORDER BY score DESC ORDER BY score DESC
""", """,
search_vector=search_vector, search_vector=search_vector,
limit=limit,
) )
nodes: list[EntityNode] = [] nodes: list[EntityNode] = []
@ -141,12 +144,12 @@ async def entity_similarity_search(
) )
) )
logger.info(f"name semantic search results. RESULT: {nodes}")
return nodes return nodes
async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]: async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes # BM25 search to get top nodes
fuzzy_query = query + "~" fuzzy_query = query + "~"
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -158,9 +161,10 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
node.created_at AS created_at, node.created_at AS created_at,
node.summary AS summary node.summary AS summary
ORDER BY score DESC ORDER BY score DESC
LIMIT 10 LIMIT $limit
""", """,
query=fuzzy_query, query=fuzzy_query,
limit=limit,
) )
nodes: list[EntityNode] = [] nodes: list[EntityNode] = []
@ -175,12 +179,12 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
) )
) )
logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}")
return nodes return nodes
async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]: async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts # fulltext search over facts
fuzzy_query = query + "~" fuzzy_query = query + "~"
@ -201,9 +205,10 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
r.expired_at AS expired_at, r.expired_at AS expired_at,
r.valid_at AS valid_at, r.valid_at AS valid_at,
r.invalid_at AS invalid_at r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10 ORDER BY score DESC LIMIT $limit
""", """,
query=fuzzy_query, query=fuzzy_query,
limit=limit,
) )
edges: list[EntityEdge] = [] edges: list[EntityEdge] = []
@ -227,10 +232,6 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
edges.append(edge) edges.append(edge)
logger.info(
f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}"
)
return edges return edges
@ -238,7 +239,9 @@ async def get_relevant_nodes(
nodes: list[EntityNode], nodes: list[EntityNode],
driver: AsyncDriver, driver: AsyncDriver,
) -> list[EntityNode]: ) -> list[EntityNode]:
relevant_nodes: dict[str, EntityNode] = {} start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()
results = await asyncio.gather( results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes], *[entity_fulltext_search(node.name, driver) for node in nodes],
@ -247,18 +250,27 @@ async def get_relevant_nodes(
for result in results: for result in results:
for node in result: for node in result:
relevant_nodes[node.uuid] = node if node.uuid in relevant_node_uuids:
continue
logger.info(f"Found relevant nodes: {relevant_nodes.keys()}") relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
return relevant_nodes.values() end = time()
logger.info(
f"Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms"
)
return relevant_nodes
async def get_relevant_edges( async def get_relevant_edges(
edges: list[EntityEdge], edges: list[EntityEdge],
driver: AsyncDriver, driver: AsyncDriver,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
relevant_edges: dict[str, EntityEdge] = {} start = time()
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
results = await asyncio.gather( results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges], *[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
@ -267,8 +279,15 @@ async def get_relevant_edges(
for result in results: for result in results:
for edge in result: for edge in result:
relevant_edges[edge.uuid] = edge if edge.uuid in relevant_edge_uuids:
continue
logger.info(f"Found relevant nodes: {relevant_edges.keys()}") relevant_edge_uuids.add(edge.uuid)
relevant_edges.append(edge)
return list(relevant_edges.values()) end = time()
logger.info(
f"Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms"
)
return relevant_edges

View file

@ -1,4 +1,5 @@
from core import Graphiti from core import Graphiti
from core.utils.bulk_utils import BulkEpisode
from core.utils.maintenance.graph_data_operations import clear_data from core.utils.maintenance.graph_data_operations import clear_data
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
@ -37,18 +38,33 @@ def setup_logging():
return logger return logger
async def main(): async def main(use_bulk: bool = True):
setup_logging() setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver) await clear_data(client.driver)
messages = parse_podcast_messages() messages = parse_podcast_messages()
for i, message in enumerate(messages[3:50]):
await client.add_episode( if not use_bulk:
for i, message in enumerate(messages[3:14]):
await client.add_episode(
name=f"Message {i}",
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
reference_time=message.actual_timestamp,
source_description="Podcast Transcript",
)
episodes: list[BulkEpisode] = [
BulkEpisode(
name=f"Message {i}", name=f"Message {i}",
episode_body=f"{message.speaker_name} ({message.role}): {message.content}", content=f"{message.speaker_name} ({message.role}): {message.content}",
reference_time=message.actual_timestamp,
source_description="Podcast Transcript", source_description="Podcast Transcript",
episode_type="string",
reference_time=message.actual_timestamp,
) )
for i, message in enumerate(messages[3:7])
]
await client.add_episode_bulk(episodes)
asyncio.run(main()) asyncio.run(main())