Create Bulk Add Episode for faster processing (#9)

* benchmark logging

* load schema updates

* add extract bulk nodes and edges

* updated bulk calls

* compression updates

* bulk updates

* bulk logic first pass

* updated bulk process

* debug

* remove exact names first

* cleaned up prompt

* fix bad merge

* update

* fix merge issues
This commit is contained in:
Preston Rasmussen 2024-08-21 12:03:32 -04:00 committed by GitHub
parent a6fd0ddb75
commit d6add504bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 675 additions and 109 deletions

View file

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from uuid import uuid4
import logging
@ -76,10 +77,15 @@ class EntityEdge(Edge):
)
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
start = time()
text = self.fact.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.fact_embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.info(f"embedded {text} in {end-start} ms")
return embedding
async def save(self, driver: AsyncDriver):
@ -105,6 +111,6 @@ class EntityEdge(Edge):
invalid_at=self.invalid_at,
)
logger.info(f"Saved Node to neo4j: {self.uuid}")
logger.info(f"Saved edge to neo4j: {self.uuid}")
return result

View file

@ -4,25 +4,32 @@ import logging
from typing import Callable, LiteralString
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
from time import time
import os
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, EpisodicEdge
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.utils import (
build_episodic_edges,
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import (
extract_edges,
dedupe_extracted_edges,
from core.utils.bulk_utils import (
BulkEpisode,
extract_nodes_and_edges_bulk,
retrieve_previous_episodes_bulk,
compress_nodes,
dedupe_nodes_bulk,
resolve_edge_pointers,
dedupe_edges_bulk,
)
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation,
invalidate_edges,
prepare_edges_for_invalidation,
)
from core.utils.search.search_utils import (
edge_similarity_search,
@ -58,30 +65,47 @@ class Graphiti:
self.driver.close()
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
self,
reference_time: datetime,
last_n: int,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, last_n, sources)
return await retrieve_episodes(self.driver, reference_time, last_n, sources)
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime = None,
reference_time: datetime,
episode_type="string",
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
start = time()
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
now = datetime.now()
previous_episodes = await self.retrieve_episodes(last_n=3)
previous_episodes = await self.retrieve_episodes(
reference_time, last_n=EPISODE_WINDOW_LEN
)
episode = EpisodicNode(
name=name,
labels=[],
@ -105,7 +129,7 @@ class Graphiti:
logger.info(
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
)
new_nodes = await dedupe_extracted_nodes(
new_nodes, _ = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(
@ -151,8 +175,15 @@ class Graphiti:
)
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
entity_edges.extend(deduped_edges)
new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
@ -175,6 +206,9 @@ class Graphiti:
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
end = time()
logger.info(f"Completed add_episode in {(end-start) * 1000} ms")
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
@ -190,36 +224,19 @@ class Graphiti:
index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
]
# Add the range indices
for query in index_queries:
await self.driver.execute_query(query)
# Add the semantic indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
"""
)
await self.driver.execute_query(
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
@ -227,10 +244,7 @@ class Graphiti:
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)
await self.driver.execute_query(
""",
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
@ -238,7 +252,19 @@ class Graphiti:
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE CONSTRAINT entity_name IF NOT EXISTS
FOR (n:Entity) REQUIRE n.name IS UNIQUE
""",
"""
CREATE CONSTRAINT edge_facts IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
""",
]
await asyncio.gather(
*[self.driver.execute_query(query) for query in index_queries]
)
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
@ -267,3 +293,78 @@ class Graphiti:
context = await bfs(node_ids, self.driver)
return context
async def add_episode_bulk(
self,
bulk_episodes: list[BulkEpisode],
):
try:
start = time()
embedder = self.llm_client.client.embeddings
now = datetime.now()
episodes = [
EpisodicNode(
name=episode.name,
labels=[],
source="messages",
content=episode.content,
source_description=episode.source_description,
created_at=now,
valid_at=episode.reference_time,
)
for episode in bulk_episodes
]
# Save all the episodes
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
# Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
# Extract all nodes and edges
extracted_nodes, extracted_edges, episodic_edges = (
await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
)
# Generate embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
*[edge.generate_embedding(embedder) for edge in extracted_edges],
)
# Dedupe extracted nodes
nodes, uuid_map = await dedupe_nodes_bulk(
self.driver, self.llm_client, extracted_nodes
)
# save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes])
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
)
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
)
# save episodic edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
# Dedupe extracted edges
edges = await dedupe_edges_bulk(
self.driver, self.llm_client, extracted_edges
)
logger.info(f"extracted edge length: {len(edges)}")
# invalidate edges
# save edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
end = time()
logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms")
except Exception as e:
raise e

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from pydantic import Field
from time import time
from datetime import datetime
from uuid import uuid4
@ -35,14 +35,13 @@ class EpisodicNode(Node):
source: str = Field(description="source type")
source_description: str = Field(description="description of the data source")
content: str = Field(description="raw episode data")
valid_at: datetime = Field(
description="datetime of when the original document was created",
)
entity_edges: list[str] = Field(
description="list of entity edges referenced in this episode",
default_factory=list,
)
valid_at: datetime | None = Field(
description="datetime of when the original document was created",
default=None,
)
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
@ -80,9 +79,12 @@ class EntityNode(Node):
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
start = time()
text = self.name.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.info(f"embedded {text} in {end-start} ms")
return embedding

View file

@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
edge_list: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
edge_list: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
@ -43,7 +45,6 @@ def v1(context: dict[str, any]) -> list[Message]:
{{
"new_edges": [
{{
"name": "Unique identifier for the edge",
"fact": "one sentence description of the fact"
}}
]
@ -53,4 +54,40 @@ def v1(context: dict[str, any]) -> list[Message]:
]
versions: Versions = {"v1": v1}
def edge_list(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates edges from edge lists.",
),
Message(
role="user",
content=f"""
Given the following context, find all of the duplicates in a list of edges:
Edges:
{json.dumps(context['edges'], indent=2)}
Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
edges with the same name may not be duplicates
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
facts should be in the response
Respond with a JSON object in the following format:
{{
"unique_edges": [
{{
"fact": "fact of a unique edge",
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1, "edge_list": edge_list}

View file

@ -6,10 +6,14 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
node_list: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
node_list: PromptVersion
def v1(context: dict[str, any]) -> list[Message]:
@ -44,7 +48,6 @@ def v1(context: dict[str, any]) -> list[Message]:
"new_nodes": [
{{
"name": "Unique identifier for the node",
"summary": "Brief summary of the node's role or significance"
}}
]
}}
@ -53,4 +56,79 @@ def v1(context: dict[str, any]) -> list[Message]:
]
versions: Versions = {"v1": v1}
def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates nodes from node lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes:
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
New Nodes:
{json.dumps(context['extracted_nodes'], indent=2)}
Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
duplicate nodes may have different names
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be
the name of the Existing Node.
Respond with a JSON object in the following format:
{{
"duplicates": [
{{
"name": "name of the new node",
"duplicate_of": "name of the existing node"
}}
]
}}
""",
),
]
def node_list(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates nodes from node lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate a list of nodes:
Nodes:
{json.dumps(context['nodes'], indent=2)}
Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All dupolicate names should be grouped together in the same list
Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response
2. If a node has no duplicates, it should appear in the response in a list of only one name
Respond with a JSON object in the following format:
{{
"nodes": [
{{
"names": ["myNode", "node that is a duplicate of myNode"],
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1, "v2": v2, "node_list": node_list}

206
core/utils/bulk_utils.py Normal file
View file

@ -0,0 +1,206 @@
import asyncio
from collections import defaultdict
from datetime import datetime
from neo4j import AsyncDriver
from pydantic import BaseModel
from core.edges import EpisodicEdge, EntityEdge, Edge
from core.llm_client import LLMClient
from core.nodes import EpisodicNode, EntityNode
from core.utils import retrieve_episodes
from core.utils.maintenance.edge_operations import (
extract_edges,
build_episodic_edges,
dedupe_edge_list,
dedupe_extracted_edges,
)
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.node_operations import (
extract_nodes,
dedupe_node_list,
dedupe_extracted_nodes,
)
from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges
CHUNK_SIZE = 10
class BulkEpisode(BaseModel):
name: str
content: str
source_description: str
episode_type: str
reference_time: datetime
async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather(
*[
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
for episode in episodes
]
)
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
]
return episode_tuples
async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await asyncio.gather(
*[
extract_nodes(llm_client, episode, previous_episodes)
for episode, previous_episodes in episode_tuples
]
)
episodes, previous_episodes_list = [episode[0] for episode in episode_tuples], [
episode[1] for episode in episode_tuples
]
extracted_edges_bulk = await asyncio.gather(
*[
extract_edges(
llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]
)
for i, episode in enumerate(episodes)
]
)
episodic_edges: list[EpisodicEdge] = []
for i, episode in enumerate(episodes):
episodic_edges += build_episodic_edges(
extracted_nodes_bulk[i], episode, episode.created_at
)
nodes: list[EntityNode] = []
for extracted_nodes in extracted_nodes_bulk:
nodes += extracted_nodes
edges: list[EntityEdge] = []
for extracted_edges in extracted_edges_bulk:
edges += extracted_edges
return nodes, edges, episodic_edges
async def dedupe_nodes_bulk(
driver: AsyncDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
# Compress nodes
nodes, uuid_map = node_name_match(extracted_nodes)
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
nodes, partial_uuid_map = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
)
compressed_map.update(partial_uuid_map)
return nodes, compressed_map
async def dedupe_edges_bulk(
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
) -> list[EntityEdge]:
# Compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges)
existing_edges = await get_relevant_edges(compressed_edges, driver)
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
return edges
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map = {}
name_map = {}
for node in nodes:
if node.name in name_map:
uuid_map[node.uuid] = name_map[node.name].uuid
continue
name_map[node.name] = node
return [node for node in name_map.values()], uuid_map
async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]:
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
results = await asyncio.gather(
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
)
extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = []
for node_chunk, uuid_map_chunk in results:
compressed_nodes += node_chunk
extended_map.update(uuid_map_chunk)
# Check if we have removed all duplicates
if len(compressed_nodes) == len(nodes):
compressed_uuid_map = compress_uuid_map(extended_map)
return compressed_nodes, compressed_uuid_map
return await compress_nodes(llm_client, compressed_nodes, extended_map)
async def compress_edges(
llm_client: LLMClient, edges: list[EntityEdge]
) -> list[EntityEdge]:
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
results = await asyncio.gather(
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
)
compressed_edges: list[EntityEdge] = []
for edge_chunk in results:
compressed_edges += edge_chunk
# Check if we have removed all duplicates
if len(compressed_edges) == len(edges):
return compressed_edges
return await compress_edges(llm_client, compressed_edges)
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
# make sure all uuid values aren't mapped to other uuids
compressed_map = {}
for key, uuid in uuid_map.items():
curr_value = uuid
while curr_value in uuid_map.keys():
curr_value = uuid_map[curr_value]
compressed_map[key] = curr_value
return compressed_map
def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]):
for edge in edges:
source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid
edge.source_node_uuid = (
uuid_map[source_uuid] if source_uuid in uuid_map else source_uuid
)
edge.target_node_uuid = (
uuid_map[target_uuid] if target_uuid in uuid_map else target_uuid
)
return edges

View file

@ -1,6 +1,7 @@
import json
from typing import List
from datetime import datetime
from time import time
from pydantic import BaseModel
@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: List[EntityNode],
episode: EpisodicNode,
transaction_from: datetime,
created_at: datetime,
) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = []
@ -25,7 +26,7 @@ def build_episodic_edges(
edge = EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=transaction_from,
created_at=created_at,
)
edges.append(edge)
@ -144,6 +145,8 @@ async def extract_edges(
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
start = time()
# Prepare context for LLM
context = {
"episode_content": episode.content,
@ -167,7 +170,9 @@ async def extract_edges(
prompt_library.extract_edges.v2(context)
)
edges_data = llm_response.get("edges", [])
logger.info(f"Extracted new edges: {edges_data}")
end = time()
logger.info(f"Extracted new edges: {edges_data} in {(end - start) * 1000} ms")
# Convert the extracted data into EntityEdge objects
edges = []
@ -199,11 +204,11 @@ async def dedupe_extracted_edges(
# Create edge map
edge_map = {}
for edge in existing_edges:
edge_map[edge.name] = edge
edge_map[edge.fact] = edge
for edge in extracted_edges:
if edge.name in edge_map.keys():
if edge.fact in edge_map.keys():
continue
edge_map[edge.name] = edge
edge_map[edge.fact] = edge
# Prepare context for LLM
context = {
@ -224,7 +229,40 @@ async def dedupe_extracted_edges(
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data["name"]]
edge = edge_map[edge_data["fact"]]
edges.append(edge)
return edges
async def dedupe_edge_list(
llm_client: LLMClient,
edges: list[EntityEdge],
) -> list[EntityEdge]:
start = time()
# Create edge map
edge_map = {}
for edge in edges:
edge_map[edge.fact] = edge
# Prepare context for LLM
context = {"edges": [{"name": edge.name, "fact": edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context)
)
unique_edges_data = llm_response.get("unique_edges", [])
end = time()
logger.info(
f"Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms "
)
# Get full edge data
unique_edges = []
for edge_data in unique_edges_data:
fact = edge_data["fact"]
unique_edges.append(edge_map[fact])
return unique_edges

View file

@ -4,6 +4,7 @@ from core.nodes import EpisodicNode
from neo4j import AsyncDriver
import logging
EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__)
@ -18,11 +19,15 @@ async def clear_data(driver: AsyncDriver):
async def retrieve_episodes(
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
driver: AsyncDriver,
reference_time: datetime,
last_n: int,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
query = """
MATCH (e:Episodic)
result = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
RETURN e.content as content,
e.created_at as created_at,
e.valid_at as valid_at,
@ -32,8 +37,10 @@ async def retrieve_episodes(
e.source as source
ORDER BY e.created_at DESC
LIMIT $num_episodes
"""
result = await driver.execute_query(query, num_episodes=last_n)
""",
reference_time=reference_time,
num_episodes=last_n,
)
episodes = [
EpisodicNode(
content=record["content"],

View file

@ -1,4 +1,5 @@
from datetime import datetime
from time import time
from core.nodes import EntityNode, EpisodicNode
import logging
@ -68,6 +69,8 @@ async def extract_nodes(
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
start = time()
# Prepare context for LLM
context = {
"episode_content": episode.content,
@ -87,7 +90,9 @@ async def extract_nodes(
prompt_library.extract_nodes.v3(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Extracted new nodes: {new_nodes_data}")
end = time()
logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms")
# Convert the extracted data into EntityNode objects
new_nodes = []
for node_data in new_nodes_data:
@ -107,15 +112,13 @@ async def dedupe_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> list[EntityNode]:
# build node map
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
# build existing node map
node_map = {}
for node in existing_nodes:
node_map[node.name] = node
for node in extracted_nodes:
if node.name in node_map.keys():
continue
node_map[node.name] = node
# Prepare context for LLM
existing_nodes_context = [
@ -132,16 +135,69 @@ async def dedupe_extracted_nodes(
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.v1(context)
prompt_library.dedupe_nodes.v2(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Deduplicated nodes: {new_nodes_data}")
duplicate_data = llm_response.get("duplicates", [])
end = time()
logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms")
uuid_map = {}
for duplicate in duplicate_data:
uuid = node_map[duplicate["name"]].uuid
uuid_value = node_map[duplicate["duplicate_of"]].uuid
uuid_map[uuid] = uuid_value
# Get full node data
nodes = []
for node_data in new_nodes_data:
node = node_map[node_data["name"]]
for node in extracted_nodes:
if node.uuid in uuid_map:
existing_name = uuid_map[node.name]
existing_node = node_map[existing_name]
nodes.append(existing_node)
continue
nodes.append(node)
return nodes
return nodes, uuid_map
async def dedupe_node_list(
llm_client: LLMClient,
nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
# build node map
node_map = {}
for node in nodes:
node_map[node.name] = node
# Prepare context for LLM
nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes]
context = {
"nodes": nodes_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.node_list(context)
)
nodes_data = llm_response.get("nodes", [])
end = time()
logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms")
# Get full node data
unique_nodes = []
uuid_map: dict[str, str] = {}
for node_data in nodes_data:
node = node_map[node_data["names"][0]]
unique_nodes.append(node)
for name in node_data["names"][1:]:
uuid = node_map[name].uuid
uuid_value = node_map[node_data["names"][0]].uuid
uuid_map[uuid] = uuid_value
return unique_nodes, uuid_map

View file

@ -1,6 +1,7 @@
import asyncio
import logging
from datetime import datetime
from time import time
from neo4j import AsyncDriver
@ -9,6 +10,8 @@ from core.nodes import EntityNode
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
@ -60,7 +63,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
@ -80,9 +83,10 @@ async def edge_similarity_search(
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10
ORDER BY score DESC LIMIT $limit
""",
search_vector=search_vector,
limit=limit,
)
edges: list[EntityEdge] = []
@ -106,18 +110,16 @@ async def edge_similarity_search(
edges.append(edge)
logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}")
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector)
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
RETURN
n.uuid As uuid,
@ -127,6 +129,7 @@ async def entity_similarity_search(
ORDER BY score DESC
""",
search_vector=search_vector,
limit=limit,
)
nodes: list[EntityNode] = []
@ -141,12 +144,12 @@ async def entity_similarity_search(
)
)
logger.info(f"name semantic search results. RESULT: {nodes}")
return nodes
async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]:
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + "~"
records, _, _ = await driver.execute_query(
@ -158,9 +161,10 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
LIMIT 10
LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
nodes: list[EntityNode] = []
@ -175,12 +179,12 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
)
)
logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}")
return nodes
async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]:
async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + "~"
@ -201,9 +205,10 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10
ORDER BY score DESC LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
edges: list[EntityEdge] = []
@ -227,10 +232,6 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
edges.append(edge)
logger.info(
f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}"
)
return edges
@ -238,7 +239,9 @@ async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
relevant_nodes: dict[str, EntityNode] = {}
start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
@ -247,18 +250,27 @@ async def get_relevant_nodes(
for result in results:
for node in result:
relevant_nodes[node.uuid] = node
if node.uuid in relevant_node_uuids:
continue
logger.info(f"Found relevant nodes: {relevant_nodes.keys()}")
relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
return relevant_nodes.values()
end = time()
logger.info(
f"Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms"
)
return relevant_nodes
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
relevant_edges: dict[str, EntityEdge] = {}
start = time()
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
@ -267,8 +279,15 @@ async def get_relevant_edges(
for result in results:
for edge in result:
relevant_edges[edge.uuid] = edge
if edge.uuid in relevant_edge_uuids:
continue
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
relevant_edge_uuids.add(edge.uuid)
relevant_edges.append(edge)
return list(relevant_edges.values())
end = time()
logger.info(
f"Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms"
)
return relevant_edges

View file

@ -1,4 +1,5 @@
from core import Graphiti
from core.utils.bulk_utils import BulkEpisode
from core.utils.maintenance.graph_data_operations import clear_data
from dotenv import load_dotenv
import os
@ -37,18 +38,33 @@ def setup_logging():
return logger
async def main():
async def main(use_bulk: bool = True):
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
messages = parse_podcast_messages()
for i, message in enumerate(messages[3:50]):
await client.add_episode(
if not use_bulk:
for i, message in enumerate(messages[3:14]):
await client.add_episode(
name=f"Message {i}",
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
reference_time=message.actual_timestamp,
source_description="Podcast Transcript",
)
episodes: list[BulkEpisode] = [
BulkEpisode(
name=f"Message {i}",
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
reference_time=message.actual_timestamp,
content=f"{message.speaker_name} ({message.role}): {message.content}",
source_description="Podcast Transcript",
episode_type="string",
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:7])
]
await client.add_episode_bulk(episodes)
asyncio.run(main())