Create Bulk Add Episode for faster processing (#9)
* benchmark logging * load schema updates * add extract bulk nodes and edges * updated bulk calls * compression updates * bulk updates * bulk logic first pass * updated bulk process * debug * remove exact names first * cleaned up prompt * fix bad merge * update * fix merge issues
This commit is contained in:
parent
a6fd0ddb75
commit
d6add504bd
11 changed files with 675 additions and 109 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from time import time
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -76,10 +77,15 @@ class EntityEdge(Edge):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
|
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
|
||||||
|
start = time()
|
||||||
|
|
||||||
text = self.fact.replace("\n", " ")
|
text = self.fact.replace("\n", " ")
|
||||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||||
self.fact_embedding = embedding[:EMBEDDING_DIM]
|
self.fact_embedding = embedding[:EMBEDDING_DIM]
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(f"embedded {text} in {end-start} ms")
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
|
|
@ -105,6 +111,6 @@ class EntityEdge(Edge):
|
||||||
invalid_at=self.invalid_at,
|
invalid_at=self.invalid_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
logger.info(f"Saved edge to neo4j: {self.uuid}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
185
core/graphiti.py
185
core/graphiti.py
|
|
@ -4,25 +4,32 @@ import logging
|
||||||
from typing import Callable, LiteralString
|
from typing import Callable, LiteralString
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from time import time
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from core.llm_client.config import EMBEDDING_DIM
|
from core.llm_client.config import EMBEDDING_DIM
|
||||||
from core.nodes import EntityNode, EpisodicNode, Node
|
from core.nodes import EntityNode, EpisodicNode, Node
|
||||||
from core.edges import EntityEdge, EpisodicEdge
|
from core.edges import EntityEdge, Edge, EpisodicEdge
|
||||||
from core.utils import (
|
from core.utils import (
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
||||||
from core.utils.maintenance.edge_operations import (
|
from core.utils.bulk_utils import (
|
||||||
extract_edges,
|
BulkEpisode,
|
||||||
dedupe_extracted_edges,
|
extract_nodes_and_edges_bulk,
|
||||||
|
retrieve_previous_episodes_bulk,
|
||||||
|
compress_nodes,
|
||||||
|
dedupe_nodes_bulk,
|
||||||
|
resolve_edge_pointers,
|
||||||
|
dedupe_edges_bulk,
|
||||||
)
|
)
|
||||||
|
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
||||||
|
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||||
from core.utils.maintenance.temporal_operations import (
|
from core.utils.maintenance.temporal_operations import (
|
||||||
prepare_edges_for_invalidation,
|
|
||||||
invalidate_edges,
|
invalidate_edges,
|
||||||
|
prepare_edges_for_invalidation,
|
||||||
)
|
)
|
||||||
from core.utils.search.search_utils import (
|
from core.utils.search.search_utils import (
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
|
|
@ -58,30 +65,47 @@ class Graphiti:
|
||||||
self.driver.close()
|
self.driver.close()
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
self, last_n: int, sources: list[str] | None = "messages"
|
self,
|
||||||
|
reference_time: datetime,
|
||||||
|
last_n: int,
|
||||||
|
sources: list[str] | None = "messages",
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""Retrieve the last n episodic nodes from the graph"""
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
return await retrieve_episodes(self.driver, last_n, sources)
|
return await retrieve_episodes(self.driver, reference_time, last_n, sources)
|
||||||
|
|
||||||
|
# Invalidate edges that are no longer valid
|
||||||
|
async def invalidate_edges(
|
||||||
|
self,
|
||||||
|
episode: EpisodicNode,
|
||||||
|
new_nodes: list[EntityNode],
|
||||||
|
new_edges: list[EntityEdge],
|
||||||
|
relevant_schema: dict[str, any],
|
||||||
|
previous_episodes: list[EpisodicNode],
|
||||||
|
): ...
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
episode_body: str,
|
episode_body: str,
|
||||||
source_description: str,
|
source_description: str,
|
||||||
reference_time: datetime = None,
|
reference_time: datetime,
|
||||||
episode_type="string",
|
episode_type="string",
|
||||||
success_callback: Callable | None = None,
|
success_callback: Callable | None = None,
|
||||||
error_callback: Callable | None = None,
|
error_callback: Callable | None = None,
|
||||||
):
|
):
|
||||||
"""Process an episode and update the graph"""
|
"""Process an episode and update the graph"""
|
||||||
try:
|
try:
|
||||||
|
start = time()
|
||||||
|
|
||||||
nodes: list[EntityNode] = []
|
nodes: list[EntityNode] = []
|
||||||
entity_edges: list[EntityEdge] = []
|
entity_edges: list[EntityEdge] = []
|
||||||
episodic_edges: list[EpisodicEdge] = []
|
episodic_edges: list[EpisodicEdge] = []
|
||||||
embedder = self.llm_client.client.embeddings
|
embedder = self.llm_client.client.embeddings
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
previous_episodes = await self.retrieve_episodes(last_n=3)
|
previous_episodes = await self.retrieve_episodes(
|
||||||
|
reference_time, last_n=EPISODE_WINDOW_LEN
|
||||||
|
)
|
||||||
episode = EpisodicNode(
|
episode = EpisodicNode(
|
||||||
name=name,
|
name=name,
|
||||||
labels=[],
|
labels=[],
|
||||||
|
|
@ -105,7 +129,7 @@ class Graphiti:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
|
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
|
||||||
)
|
)
|
||||||
new_nodes = await dedupe_extracted_nodes(
|
new_nodes, _ = await dedupe_extracted_nodes(
|
||||||
self.llm_client, extracted_nodes, existing_nodes
|
self.llm_client, extracted_nodes, existing_nodes
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -151,8 +175,15 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
|
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
|
||||||
|
|
||||||
entity_edges.extend(deduped_edges)
|
entity_edges.extend(deduped_edges)
|
||||||
|
|
||||||
|
new_edges = await dedupe_extracted_edges(
|
||||||
|
self.llm_client, extracted_edges, existing_edges
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
|
||||||
|
|
||||||
|
entity_edges.extend(new_edges)
|
||||||
episodic_edges.extend(
|
episodic_edges.extend(
|
||||||
build_episodic_edges(
|
build_episodic_edges(
|
||||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||||
|
|
@ -175,6 +206,9 @@ class Graphiti:
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(f"Completed add_episode in {(end-start) * 1000} ms")
|
||||||
# for node in nodes:
|
# for node in nodes:
|
||||||
# if isinstance(node, EntityNode):
|
# if isinstance(node, EntityNode):
|
||||||
# await node.update_summary(self.driver)
|
# await node.update_summary(self.driver)
|
||||||
|
|
@ -190,36 +224,19 @@ class Graphiti:
|
||||||
index_queries: list[LiteralString] = [
|
index_queries: list[LiteralString] = [
|
||||||
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
||||||
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
||||||
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
|
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
|
||||||
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
|
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
|
||||||
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||||
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||||
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
||||||
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
||||||
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
|
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
|
||||||
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
|
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
|
||||||
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
|
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
|
||||||
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
|
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
|
||||||
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
|
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
|
||||||
]
|
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
|
||||||
# Add the range indices
|
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
|
||||||
for query in index_queries:
|
|
||||||
await self.driver.execute_query(query)
|
|
||||||
|
|
||||||
# Add the semantic indices
|
|
||||||
await self.driver.execute_query(
|
|
||||||
"""
|
|
||||||
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.driver.execute_query(
|
|
||||||
"""
|
|
||||||
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.driver.execute_query(
|
|
||||||
"""
|
"""
|
||||||
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||||
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
||||||
|
|
@ -227,10 +244,7 @@ class Graphiti:
|
||||||
`vector.dimensions`: 1024,
|
`vector.dimensions`: 1024,
|
||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
"""
|
""",
|
||||||
)
|
|
||||||
|
|
||||||
await self.driver.execute_query(
|
|
||||||
"""
|
"""
|
||||||
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
||||||
FOR (n:Entity) ON (n.name_embedding)
|
FOR (n:Entity) ON (n.name_embedding)
|
||||||
|
|
@ -238,7 +252,19 @@ class Graphiti:
|
||||||
`vector.dimensions`: 1024,
|
`vector.dimensions`: 1024,
|
||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
|
""",
|
||||||
"""
|
"""
|
||||||
|
CREATE CONSTRAINT entity_name IF NOT EXISTS
|
||||||
|
FOR (n:Entity) REQUIRE n.name IS UNIQUE
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE CONSTRAINT edge_facts IF NOT EXISTS
|
||||||
|
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
*[self.driver.execute_query(query) for query in index_queries]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
|
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
|
||||||
|
|
@ -267,3 +293,78 @@ class Graphiti:
|
||||||
context = await bfs(node_ids, self.driver)
|
context = await bfs(node_ids, self.driver)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
async def add_episode_bulk(
|
||||||
|
self,
|
||||||
|
bulk_episodes: list[BulkEpisode],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
start = time()
|
||||||
|
embedder = self.llm_client.client.embeddings
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
episodes = [
|
||||||
|
EpisodicNode(
|
||||||
|
name=episode.name,
|
||||||
|
labels=[],
|
||||||
|
source="messages",
|
||||||
|
content=episode.content,
|
||||||
|
source_description=episode.source_description,
|
||||||
|
created_at=now,
|
||||||
|
valid_at=episode.reference_time,
|
||||||
|
)
|
||||||
|
for episode in bulk_episodes
|
||||||
|
]
|
||||||
|
|
||||||
|
# Save all the episodes
|
||||||
|
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
|
||||||
|
|
||||||
|
# Get previous episode context for each episode
|
||||||
|
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||||
|
|
||||||
|
# Extract all nodes and edges
|
||||||
|
extracted_nodes, extracted_edges, episodic_edges = (
|
||||||
|
await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
await asyncio.gather(
|
||||||
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
|
||||||
|
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dedupe extracted nodes
|
||||||
|
nodes, uuid_map = await dedupe_nodes_bulk(
|
||||||
|
self.driver, self.llm_client, extracted_nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
# save nodes to KG
|
||||||
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||||
|
|
||||||
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||||
|
extracted_edges: list[EntityEdge] = resolve_edge_pointers(
|
||||||
|
extracted_edges, uuid_map
|
||||||
|
)
|
||||||
|
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(
|
||||||
|
episodic_edges, uuid_map
|
||||||
|
)
|
||||||
|
|
||||||
|
# save episodic edges to KG
|
||||||
|
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||||
|
|
||||||
|
# Dedupe extracted edges
|
||||||
|
edges = await dedupe_edges_bulk(
|
||||||
|
self.driver, self.llm_client, extracted_edges
|
||||||
|
)
|
||||||
|
logger.info(f"extracted edge length: {len(edges)}")
|
||||||
|
|
||||||
|
# invalidate edges
|
||||||
|
|
||||||
|
# save edges to KG
|
||||||
|
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pydantic import Field
|
from time import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
@ -35,14 +35,13 @@ class EpisodicNode(Node):
|
||||||
source: str = Field(description="source type")
|
source: str = Field(description="source type")
|
||||||
source_description: str = Field(description="description of the data source")
|
source_description: str = Field(description="description of the data source")
|
||||||
content: str = Field(description="raw episode data")
|
content: str = Field(description="raw episode data")
|
||||||
|
valid_at: datetime = Field(
|
||||||
|
description="datetime of when the original document was created",
|
||||||
|
)
|
||||||
entity_edges: list[str] = Field(
|
entity_edges: list[str] = Field(
|
||||||
description="list of entity edges referenced in this episode",
|
description="list of entity edges referenced in this episode",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
valid_at: datetime | None = Field(
|
|
||||||
description="datetime of when the original document was created",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
|
|
@ -80,9 +79,12 @@ class EntityNode(Node):
|
||||||
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
||||||
|
|
||||||
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
|
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
|
||||||
|
start = time()
|
||||||
text = self.name.replace("\n", " ")
|
text = self.name.replace("\n", " ")
|
||||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||||
self.name_embedding = embedding[:EMBEDDING_DIM]
|
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||||
|
end = time()
|
||||||
|
logger.info(f"embedded {text} in {end-start} ms")
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
|
||||||
|
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
|
edge_list: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
|
edge_list: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, any]) -> list[Message]:
|
||||||
|
|
@ -43,7 +45,6 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
{{
|
{{
|
||||||
"new_edges": [
|
"new_edges": [
|
||||||
{{
|
{{
|
||||||
"name": "Unique identifier for the edge",
|
|
||||||
"fact": "one sentence description of the fact"
|
"fact": "one sentence description of the fact"
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
|
|
@ -53,4 +54,40 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {"v1": v1}
|
def edge_list(context: dict[str, any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role="system",
|
||||||
|
content="You are a helpful assistant that de-duplicates edges from edge lists.",
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
content=f"""
|
||||||
|
Given the following context, find all of the duplicates in a list of edges:
|
||||||
|
|
||||||
|
Edges:
|
||||||
|
{json.dumps(context['edges'], indent=2)}
|
||||||
|
|
||||||
|
Task:
|
||||||
|
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. Use both the name and fact of edges to determine if they are duplicates,
|
||||||
|
edges with the same name may not be duplicates
|
||||||
|
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
|
||||||
|
facts should be in the response
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"unique_edges": [
|
||||||
|
{{
|
||||||
|
"fact": "fact of a unique edge",
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {"v1": v1, "edge_list": edge_list}
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,14 @@ from .models import Message, PromptVersion, PromptFunction
|
||||||
|
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
|
v2: PromptVersion
|
||||||
|
node_list: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
|
v2: PromptFunction
|
||||||
|
node_list: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, any]) -> list[Message]:
|
||||||
|
|
@ -44,7 +48,6 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
"new_nodes": [
|
"new_nodes": [
|
||||||
{{
|
{{
|
||||||
"name": "Unique identifier for the node",
|
"name": "Unique identifier for the node",
|
||||||
"summary": "Brief summary of the node's role or significance"
|
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
@ -53,4 +56,79 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {"v1": v1}
|
def v2(context: dict[str, any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role="system",
|
||||||
|
content="You are a helpful assistant that de-duplicates nodes from node lists.",
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
content=f"""
|
||||||
|
Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes:
|
||||||
|
|
||||||
|
Existing Nodes:
|
||||||
|
{json.dumps(context['existing_nodes'], indent=2)}
|
||||||
|
|
||||||
|
New Nodes:
|
||||||
|
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||||
|
|
||||||
|
Task:
|
||||||
|
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||||
|
duplicate nodes may have different names
|
||||||
|
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be
|
||||||
|
the name of the Existing Node.
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"duplicates": [
|
||||||
|
{{
|
||||||
|
"name": "name of the new node",
|
||||||
|
"duplicate_of": "name of the existing node"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def node_list(context: dict[str, any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role="system",
|
||||||
|
content="You are a helpful assistant that de-duplicates nodes from node lists.",
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
content=f"""
|
||||||
|
Given the following context, deduplicate a list of nodes:
|
||||||
|
|
||||||
|
Nodes:
|
||||||
|
{json.dumps(context['nodes'], indent=2)}
|
||||||
|
|
||||||
|
Task:
|
||||||
|
1. Group nodes together such that all duplicate nodes are in the same list of names
|
||||||
|
2. All dupolicate names should be grouped together in the same list
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. Each name from the list of nodes should appear EXACTLY once in your response
|
||||||
|
2. If a node has no duplicates, it should appear in the response in a list of only one name
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"nodes": [
|
||||||
|
{{
|
||||||
|
"names": ["myNode", "node that is a duplicate of myNode"],
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {"v1": v1, "v2": v2, "node_list": node_list}
|
||||||
|
|
|
||||||
206
core/utils/bulk_utils.py
Normal file
206
core/utils/bulk_utils.py
Normal file
|
|
@ -0,0 +1,206 @@
|
||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from neo4j import AsyncDriver
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.edges import EpisodicEdge, EntityEdge, Edge
|
||||||
|
from core.llm_client import LLMClient
|
||||||
|
from core.nodes import EpisodicNode, EntityNode
|
||||||
|
from core.utils import retrieve_episodes
|
||||||
|
from core.utils.maintenance.edge_operations import (
|
||||||
|
extract_edges,
|
||||||
|
build_episodic_edges,
|
||||||
|
dedupe_edge_list,
|
||||||
|
dedupe_extracted_edges,
|
||||||
|
)
|
||||||
|
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||||
|
from core.utils.maintenance.node_operations import (
|
||||||
|
extract_nodes,
|
||||||
|
dedupe_node_list,
|
||||||
|
dedupe_extracted_nodes,
|
||||||
|
)
|
||||||
|
from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges
|
||||||
|
|
||||||
|
CHUNK_SIZE = 10
|
||||||
|
|
||||||
|
|
||||||
|
class BulkEpisode(BaseModel):
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
source_description: str
|
||||||
|
episode_type: str
|
||||||
|
reference_time: datetime
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_previous_episodes_bulk(
|
||||||
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||||
|
previous_episodes_list = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
|
||||||
|
for episode in episodes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
|
||||||
|
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
|
||||||
|
]
|
||||||
|
|
||||||
|
return episode_tuples
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_nodes_and_edges_bulk(
|
||||||
|
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||||
|
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||||
|
extracted_nodes_bulk = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
extract_nodes(llm_client, episode, previous_episodes)
|
||||||
|
for episode, previous_episodes in episode_tuples
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes, previous_episodes_list = [episode[0] for episode in episode_tuples], [
|
||||||
|
episode[1] for episode in episode_tuples
|
||||||
|
]
|
||||||
|
|
||||||
|
extracted_edges_bulk = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
extract_edges(
|
||||||
|
llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]
|
||||||
|
)
|
||||||
|
for i, episode in enumerate(episodes)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
episodic_edges: list[EpisodicEdge] = []
|
||||||
|
for i, episode in enumerate(episodes):
|
||||||
|
episodic_edges += build_episodic_edges(
|
||||||
|
extracted_nodes_bulk[i], episode, episode.created_at
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes: list[EntityNode] = []
|
||||||
|
for extracted_nodes in extracted_nodes_bulk:
|
||||||
|
nodes += extracted_nodes
|
||||||
|
|
||||||
|
edges: list[EntityEdge] = []
|
||||||
|
for extracted_edges in extracted_edges_bulk:
|
||||||
|
edges += extracted_edges
|
||||||
|
|
||||||
|
return nodes, edges, episodic_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def dedupe_nodes_bulk(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
llm_client: LLMClient,
|
||||||
|
extracted_nodes: list[EntityNode],
|
||||||
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
# Compress nodes
|
||||||
|
nodes, uuid_map = node_name_match(extracted_nodes)
|
||||||
|
|
||||||
|
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
||||||
|
|
||||||
|
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
|
||||||
|
|
||||||
|
nodes, partial_uuid_map = await dedupe_extracted_nodes(
|
||||||
|
llm_client, compressed_nodes, existing_nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_map.update(partial_uuid_map)
|
||||||
|
|
||||||
|
return nodes, compressed_map
|
||||||
|
|
||||||
|
|
||||||
|
async def dedupe_edges_bulk(
|
||||||
|
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
# Compress edges
|
||||||
|
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
||||||
|
|
||||||
|
existing_edges = await get_relevant_edges(compressed_edges, driver)
|
||||||
|
|
||||||
|
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
uuid_map = {}
|
||||||
|
name_map = {}
|
||||||
|
for node in nodes:
|
||||||
|
if node.name in name_map:
|
||||||
|
uuid_map[node.uuid] = name_map[node.name].uuid
|
||||||
|
continue
|
||||||
|
|
||||||
|
name_map[node.name] = node
|
||||||
|
|
||||||
|
return [node for node in name_map.values()], uuid_map
|
||||||
|
|
||||||
|
|
||||||
|
async def compress_nodes(
|
||||||
|
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
||||||
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
||||||
|
)
|
||||||
|
|
||||||
|
extended_map = dict(uuid_map)
|
||||||
|
compressed_nodes: list[EntityNode] = []
|
||||||
|
for node_chunk, uuid_map_chunk in results:
|
||||||
|
compressed_nodes += node_chunk
|
||||||
|
extended_map.update(uuid_map_chunk)
|
||||||
|
|
||||||
|
# Check if we have removed all duplicates
|
||||||
|
if len(compressed_nodes) == len(nodes):
|
||||||
|
compressed_uuid_map = compress_uuid_map(extended_map)
|
||||||
|
return compressed_nodes, compressed_uuid_map
|
||||||
|
|
||||||
|
return await compress_nodes(llm_client, compressed_nodes, extended_map)
|
||||||
|
|
||||||
|
|
||||||
|
async def compress_edges(
|
||||||
|
llm_client: LLMClient, edges: list[EntityEdge]
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_edges: list[EntityEdge] = []
|
||||||
|
for edge_chunk in results:
|
||||||
|
compressed_edges += edge_chunk
|
||||||
|
|
||||||
|
# Check if we have removed all duplicates
|
||||||
|
if len(compressed_edges) == len(edges):
|
||||||
|
return compressed_edges
|
||||||
|
|
||||||
|
return await compress_edges(llm_client, compressed_edges)
|
||||||
|
|
||||||
|
|
||||||
|
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
||||||
|
# make sure all uuid values aren't mapped to other uuids
|
||||||
|
compressed_map = {}
|
||||||
|
for key, uuid in uuid_map.items():
|
||||||
|
curr_value = uuid
|
||||||
|
while curr_value in uuid_map.keys():
|
||||||
|
curr_value = uuid_map[curr_value]
|
||||||
|
|
||||||
|
compressed_map[key] = curr_value
|
||||||
|
return compressed_map
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]):
|
||||||
|
for edge in edges:
|
||||||
|
source_uuid = edge.source_node_uuid
|
||||||
|
target_uuid = edge.target_node_uuid
|
||||||
|
edge.source_node_uuid = (
|
||||||
|
uuid_map[source_uuid] if source_uuid in uuid_map else source_uuid
|
||||||
|
)
|
||||||
|
edge.target_node_uuid = (
|
||||||
|
uuid_map[target_uuid] if target_uuid in uuid_map else target_uuid
|
||||||
|
)
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from time import time
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||||
def build_episodic_edges(
|
def build_episodic_edges(
|
||||||
entity_nodes: List[EntityNode],
|
entity_nodes: List[EntityNode],
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
transaction_from: datetime,
|
created_at: datetime,
|
||||||
) -> List[EpisodicEdge]:
|
) -> List[EpisodicEdge]:
|
||||||
edges: List[EpisodicEdge] = []
|
edges: List[EpisodicEdge] = []
|
||||||
|
|
||||||
|
|
@ -25,7 +26,7 @@ def build_episodic_edges(
|
||||||
edge = EpisodicEdge(
|
edge = EpisodicEdge(
|
||||||
source_node_uuid=episode.uuid,
|
source_node_uuid=episode.uuid,
|
||||||
target_node_uuid=node.uuid,
|
target_node_uuid=node.uuid,
|
||||||
created_at=transaction_from,
|
created_at=created_at,
|
||||||
)
|
)
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
||||||
|
|
@ -144,6 +145,8 @@ async def extract_edges(
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
|
start = time()
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
"episode_content": episode.content,
|
"episode_content": episode.content,
|
||||||
|
|
@ -167,7 +170,9 @@ async def extract_edges(
|
||||||
prompt_library.extract_edges.v2(context)
|
prompt_library.extract_edges.v2(context)
|
||||||
)
|
)
|
||||||
edges_data = llm_response.get("edges", [])
|
edges_data = llm_response.get("edges", [])
|
||||||
logger.info(f"Extracted new edges: {edges_data}")
|
|
||||||
|
end = time()
|
||||||
|
logger.info(f"Extracted new edges: {edges_data} in {(end - start) * 1000} ms")
|
||||||
|
|
||||||
# Convert the extracted data into EntityEdge objects
|
# Convert the extracted data into EntityEdge objects
|
||||||
edges = []
|
edges = []
|
||||||
|
|
@ -199,11 +204,11 @@ async def dedupe_extracted_edges(
|
||||||
# Create edge map
|
# Create edge map
|
||||||
edge_map = {}
|
edge_map = {}
|
||||||
for edge in existing_edges:
|
for edge in existing_edges:
|
||||||
edge_map[edge.name] = edge
|
edge_map[edge.fact] = edge
|
||||||
for edge in extracted_edges:
|
for edge in extracted_edges:
|
||||||
if edge.name in edge_map.keys():
|
if edge.fact in edge_map.keys():
|
||||||
continue
|
continue
|
||||||
edge_map[edge.name] = edge
|
edge_map[edge.fact] = edge
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -224,7 +229,40 @@ async def dedupe_extracted_edges(
|
||||||
# Get full edge data
|
# Get full edge data
|
||||||
edges = []
|
edges = []
|
||||||
for edge_data in new_edges_data:
|
for edge_data in new_edges_data:
|
||||||
edge = edge_map[edge_data["name"]]
|
edge = edge_map[edge_data["fact"]]
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
async def dedupe_edge_list(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
edges: list[EntityEdge],
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
start = time()
|
||||||
|
|
||||||
|
# Create edge map
|
||||||
|
edge_map = {}
|
||||||
|
for edge in edges:
|
||||||
|
edge_map[edge.fact] = edge
|
||||||
|
|
||||||
|
# Prepare context for LLM
|
||||||
|
context = {"edges": [{"name": edge.name, "fact": edge.fact} for edge in edges]}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.dedupe_edges.edge_list(context)
|
||||||
|
)
|
||||||
|
unique_edges_data = llm_response.get("unique_edges", [])
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(
|
||||||
|
f"Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms "
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get full edge data
|
||||||
|
unique_edges = []
|
||||||
|
for edge_data in unique_edges_data:
|
||||||
|
fact = edge_data["fact"]
|
||||||
|
unique_edges.append(edge_map[fact])
|
||||||
|
|
||||||
|
return unique_edges
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from core.nodes import EpisodicNode
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
EPISODE_WINDOW_LEN = 3
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -18,11 +19,15 @@ async def clear_data(driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
|
driver: AsyncDriver,
|
||||||
|
reference_time: datetime,
|
||||||
|
last_n: int,
|
||||||
|
sources: list[str] | None = "messages",
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""Retrieve the last n episodic nodes from the graph"""
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
query = """
|
result = await driver.execute_query(
|
||||||
MATCH (e:Episodic)
|
"""
|
||||||
|
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||||
RETURN e.content as content,
|
RETURN e.content as content,
|
||||||
e.created_at as created_at,
|
e.created_at as created_at,
|
||||||
e.valid_at as valid_at,
|
e.valid_at as valid_at,
|
||||||
|
|
@ -32,8 +37,10 @@ async def retrieve_episodes(
|
||||||
e.source as source
|
e.source as source
|
||||||
ORDER BY e.created_at DESC
|
ORDER BY e.created_at DESC
|
||||||
LIMIT $num_episodes
|
LIMIT $num_episodes
|
||||||
"""
|
""",
|
||||||
result = await driver.execute_query(query, num_episodes=last_n)
|
reference_time=reference_time,
|
||||||
|
num_episodes=last_n,
|
||||||
|
)
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
content=record["content"],
|
content=record["content"],
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from time import time
|
||||||
|
|
||||||
from core.nodes import EntityNode, EpisodicNode
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -68,6 +69,8 @@ async def extract_nodes(
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
|
start = time()
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
"episode_content": episode.content,
|
"episode_content": episode.content,
|
||||||
|
|
@ -87,7 +90,9 @@ async def extract_nodes(
|
||||||
prompt_library.extract_nodes.v3(context)
|
prompt_library.extract_nodes.v3(context)
|
||||||
)
|
)
|
||||||
new_nodes_data = llm_response.get("new_nodes", [])
|
new_nodes_data = llm_response.get("new_nodes", [])
|
||||||
logger.info(f"Extracted new nodes: {new_nodes_data}")
|
|
||||||
|
end = time()
|
||||||
|
logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms")
|
||||||
# Convert the extracted data into EntityNode objects
|
# Convert the extracted data into EntityNode objects
|
||||||
new_nodes = []
|
new_nodes = []
|
||||||
for node_data in new_nodes_data:
|
for node_data in new_nodes_data:
|
||||||
|
|
@ -107,15 +112,13 @@ async def dedupe_extracted_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
existing_nodes: list[EntityNode],
|
existing_nodes: list[EntityNode],
|
||||||
) -> list[EntityNode]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
# build node map
|
start = time()
|
||||||
|
|
||||||
|
# build existing node map
|
||||||
node_map = {}
|
node_map = {}
|
||||||
for node in existing_nodes:
|
for node in existing_nodes:
|
||||||
node_map[node.name] = node
|
node_map[node.name] = node
|
||||||
for node in extracted_nodes:
|
|
||||||
if node.name in node_map.keys():
|
|
||||||
continue
|
|
||||||
node_map[node.name] = node
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
existing_nodes_context = [
|
existing_nodes_context = [
|
||||||
|
|
@ -132,16 +135,69 @@ async def dedupe_extracted_nodes(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.dedupe_nodes.v1(context)
|
prompt_library.dedupe_nodes.v2(context)
|
||||||
)
|
)
|
||||||
|
|
||||||
new_nodes_data = llm_response.get("new_nodes", [])
|
duplicate_data = llm_response.get("duplicates", [])
|
||||||
logger.info(f"Deduplicated nodes: {new_nodes_data}")
|
|
||||||
|
end = time()
|
||||||
|
logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms")
|
||||||
|
|
||||||
|
uuid_map = {}
|
||||||
|
for duplicate in duplicate_data:
|
||||||
|
uuid = node_map[duplicate["name"]].uuid
|
||||||
|
uuid_value = node_map[duplicate["duplicate_of"]].uuid
|
||||||
|
uuid_map[uuid] = uuid_value
|
||||||
|
|
||||||
# Get full node data
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node_data in new_nodes_data:
|
for node in extracted_nodes:
|
||||||
node = node_map[node_data["name"]]
|
if node.uuid in uuid_map:
|
||||||
|
existing_name = uuid_map[node.name]
|
||||||
|
existing_node = node_map[existing_name]
|
||||||
|
nodes.append(existing_node)
|
||||||
|
continue
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
|
|
||||||
return nodes
|
return nodes, uuid_map
|
||||||
|
|
||||||
|
|
||||||
|
async def dedupe_node_list(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
nodes: list[EntityNode],
|
||||||
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
start = time()
|
||||||
|
|
||||||
|
# build node map
|
||||||
|
node_map = {}
|
||||||
|
for node in nodes:
|
||||||
|
node_map[node.name] = node
|
||||||
|
|
||||||
|
# Prepare context for LLM
|
||||||
|
nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes]
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"nodes": nodes_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.dedupe_nodes.node_list(context)
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes_data = llm_response.get("nodes", [])
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms")
|
||||||
|
|
||||||
|
# Get full node data
|
||||||
|
unique_nodes = []
|
||||||
|
uuid_map: dict[str, str] = {}
|
||||||
|
for node_data in nodes_data:
|
||||||
|
node = node_map[node_data["names"][0]]
|
||||||
|
unique_nodes.append(node)
|
||||||
|
|
||||||
|
for name in node_data["names"][1:]:
|
||||||
|
uuid = node_map[name].uuid
|
||||||
|
uuid_value = node_map[node_data["names"][0]].uuid
|
||||||
|
uuid_map[uuid] = uuid_value
|
||||||
|
|
||||||
|
return unique_nodes, uuid_map
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
||||||
|
|
@ -9,6 +10,8 @@ from core.nodes import EntityNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
|
|
||||||
|
|
||||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -60,7 +63,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
async def edge_similarity_search(
|
||||||
search_vector: list[float], driver: AsyncDriver
|
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -80,9 +83,10 @@ async def edge_similarity_search(
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at
|
||||||
ORDER BY score DESC LIMIT 10
|
ORDER BY score DESC LIMIT $limit
|
||||||
""",
|
""",
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges: list[EntityEdge] = []
|
edges: list[EntityEdge] = []
|
||||||
|
|
@ -106,18 +110,16 @@ async def edge_similarity_search(
|
||||||
|
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
||||||
logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}")
|
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
async def entity_similarity_search(
|
async def entity_similarity_search(
|
||||||
search_vector: list[float], driver: AsyncDriver
|
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector)
|
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
|
|
@ -127,6 +129,7 @@ async def entity_similarity_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
""",
|
""",
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
nodes: list[EntityNode] = []
|
nodes: list[EntityNode] = []
|
||||||
|
|
||||||
|
|
@ -141,12 +144,12 @@ async def entity_similarity_search(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"name semantic search results. RESULT: {nodes}")
|
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]:
|
async def entity_fulltext_search(
|
||||||
|
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = query + "~"
|
fuzzy_query = query + "~"
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -158,9 +161,10 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
|
||||||
node.created_at AS created_at,
|
node.created_at AS created_at,
|
||||||
node.summary AS summary
|
node.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT 10
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
nodes: list[EntityNode] = []
|
nodes: list[EntityNode] = []
|
||||||
|
|
||||||
|
|
@ -175,12 +179,12 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}")
|
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]:
|
async def edge_fulltext_search(
|
||||||
|
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
fuzzy_query = query + "~"
|
fuzzy_query = query + "~"
|
||||||
|
|
||||||
|
|
@ -201,9 +205,10 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at
|
||||||
ORDER BY score DESC LIMIT 10
|
ORDER BY score DESC LIMIT $limit
|
||||||
""",
|
""",
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges: list[EntityEdge] = []
|
edges: list[EntityEdge] = []
|
||||||
|
|
@ -227,10 +232,6 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
|
||||||
|
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -238,7 +239,9 @@ async def get_relevant_nodes(
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
relevant_nodes: dict[str, EntityNode] = {}
|
start = time()
|
||||||
|
relevant_nodes: list[EntityNode] = []
|
||||||
|
relevant_node_uuids = set()
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
||||||
|
|
@ -247,18 +250,27 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
for node in result:
|
for node in result:
|
||||||
relevant_nodes[node.uuid] = node
|
if node.uuid in relevant_node_uuids:
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info(f"Found relevant nodes: {relevant_nodes.keys()}")
|
relevant_node_uuids.add(node.uuid)
|
||||||
|
relevant_nodes.append(node)
|
||||||
|
|
||||||
return relevant_nodes.values()
|
end = time()
|
||||||
|
logger.info(
|
||||||
|
f"Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return relevant_nodes
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
relevant_edges: dict[str, EntityEdge] = {}
|
start = time()
|
||||||
|
relevant_edges: list[EntityEdge] = []
|
||||||
|
relevant_edge_uuids = set()
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
||||||
|
|
@ -267,8 +279,15 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
for edge in result:
|
for edge in result:
|
||||||
relevant_edges[edge.uuid] = edge
|
if edge.uuid in relevant_edge_uuids:
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
|
relevant_edge_uuids.add(edge.uuid)
|
||||||
|
relevant_edges.append(edge)
|
||||||
|
|
||||||
return list(relevant_edges.values())
|
end = time()
|
||||||
|
logger.info(
|
||||||
|
f"Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return relevant_edges
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from core import Graphiti
|
from core import Graphiti
|
||||||
|
from core.utils.bulk_utils import BulkEpisode
|
||||||
from core.utils.maintenance.graph_data_operations import clear_data
|
from core.utils.maintenance.graph_data_operations import clear_data
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import os
|
import os
|
||||||
|
|
@ -37,18 +38,33 @@ def setup_logging():
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main(use_bulk: bool = True):
|
||||||
setup_logging()
|
setup_logging()
|
||||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
for i, message in enumerate(messages[3:50]):
|
|
||||||
await client.add_episode(
|
if not use_bulk:
|
||||||
|
for i, message in enumerate(messages[3:14]):
|
||||||
|
await client.add_episode(
|
||||||
|
name=f"Message {i}",
|
||||||
|
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
|
||||||
|
reference_time=message.actual_timestamp,
|
||||||
|
source_description="Podcast Transcript",
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes: list[BulkEpisode] = [
|
||||||
|
BulkEpisode(
|
||||||
name=f"Message {i}",
|
name=f"Message {i}",
|
||||||
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
|
content=f"{message.speaker_name} ({message.role}): {message.content}",
|
||||||
reference_time=message.actual_timestamp,
|
|
||||||
source_description="Podcast Transcript",
|
source_description="Podcast Transcript",
|
||||||
|
episode_type="string",
|
||||||
|
reference_time=message.actual_timestamp,
|
||||||
)
|
)
|
||||||
|
for i, message in enumerate(messages[3:7])
|
||||||
|
]
|
||||||
|
|
||||||
|
await client.add_episode_bulk(episodes)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue