Create Bulk Add Episode for faster processing (#9)
* benchmark logging * load schema updates * add extract bulk nodes and edges * updated bulk calls * compression updates * bulk updates * bulk logic first pass * updated bulk process * debug * remove exact names first * cleaned up prompt * fix bad merge * update * fix merge issues
This commit is contained in:
parent
a6fd0ddb75
commit
d6add504bd
11 changed files with 675 additions and 109 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from neo4j import AsyncDriver
|
||||
from uuid import uuid4
|
||||
import logging
|
||||
|
|
@ -76,10 +77,15 @@ class EntityEdge(Edge):
|
|||
)
|
||||
|
||||
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
|
||||
start = time()
|
||||
|
||||
text = self.fact.replace("\n", " ")
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.fact_embedding = embedding[:EMBEDDING_DIM]
|
||||
|
||||
end = time()
|
||||
logger.info(f"embedded {text} in {end-start} ms")
|
||||
|
||||
return embedding
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
|
|
@ -105,6 +111,6 @@ class EntityEdge(Edge):
|
|||
invalid_at=self.invalid_at,
|
||||
)
|
||||
|
||||
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
||||
logger.info(f"Saved edge to neo4j: {self.uuid}")
|
||||
|
||||
return result
|
||||
|
|
|
|||
185
core/graphiti.py
185
core/graphiti.py
|
|
@ -4,25 +4,32 @@ import logging
|
|||
from typing import Callable, LiteralString
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from dotenv import load_dotenv
|
||||
from time import time
|
||||
import os
|
||||
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import EntityNode, EpisodicNode, Node
|
||||
from core.edges import EntityEdge, EpisodicEdge
|
||||
from core.edges import EntityEdge, Edge, EpisodicEdge
|
||||
from core.utils import (
|
||||
build_episodic_edges,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
||||
from core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
dedupe_extracted_edges,
|
||||
from core.utils.bulk_utils import (
|
||||
BulkEpisode,
|
||||
extract_nodes_and_edges_bulk,
|
||||
retrieve_previous_episodes_bulk,
|
||||
compress_nodes,
|
||||
dedupe_nodes_bulk,
|
||||
resolve_edge_pointers,
|
||||
dedupe_edges_bulk,
|
||||
)
|
||||
|
||||
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
||||
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
prepare_edges_for_invalidation,
|
||||
invalidate_edges,
|
||||
prepare_edges_for_invalidation,
|
||||
)
|
||||
from core.utils.search.search_utils import (
|
||||
edge_similarity_search,
|
||||
|
|
@ -58,30 +65,47 @@ class Graphiti:
|
|||
self.driver.close()
|
||||
|
||||
async def retrieve_episodes(
|
||||
self, last_n: int, sources: list[str] | None = "messages"
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int,
|
||||
sources: list[str] | None = "messages",
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
return await retrieve_episodes(self.driver, last_n, sources)
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n, sources)
|
||||
|
||||
# Invalidate edges that are no longer valid
|
||||
async def invalidate_edges(
|
||||
self,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
new_edges: list[EntityEdge],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
): ...
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime = None,
|
||||
reference_time: datetime,
|
||||
episode_type="string",
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
):
|
||||
"""Process an episode and update the graph"""
|
||||
try:
|
||||
start = time()
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
entity_edges: list[EntityEdge] = []
|
||||
episodic_edges: list[EpisodicEdge] = []
|
||||
embedder = self.llm_client.client.embeddings
|
||||
now = datetime.now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(last_n=3)
|
||||
previous_episodes = await self.retrieve_episodes(
|
||||
reference_time, last_n=EPISODE_WINDOW_LEN
|
||||
)
|
||||
episode = EpisodicNode(
|
||||
name=name,
|
||||
labels=[],
|
||||
|
|
@ -105,7 +129,7 @@ class Graphiti:
|
|||
logger.info(
|
||||
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
|
||||
)
|
||||
new_nodes = await dedupe_extracted_nodes(
|
||||
new_nodes, _ = await dedupe_extracted_nodes(
|
||||
self.llm_client, extracted_nodes, existing_nodes
|
||||
)
|
||||
logger.info(
|
||||
|
|
@ -151,8 +175,15 @@ class Graphiti:
|
|||
)
|
||||
|
||||
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
|
||||
|
||||
entity_edges.extend(deduped_edges)
|
||||
|
||||
new_edges = await dedupe_extracted_edges(
|
||||
self.llm_client, extracted_edges, existing_edges
|
||||
)
|
||||
|
||||
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
|
||||
|
||||
entity_edges.extend(new_edges)
|
||||
episodic_edges.extend(
|
||||
build_episodic_edges(
|
||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||
|
|
@ -175,6 +206,9 @@ class Graphiti:
|
|||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
||||
|
||||
end = time()
|
||||
logger.info(f"Completed add_episode in {(end-start) * 1000} ms")
|
||||
# for node in nodes:
|
||||
# if isinstance(node, EntityNode):
|
||||
# await node.update_summary(self.driver)
|
||||
|
|
@ -190,36 +224,19 @@ class Graphiti:
|
|||
index_queries: list[LiteralString] = [
|
||||
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
||||
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
||||
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
|
||||
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
|
||||
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
|
||||
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
|
||||
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
||||
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
||||
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
|
||||
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
|
||||
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
|
||||
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
|
||||
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
|
||||
]
|
||||
# Add the range indices
|
||||
for query in index_queries:
|
||||
await self.driver.execute_query(query)
|
||||
|
||||
# Add the semantic indices
|
||||
await self.driver.execute_query(
|
||||
"""
|
||||
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
|
||||
"""
|
||||
)
|
||||
|
||||
await self.driver.execute_query(
|
||||
"""
|
||||
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
|
||||
"""
|
||||
)
|
||||
|
||||
await self.driver.execute_query(
|
||||
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
|
||||
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
|
||||
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
|
||||
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
|
||||
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
|
||||
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
|
||||
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
|
||||
"""
|
||||
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
||||
|
|
@ -227,10 +244,7 @@ class Graphiti:
|
|||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
await self.driver.execute_query(
|
||||
""",
|
||||
"""
|
||||
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
||||
FOR (n:Entity) ON (n.name_embedding)
|
||||
|
|
@ -238,7 +252,19 @@ class Graphiti:
|
|||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
"""
|
||||
CREATE CONSTRAINT entity_name IF NOT EXISTS
|
||||
FOR (n:Entity) REQUIRE n.name IS UNIQUE
|
||||
""",
|
||||
"""
|
||||
CREATE CONSTRAINT edge_facts IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
|
||||
""",
|
||||
]
|
||||
|
||||
await asyncio.gather(
|
||||
*[self.driver.execute_query(query) for query in index_queries]
|
||||
)
|
||||
|
||||
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
|
||||
|
|
@ -267,3 +293,78 @@ class Graphiti:
|
|||
context = await bfs(node_ids, self.driver)
|
||||
|
||||
return context
|
||||
|
||||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[BulkEpisode],
|
||||
):
|
||||
try:
|
||||
start = time()
|
||||
embedder = self.llm_client.client.embeddings
|
||||
now = datetime.now()
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
name=episode.name,
|
||||
labels=[],
|
||||
source="messages",
|
||||
content=episode.content,
|
||||
source_description=episode.source_description,
|
||||
created_at=now,
|
||||
valid_at=episode.reference_time,
|
||||
)
|
||||
for episode in bulk_episodes
|
||||
]
|
||||
|
||||
# Save all the episodes
|
||||
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
|
||||
|
||||
# Get previous episode context for each episode
|
||||
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||
|
||||
# Extract all nodes and edges
|
||||
extracted_nodes, extracted_edges, episodic_edges = (
|
||||
await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
||||
)
|
||||
|
||||
# Generate embeddings
|
||||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
|
||||
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
||||
)
|
||||
|
||||
# Dedupe extracted nodes
|
||||
nodes, uuid_map = await dedupe_nodes_bulk(
|
||||
self.driver, self.llm_client, extracted_nodes
|
||||
)
|
||||
|
||||
# save nodes to KG
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
|
||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||
extracted_edges: list[EntityEdge] = resolve_edge_pointers(
|
||||
extracted_edges, uuid_map
|
||||
)
|
||||
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(
|
||||
episodic_edges, uuid_map
|
||||
)
|
||||
|
||||
# save episodic edges to KG
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||
|
||||
# Dedupe extracted edges
|
||||
edges = await dedupe_edges_bulk(
|
||||
self.driver, self.llm_client, extracted_edges
|
||||
)
|
||||
logger.info(f"extracted edge length: {len(edges)}")
|
||||
|
||||
# invalidate edges
|
||||
|
||||
# save edges to KG
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||
|
||||
end = time()
|
||||
logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms")
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pydantic import Field
|
||||
from time import time
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -35,14 +35,13 @@ class EpisodicNode(Node):
|
|||
source: str = Field(description="source type")
|
||||
source_description: str = Field(description="description of the data source")
|
||||
content: str = Field(description="raw episode data")
|
||||
valid_at: datetime = Field(
|
||||
description="datetime of when the original document was created",
|
||||
)
|
||||
entity_edges: list[str] = Field(
|
||||
description="list of entity edges referenced in this episode",
|
||||
default_factory=list,
|
||||
)
|
||||
valid_at: datetime | None = Field(
|
||||
description="datetime of when the original document was created",
|
||||
default=None,
|
||||
)
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
|
|
@ -80,9 +79,12 @@ class EntityNode(Node):
|
|||
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
||||
|
||||
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
|
||||
start = time()
|
||||
text = self.name.replace("\n", " ")
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||
end = time()
|
||||
logger.info(f"embedded {text} in {end-start} ms")
|
||||
|
||||
return embedding
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
|
|||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
edge_list: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
edge_list: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
|
|
@ -43,7 +45,6 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
{{
|
||||
"new_edges": [
|
||||
{{
|
||||
"name": "Unique identifier for the edge",
|
||||
"fact": "one sentence description of the fact"
|
||||
}}
|
||||
]
|
||||
|
|
@ -53,4 +54,40 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1}
|
||||
def edge_list(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that de-duplicates edges from edge lists.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, find all of the duplicates in a list of edges:
|
||||
|
||||
Edges:
|
||||
{json.dumps(context['edges'], indent=2)}
|
||||
|
||||
Task:
|
||||
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and fact of edges to determine if they are duplicates,
|
||||
edges with the same name may not be duplicates
|
||||
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
|
||||
facts should be in the response
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"unique_edges": [
|
||||
{{
|
||||
"fact": "fact of a unique edge",
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1, "edge_list": edge_list}
|
||||
|
|
|
|||
|
|
@ -6,10 +6,14 @@ from .models import Message, PromptVersion, PromptFunction
|
|||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
node_list: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
node_list: PromptVersion
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
|
|
@ -44,7 +48,6 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
"new_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node",
|
||||
"summary": "Brief summary of the node's role or significance"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
|
@ -53,4 +56,79 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1}
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that de-duplicates nodes from node lists.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes:
|
||||
|
||||
Existing Nodes:
|
||||
{json.dumps(context['existing_nodes'], indent=2)}
|
||||
|
||||
New Nodes:
|
||||
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||
|
||||
Task:
|
||||
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||
duplicate nodes may have different names
|
||||
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be
|
||||
the name of the Existing Node.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"duplicates": [
|
||||
{{
|
||||
"name": "name of the new node",
|
||||
"duplicate_of": "name of the existing node"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def node_list(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that de-duplicates nodes from node lists.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, deduplicate a list of nodes:
|
||||
|
||||
Nodes:
|
||||
{json.dumps(context['nodes'], indent=2)}
|
||||
|
||||
Task:
|
||||
1. Group nodes together such that all duplicate nodes are in the same list of names
|
||||
2. All dupolicate names should be grouped together in the same list
|
||||
|
||||
Guidelines:
|
||||
1. Each name from the list of nodes should appear EXACTLY once in your response
|
||||
2. If a node has no duplicates, it should appear in the response in a list of only one name
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"names": ["myNode", "node that is a duplicate of myNode"],
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1, "v2": v2, "node_list": node_list}
|
||||
|
|
|
|||
206
core/utils/bulk_utils.py
Normal file
206
core/utils/bulk_utils.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.edges import EpisodicEdge, EntityEdge, Edge
|
||||
from core.llm_client import LLMClient
|
||||
from core.nodes import EpisodicNode, EntityNode
|
||||
from core.utils import retrieve_episodes
|
||||
from core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
build_episodic_edges,
|
||||
dedupe_edge_list,
|
||||
dedupe_extracted_edges,
|
||||
)
|
||||
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||
from core.utils.maintenance.node_operations import (
|
||||
extract_nodes,
|
||||
dedupe_node_list,
|
||||
dedupe_extracted_nodes,
|
||||
)
|
||||
from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges
|
||||
|
||||
CHUNK_SIZE = 10
|
||||
|
||||
|
||||
class BulkEpisode(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
source_description: str
|
||||
episode_type: str
|
||||
reference_time: datetime
|
||||
|
||||
|
||||
async def retrieve_previous_episodes_bulk(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||
previous_episodes_list = await asyncio.gather(
|
||||
*[
|
||||
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
|
||||
for episode in episodes
|
||||
]
|
||||
)
|
||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
|
||||
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
|
||||
]
|
||||
|
||||
return episode_tuples
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||
extracted_nodes_bulk = await asyncio.gather(
|
||||
*[
|
||||
extract_nodes(llm_client, episode, previous_episodes)
|
||||
for episode, previous_episodes in episode_tuples
|
||||
]
|
||||
)
|
||||
|
||||
episodes, previous_episodes_list = [episode[0] for episode in episode_tuples], [
|
||||
episode[1] for episode in episode_tuples
|
||||
]
|
||||
|
||||
extracted_edges_bulk = await asyncio.gather(
|
||||
*[
|
||||
extract_edges(
|
||||
llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]
|
||||
)
|
||||
for i, episode in enumerate(episodes)
|
||||
]
|
||||
)
|
||||
|
||||
episodic_edges: list[EpisodicEdge] = []
|
||||
for i, episode in enumerate(episodes):
|
||||
episodic_edges += build_episodic_edges(
|
||||
extracted_nodes_bulk[i], episode, episode.created_at
|
||||
)
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
for extracted_nodes in extracted_nodes_bulk:
|
||||
nodes += extracted_nodes
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
for extracted_edges in extracted_edges_bulk:
|
||||
edges += extracted_edges
|
||||
|
||||
return nodes, edges, episodic_edges
|
||||
|
||||
|
||||
async def dedupe_nodes_bulk(
|
||||
driver: AsyncDriver,
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
# Compress nodes
|
||||
nodes, uuid_map = node_name_match(extracted_nodes)
|
||||
|
||||
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
||||
|
||||
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
|
||||
|
||||
nodes, partial_uuid_map = await dedupe_extracted_nodes(
|
||||
llm_client, compressed_nodes, existing_nodes
|
||||
)
|
||||
|
||||
compressed_map.update(partial_uuid_map)
|
||||
|
||||
return nodes, compressed_map
|
||||
|
||||
|
||||
async def dedupe_edges_bulk(
|
||||
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||
) -> list[EntityEdge]:
|
||||
# Compress edges
|
||||
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
||||
|
||||
existing_edges = await get_relevant_edges(compressed_edges, driver)
|
||||
|
||||
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
uuid_map = {}
|
||||
name_map = {}
|
||||
for node in nodes:
|
||||
if node.name in name_map:
|
||||
uuid_map[node.uuid] = name_map[node.name].uuid
|
||||
continue
|
||||
|
||||
name_map[node.name] = node
|
||||
|
||||
return [node for node in name_map.values()], uuid_map
|
||||
|
||||
|
||||
async def compress_nodes(
|
||||
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
||||
)
|
||||
|
||||
extended_map = dict(uuid_map)
|
||||
compressed_nodes: list[EntityNode] = []
|
||||
for node_chunk, uuid_map_chunk in results:
|
||||
compressed_nodes += node_chunk
|
||||
extended_map.update(uuid_map_chunk)
|
||||
|
||||
# Check if we have removed all duplicates
|
||||
if len(compressed_nodes) == len(nodes):
|
||||
compressed_uuid_map = compress_uuid_map(extended_map)
|
||||
return compressed_nodes, compressed_uuid_map
|
||||
|
||||
return await compress_nodes(llm_client, compressed_nodes, extended_map)
|
||||
|
||||
|
||||
async def compress_edges(
|
||||
llm_client: LLMClient, edges: list[EntityEdge]
|
||||
) -> list[EntityEdge]:
|
||||
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
||||
)
|
||||
|
||||
compressed_edges: list[EntityEdge] = []
|
||||
for edge_chunk in results:
|
||||
compressed_edges += edge_chunk
|
||||
|
||||
# Check if we have removed all duplicates
|
||||
if len(compressed_edges) == len(edges):
|
||||
return compressed_edges
|
||||
|
||||
return await compress_edges(llm_client, compressed_edges)
|
||||
|
||||
|
||||
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
||||
# make sure all uuid values aren't mapped to other uuids
|
||||
compressed_map = {}
|
||||
for key, uuid in uuid_map.items():
|
||||
curr_value = uuid
|
||||
while curr_value in uuid_map.keys():
|
||||
curr_value = uuid_map[curr_value]
|
||||
|
||||
compressed_map[key] = curr_value
|
||||
return compressed_map
|
||||
|
||||
|
||||
def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]):
|
||||
for edge in edges:
|
||||
source_uuid = edge.source_node_uuid
|
||||
target_uuid = edge.target_node_uuid
|
||||
edge.source_node_uuid = (
|
||||
uuid_map[source_uuid] if source_uuid in uuid_map else source_uuid
|
||||
)
|
||||
edge.target_node_uuid = (
|
||||
uuid_map[target_uuid] if target_uuid in uuid_map else target_uuid
|
||||
)
|
||||
|
||||
return edges
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|||
def build_episodic_edges(
|
||||
entity_nodes: List[EntityNode],
|
||||
episode: EpisodicNode,
|
||||
transaction_from: datetime,
|
||||
created_at: datetime,
|
||||
) -> List[EpisodicEdge]:
|
||||
edges: List[EpisodicEdge] = []
|
||||
|
||||
|
|
@ -25,7 +26,7 @@ def build_episodic_edges(
|
|||
edge = EpisodicEdge(
|
||||
source_node_uuid=episode.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
created_at=transaction_from,
|
||||
created_at=created_at,
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
|
|
@ -144,6 +145,8 @@ async def extract_edges(
|
|||
nodes: list[EntityNode],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
"episode_content": episode.content,
|
||||
|
|
@ -167,7 +170,9 @@ async def extract_edges(
|
|||
prompt_library.extract_edges.v2(context)
|
||||
)
|
||||
edges_data = llm_response.get("edges", [])
|
||||
logger.info(f"Extracted new edges: {edges_data}")
|
||||
|
||||
end = time()
|
||||
logger.info(f"Extracted new edges: {edges_data} in {(end - start) * 1000} ms")
|
||||
|
||||
# Convert the extracted data into EntityEdge objects
|
||||
edges = []
|
||||
|
|
@ -199,11 +204,11 @@ async def dedupe_extracted_edges(
|
|||
# Create edge map
|
||||
edge_map = {}
|
||||
for edge in existing_edges:
|
||||
edge_map[edge.name] = edge
|
||||
edge_map[edge.fact] = edge
|
||||
for edge in extracted_edges:
|
||||
if edge.name in edge_map.keys():
|
||||
if edge.fact in edge_map.keys():
|
||||
continue
|
||||
edge_map[edge.name] = edge
|
||||
edge_map[edge.fact] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
|
|
@ -224,7 +229,40 @@ async def dedupe_extracted_edges(
|
|||
# Get full edge data
|
||||
edges = []
|
||||
for edge_data in new_edges_data:
|
||||
edge = edge_map[edge_data["name"]]
|
||||
edge = edge_map[edge_data["fact"]]
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def dedupe_edge_list(
|
||||
llm_client: LLMClient,
|
||||
edges: list[EntityEdge],
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
# Create edge map
|
||||
edge_map = {}
|
||||
for edge in edges:
|
||||
edge_map[edge.fact] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {"edges": [{"name": edge.name, "fact": edge.fact} for edge in edges]}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_edges.edge_list(context)
|
||||
)
|
||||
unique_edges_data = llm_response.get("unique_edges", [])
|
||||
|
||||
end = time()
|
||||
logger.info(
|
||||
f"Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms "
|
||||
)
|
||||
|
||||
# Get full edge data
|
||||
unique_edges = []
|
||||
for edge_data in unique_edges_data:
|
||||
fact = edge_data["fact"]
|
||||
unique_edges.append(edge_map[fact])
|
||||
|
||||
return unique_edges
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from core.nodes import EpisodicNode
|
|||
from neo4j import AsyncDriver
|
||||
import logging
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -18,11 +19,15 @@ async def clear_data(driver: AsyncDriver):
|
|||
|
||||
|
||||
async def retrieve_episodes(
|
||||
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
|
||||
driver: AsyncDriver,
|
||||
reference_time: datetime,
|
||||
last_n: int,
|
||||
sources: list[str] | None = "messages",
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
query = """
|
||||
MATCH (e:Episodic)
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
RETURN e.content as content,
|
||||
e.created_at as created_at,
|
||||
e.valid_at as valid_at,
|
||||
|
|
@ -32,8 +37,10 @@ async def retrieve_episodes(
|
|||
e.source as source
|
||||
ORDER BY e.created_at DESC
|
||||
LIMIT $num_episodes
|
||||
"""
|
||||
result = await driver.execute_query(query, num_episodes=last_n)
|
||||
""",
|
||||
reference_time=reference_time,
|
||||
num_episodes=last_n,
|
||||
)
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
content=record["content"],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
import logging
|
||||
|
|
@ -68,6 +69,8 @@ async def extract_nodes(
|
|||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
"episode_content": episode.content,
|
||||
|
|
@ -87,7 +90,9 @@ async def extract_nodes(
|
|||
prompt_library.extract_nodes.v3(context)
|
||||
)
|
||||
new_nodes_data = llm_response.get("new_nodes", [])
|
||||
logger.info(f"Extracted new nodes: {new_nodes_data}")
|
||||
|
||||
end = time()
|
||||
logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms")
|
||||
# Convert the extracted data into EntityNode objects
|
||||
new_nodes = []
|
||||
for node_data in new_nodes_data:
|
||||
|
|
@ -107,15 +112,13 @@ async def dedupe_extracted_nodes(
|
|||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes: list[EntityNode],
|
||||
) -> list[EntityNode]:
|
||||
# build node map
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
# build existing node map
|
||||
node_map = {}
|
||||
for node in existing_nodes:
|
||||
node_map[node.name] = node
|
||||
for node in extracted_nodes:
|
||||
if node.name in node_map.keys():
|
||||
continue
|
||||
node_map[node.name] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_nodes_context = [
|
||||
|
|
@ -132,16 +135,69 @@ async def dedupe_extracted_nodes(
|
|||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_nodes.v1(context)
|
||||
prompt_library.dedupe_nodes.v2(context)
|
||||
)
|
||||
|
||||
new_nodes_data = llm_response.get("new_nodes", [])
|
||||
logger.info(f"Deduplicated nodes: {new_nodes_data}")
|
||||
duplicate_data = llm_response.get("duplicates", [])
|
||||
|
||||
end = time()
|
||||
logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms")
|
||||
|
||||
uuid_map = {}
|
||||
for duplicate in duplicate_data:
|
||||
uuid = node_map[duplicate["name"]].uuid
|
||||
uuid_value = node_map[duplicate["duplicate_of"]].uuid
|
||||
uuid_map[uuid] = uuid_value
|
||||
|
||||
# Get full node data
|
||||
nodes = []
|
||||
for node_data in new_nodes_data:
|
||||
node = node_map[node_data["name"]]
|
||||
for node in extracted_nodes:
|
||||
if node.uuid in uuid_map:
|
||||
existing_name = uuid_map[node.name]
|
||||
existing_node = node_map[existing_name]
|
||||
nodes.append(existing_node)
|
||||
continue
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
return nodes, uuid_map
|
||||
|
||||
|
||||
async def dedupe_node_list(
|
||||
llm_client: LLMClient,
|
||||
nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
# build node map
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
node_map[node.name] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes]
|
||||
|
||||
context = {
|
||||
"nodes": nodes_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_nodes.node_list(context)
|
||||
)
|
||||
|
||||
nodes_data = llm_response.get("nodes", [])
|
||||
|
||||
end = time()
|
||||
logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms")
|
||||
|
||||
# Get full node data
|
||||
unique_nodes = []
|
||||
uuid_map: dict[str, str] = {}
|
||||
for node_data in nodes_data:
|
||||
node = node_map[node_data["names"][0]]
|
||||
unique_nodes.append(node)
|
||||
|
||||
for name in node_data["names"][1:]:
|
||||
uuid = node_map[name].uuid
|
||||
uuid_value = node_map[node_data["names"][0]].uuid
|
||||
uuid_map[uuid] = uuid_value
|
||||
|
||||
return unique_nodes, uuid_map
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
||||
|
|
@ -9,6 +10,8 @@ from core.nodes import EntityNode
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -60,7 +63,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
|||
|
||||
|
||||
async def edge_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -80,9 +83,10 @@ async def edge_similarity_search(
|
|||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT 10
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
|
@ -106,18 +110,16 @@ async def edge_similarity_search(
|
|||
|
||||
edges.append(edge)
|
||||
|
||||
logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}")
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def entity_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector)
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
|
|
@ -127,6 +129,7 @@ async def entity_similarity_search(
|
|||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
|
|
@ -141,12 +144,12 @@ async def entity_similarity_search(
|
|||
)
|
||||
)
|
||||
|
||||
logger.info(f"name semantic search results. RESULT: {nodes}")
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]:
|
||||
async def entity_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = query + "~"
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -158,9 +161,10 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
|
|||
node.created_at AS created_at,
|
||||
node.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT 10
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
|
|
@ -175,12 +179,12 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity
|
|||
)
|
||||
)
|
||||
|
||||
logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}")
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]:
|
||||
async def edge_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = query + "~"
|
||||
|
||||
|
|
@ -201,9 +205,10 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
|
|||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT 10
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
|
@ -227,10 +232,6 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd
|
|||
|
||||
edges.append(edge)
|
||||
|
||||
logger.info(
|
||||
f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}"
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
|
|
@ -238,7 +239,9 @@ async def get_relevant_nodes(
|
|||
nodes: list[EntityNode],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityNode]:
|
||||
relevant_nodes: dict[str, EntityNode] = {}
|
||||
start = time()
|
||||
relevant_nodes: list[EntityNode] = []
|
||||
relevant_node_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
||||
|
|
@ -247,18 +250,27 @@ async def get_relevant_nodes(
|
|||
|
||||
for result in results:
|
||||
for node in result:
|
||||
relevant_nodes[node.uuid] = node
|
||||
if node.uuid in relevant_node_uuids:
|
||||
continue
|
||||
|
||||
logger.info(f"Found relevant nodes: {relevant_nodes.keys()}")
|
||||
relevant_node_uuids.add(node.uuid)
|
||||
relevant_nodes.append(node)
|
||||
|
||||
return relevant_nodes.values()
|
||||
end = time()
|
||||
logger.info(
|
||||
f"Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms"
|
||||
)
|
||||
|
||||
return relevant_nodes
|
||||
|
||||
|
||||
async def get_relevant_edges(
|
||||
edges: list[EntityEdge],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityEdge]:
|
||||
relevant_edges: dict[str, EntityEdge] = {}
|
||||
start = time()
|
||||
relevant_edges: list[EntityEdge] = []
|
||||
relevant_edge_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
||||
|
|
@ -267,8 +279,15 @@ async def get_relevant_edges(
|
|||
|
||||
for result in results:
|
||||
for edge in result:
|
||||
relevant_edges[edge.uuid] = edge
|
||||
if edge.uuid in relevant_edge_uuids:
|
||||
continue
|
||||
|
||||
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
|
||||
relevant_edge_uuids.add(edge.uuid)
|
||||
relevant_edges.append(edge)
|
||||
|
||||
return list(relevant_edges.values())
|
||||
end = time()
|
||||
logger.info(
|
||||
f"Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms"
|
||||
)
|
||||
|
||||
return relevant_edges
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from core import Graphiti
|
||||
from core.utils.bulk_utils import BulkEpisode
|
||||
from core.utils.maintenance.graph_data_operations import clear_data
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
|
@ -37,18 +38,33 @@ def setup_logging():
|
|||
return logger
|
||||
|
||||
|
||||
async def main():
|
||||
async def main(use_bulk: bool = True):
|
||||
setup_logging()
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
await clear_data(client.driver)
|
||||
messages = parse_podcast_messages()
|
||||
for i, message in enumerate(messages[3:50]):
|
||||
await client.add_episode(
|
||||
|
||||
if not use_bulk:
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
await client.add_episode(
|
||||
name=f"Message {i}",
|
||||
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description="Podcast Transcript",
|
||||
)
|
||||
|
||||
episodes: list[BulkEpisode] = [
|
||||
BulkEpisode(
|
||||
name=f"Message {i}",
|
||||
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
|
||||
reference_time=message.actual_timestamp,
|
||||
content=f"{message.speaker_name} ({message.role}): {message.content}",
|
||||
source_description="Podcast Transcript",
|
||||
episode_type="string",
|
||||
reference_time=message.actual_timestamp,
|
||||
)
|
||||
for i, message in enumerate(messages[3:7])
|
||||
]
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue