diff --git a/core/edges.py b/core/edges.py index fee458f9..1f89ec29 100644 --- a/core/edges.py +++ b/core/edges.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from pydantic import BaseModel, Field from datetime import datetime +from time import time from neo4j import AsyncDriver from uuid import uuid4 import logging @@ -76,10 +77,15 @@ class EntityEdge(Edge): ) async def generate_embedding(self, embedder, model="text-embedding-3-small"): + start = time() + text = self.fact.replace("\n", " ") embedding = (await embedder.create(input=[text], model=model)).data[0].embedding self.fact_embedding = embedding[:EMBEDDING_DIM] + end = time() + logger.info(f"embedded {text} in {end-start} ms") + return embedding async def save(self, driver: AsyncDriver): @@ -105,6 +111,6 @@ class EntityEdge(Edge): invalid_at=self.invalid_at, ) - logger.info(f"Saved Node to neo4j: {self.uuid}") + logger.info(f"Saved edge to neo4j: {self.uuid}") return result diff --git a/core/graphiti.py b/core/graphiti.py index 14b20641..14f5130b 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -4,25 +4,32 @@ import logging from typing import Callable, LiteralString from neo4j import AsyncGraphDatabase from dotenv import load_dotenv +from time import time import os from core.llm_client.config import EMBEDDING_DIM from core.nodes import EntityNode, EpisodicNode, Node -from core.edges import EntityEdge, EpisodicEdge +from core.edges import EntityEdge, Edge, EpisodicEdge from core.utils import ( build_episodic_edges, retrieve_episodes, ) from core.llm_client import LLMClient, OpenAIClient, LLMConfig -from core.utils.maintenance.edge_operations import ( - extract_edges, - dedupe_extracted_edges, +from core.utils.bulk_utils import ( + BulkEpisode, + extract_nodes_and_edges_bulk, + retrieve_previous_episodes_bulk, + compress_nodes, + dedupe_nodes_bulk, + resolve_edge_pointers, + dedupe_edges_bulk, ) - +from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges +from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( - prepare_edges_for_invalidation, invalidate_edges, + prepare_edges_for_invalidation, ) from core.utils.search.search_utils import ( edge_similarity_search, @@ -58,30 +65,47 @@ class Graphiti: self.driver.close() async def retrieve_episodes( - self, last_n: int, sources: list[str] | None = "messages" + self, + reference_time: datetime, + last_n: int, + sources: list[str] | None = "messages", ) -> list[EpisodicNode]: """Retrieve the last n episodic nodes from the graph""" - return await retrieve_episodes(self.driver, last_n, sources) + return await retrieve_episodes(self.driver, reference_time, last_n, sources) + + # Invalidate edges that are no longer valid + async def invalidate_edges( + self, + episode: EpisodicNode, + new_nodes: list[EntityNode], + new_edges: list[EntityEdge], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ): ... async def add_episode( self, name: str, episode_body: str, source_description: str, - reference_time: datetime = None, + reference_time: datetime, episode_type="string", success_callback: Callable | None = None, error_callback: Callable | None = None, ): """Process an episode and update the graph""" try: + start = time() + nodes: list[EntityNode] = [] entity_edges: list[EntityEdge] = [] episodic_edges: list[EpisodicEdge] = [] embedder = self.llm_client.client.embeddings now = datetime.now() - previous_episodes = await self.retrieve_episodes(last_n=3) + previous_episodes = await self.retrieve_episodes( + reference_time, last_n=EPISODE_WINDOW_LEN + ) episode = EpisodicNode( name=name, labels=[], @@ -105,7 +129,7 @@ class Graphiti: logger.info( f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" ) - new_nodes = await dedupe_extracted_nodes( + new_nodes, _ = await dedupe_extracted_nodes( self.llm_client, extracted_nodes, existing_nodes ) logger.info( @@ -151,8 +175,15 @@ class Graphiti: ) logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}") - entity_edges.extend(deduped_edges) + + new_edges = await dedupe_extracted_edges( + self.llm_client, extracted_edges, existing_edges + ) + + logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}") + + entity_edges.extend(new_edges) episodic_edges.extend( build_episodic_edges( # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them @@ -175,6 +206,9 @@ class Graphiti: await asyncio.gather(*[node.save(self.driver) for node in nodes]) await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) + + end = time() + logger.info(f"Completed add_episode in {(end-start) * 1000} ms") # for node in nodes: # if isinstance(node, EntityNode): # await node.update_summary(self.driver) @@ -190,36 +224,19 @@ class Graphiti: index_queries: list[LiteralString] = [ "CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)", "CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)", - "CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)", - "CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)", + "CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)", + "CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)", "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", "CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)", "CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)", "CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)", - "CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)", - "CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)", - "CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)", - "CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)", - "CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)", - ] - # Add the range indices - for query in index_queries: - await self.driver.execute_query(query) - - # Add the semantic indices - await self.driver.execute_query( - """ - CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary] - """ - ) - - await self.driver.execute_query( - """ - CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact] - """ - ) - - await self.driver.execute_query( + "CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)", + "CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)", + "CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)", + "CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)", + "CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)", + "CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]", + "CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]", """ CREATE VECTOR INDEX fact_embedding IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding) @@ -227,10 +244,7 @@ class Graphiti: `vector.dimensions`: 1024, `vector.similarity_function`: 'cosine' }} - """ - ) - - await self.driver.execute_query( + """, """ CREATE VECTOR INDEX name_embedding IF NOT EXISTS FOR (n:Entity) ON (n.name_embedding) @@ -238,7 +252,19 @@ class Graphiti: `vector.dimensions`: 1024, `vector.similarity_function`: 'cosine' }} + """, """ + CREATE CONSTRAINT entity_name IF NOT EXISTS + FOR (n:Entity) REQUIRE n.name IS UNIQUE + """, + """ + CREATE CONSTRAINT edge_facts IF NOT EXISTS + FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE + """, + ] + + await asyncio.gather( + *[self.driver.execute_query(query) for query in index_queries] ) async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]: @@ -267,3 +293,78 @@ class Graphiti: context = await bfs(node_ids, self.driver) return context + + async def add_episode_bulk( + self, + bulk_episodes: list[BulkEpisode], + ): + try: + start = time() + embedder = self.llm_client.client.embeddings + now = datetime.now() + + episodes = [ + EpisodicNode( + name=episode.name, + labels=[], + source="messages", + content=episode.content, + source_description=episode.source_description, + created_at=now, + valid_at=episode.reference_time, + ) + for episode in bulk_episodes + ] + + # Save all the episodes + await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) + + # Get previous episode context for each episode + episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) + + # Extract all nodes and edges + extracted_nodes, extracted_edges, episodic_edges = ( + await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) + ) + + # Generate embeddings + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes], + *[edge.generate_embedding(embedder) for edge in extracted_edges], + ) + + # Dedupe extracted nodes + nodes, uuid_map = await dedupe_nodes_bulk( + self.driver, self.llm_client, extracted_nodes + ) + + # save nodes to KG + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + + # re-map edge pointers so that they don't point to discard dupe nodes + extracted_edges: list[EntityEdge] = resolve_edge_pointers( + extracted_edges, uuid_map + ) + episodic_edges: list[EpisodicEdge] = resolve_edge_pointers( + episodic_edges, uuid_map + ) + + # save episodic edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + + # Dedupe extracted edges + edges = await dedupe_edges_bulk( + self.driver, self.llm_client, extracted_edges + ) + logger.info(f"extracted edge length: {len(edges)}") + + # invalidate edges + + # save edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + + end = time() + logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms") + + except Exception as e: + raise e diff --git a/core/nodes.py b/core/nodes.py index 2f269d71..5d974448 100644 --- a/core/nodes.py +++ b/core/nodes.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from pydantic import Field +from time import time from datetime import datetime from uuid import uuid4 @@ -35,14 +35,13 @@ class EpisodicNode(Node): source: str = Field(description="source type") source_description: str = Field(description="description of the data source") content: str = Field(description="raw episode data") + valid_at: datetime = Field( + description="datetime of when the original document was created", + ) entity_edges: list[str] = Field( description="list of entity edges referenced in this episode", default_factory=list, ) - valid_at: datetime | None = Field( - description="datetime of when the original document was created", - default=None, - ) async def save(self, driver: AsyncDriver): result = await driver.execute_query( @@ -80,9 +79,12 @@ class EntityNode(Node): async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... async def generate_name_embedding(self, embedder, model="text-embedding-3-small"): + start = time() text = self.name.replace("\n", " ") embedding = (await embedder.create(input=[text], model=model)).data[0].embedding self.name_embedding = embedding[:EMBEDDING_DIM] + end = time() + logger.info(f"embedded {text} in {end-start} ms") return embedding diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py index 3506a3b7..adae8bda 100644 --- a/core/prompts/dedupe_edges.py +++ b/core/prompts/dedupe_edges.py @@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction class Prompt(Protocol): v1: PromptVersion + edge_list: PromptVersion class Versions(TypedDict): v1: PromptFunction + edge_list: PromptFunction def v1(context: dict[str, any]) -> list[Message]: @@ -43,7 +45,6 @@ def v1(context: dict[str, any]) -> list[Message]: {{ "new_edges": [ {{ - "name": "Unique identifier for the edge", "fact": "one sentence description of the fact" }} ] @@ -53,4 +54,40 @@ def v1(context: dict[str, any]) -> list[Message]: ] -versions: Versions = {"v1": v1} +def edge_list(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that de-duplicates edges from edge lists.", + ), + Message( + role="user", + content=f""" + Given the following context, find all of the duplicates in a list of edges: + + Edges: + {json.dumps(context['edges'], indent=2)} + + Task: + If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges + + Guidelines: + 1. Use both the name and fact of edges to determine if they are duplicates, + edges with the same name may not be duplicates + 2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their + facts should be in the response + + Respond with a JSON object in the following format: + {{ + "unique_edges": [ + {{ + "fact": "fact of a unique edge", + }} + ] + }} + """, + ), + ] + + +versions: Versions = {"v1": v1, "edge_list": edge_list} diff --git a/core/prompts/dedupe_nodes.py b/core/prompts/dedupe_nodes.py index 798942b5..3f54ef39 100644 --- a/core/prompts/dedupe_nodes.py +++ b/core/prompts/dedupe_nodes.py @@ -6,10 +6,14 @@ from .models import Message, PromptVersion, PromptFunction class Prompt(Protocol): v1: PromptVersion + v2: PromptVersion + node_list: PromptVersion class Versions(TypedDict): v1: PromptFunction + v2: PromptFunction + node_list: PromptVersion def v1(context: dict[str, any]) -> list[Message]: @@ -44,7 +48,6 @@ def v1(context: dict[str, any]) -> list[Message]: "new_nodes": [ {{ "name": "Unique identifier for the node", - "summary": "Brief summary of the node's role or significance" }} ] }} @@ -53,4 +56,79 @@ def v1(context: dict[str, any]) -> list[Message]: ] -versions: Versions = {"v1": v1} +def v2(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that de-duplicates nodes from node lists.", + ), + Message( + role="user", + content=f""" + Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes: + + Existing Nodes: + {json.dumps(context['existing_nodes'], indent=2)} + + New Nodes: + {json.dumps(context['extracted_nodes'], indent=2)} + + Task: + If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list + + Guidelines: + 1. Use both the name and summary of nodes to determine if they are duplicates, + duplicate nodes may have different names + 2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be + the name of the Existing Node. + + Respond with a JSON object in the following format: + {{ + "duplicates": [ + {{ + "name": "name of the new node", + "duplicate_of": "name of the existing node" + }} + ] + }} + """, + ), + ] + + +def node_list(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that de-duplicates nodes from node lists.", + ), + Message( + role="user", + content=f""" + Given the following context, deduplicate a list of nodes: + + Nodes: + {json.dumps(context['nodes'], indent=2)} + + Task: + 1. Group nodes together such that all duplicate nodes are in the same list of names + 2. All dupolicate names should be grouped together in the same list + + Guidelines: + 1. Each name from the list of nodes should appear EXACTLY once in your response + 2. If a node has no duplicates, it should appear in the response in a list of only one name + + Respond with a JSON object in the following format: + {{ + "nodes": [ + {{ + "names": ["myNode", "node that is a duplicate of myNode"], + }} + ] + }} + """, + ), + ] + + +versions: Versions = {"v1": v1, "v2": v2, "node_list": node_list} diff --git a/core/utils/bulk_utils.py b/core/utils/bulk_utils.py new file mode 100644 index 00000000..a5b361ef --- /dev/null +++ b/core/utils/bulk_utils.py @@ -0,0 +1,206 @@ +import asyncio +from collections import defaultdict +from datetime import datetime + +from neo4j import AsyncDriver +from pydantic import BaseModel + +from core.edges import EpisodicEdge, EntityEdge, Edge +from core.llm_client import LLMClient +from core.nodes import EpisodicNode, EntityNode +from core.utils import retrieve_episodes +from core.utils.maintenance.edge_operations import ( + extract_edges, + build_episodic_edges, + dedupe_edge_list, + dedupe_extracted_edges, +) +from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN +from core.utils.maintenance.node_operations import ( + extract_nodes, + dedupe_node_list, + dedupe_extracted_nodes, +) +from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges + +CHUNK_SIZE = 10 + + +class BulkEpisode(BaseModel): + name: str + content: str + source_description: str + episode_type: str + reference_time: datetime + + +async def retrieve_previous_episodes_bulk( + driver: AsyncDriver, episodes: list[EpisodicNode] +) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: + previous_episodes_list = await asyncio.gather( + *[ + retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN) + for episode in episodes + ] + ) + episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [ + (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes) + ] + + return episode_tuples + + +async def extract_nodes_and_edges_bulk( + llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] +) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: + extracted_nodes_bulk = await asyncio.gather( + *[ + extract_nodes(llm_client, episode, previous_episodes) + for episode, previous_episodes in episode_tuples + ] + ) + + episodes, previous_episodes_list = [episode[0] for episode in episode_tuples], [ + episode[1] for episode in episode_tuples + ] + + extracted_edges_bulk = await asyncio.gather( + *[ + extract_edges( + llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i] + ) + for i, episode in enumerate(episodes) + ] + ) + + episodic_edges: list[EpisodicEdge] = [] + for i, episode in enumerate(episodes): + episodic_edges += build_episodic_edges( + extracted_nodes_bulk[i], episode, episode.created_at + ) + + nodes: list[EntityNode] = [] + for extracted_nodes in extracted_nodes_bulk: + nodes += extracted_nodes + + edges: list[EntityEdge] = [] + for extracted_edges in extracted_edges_bulk: + edges += extracted_edges + + return nodes, edges, episodic_edges + + +async def dedupe_nodes_bulk( + driver: AsyncDriver, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], +) -> tuple[list[EntityNode], dict[str, str]]: + # Compress nodes + nodes, uuid_map = node_name_match(extracted_nodes) + + compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map) + + existing_nodes = await get_relevant_nodes(compressed_nodes, driver) + + nodes, partial_uuid_map = await dedupe_extracted_nodes( + llm_client, compressed_nodes, existing_nodes + ) + + compressed_map.update(partial_uuid_map) + + return nodes, compressed_map + + +async def dedupe_edges_bulk( + driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] +) -> list[EntityEdge]: + # Compress edges + compressed_edges = await compress_edges(llm_client, extracted_edges) + + existing_edges = await get_relevant_edges(compressed_edges, driver) + + edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges) + + return edges + + +def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]: + uuid_map = {} + name_map = {} + for node in nodes: + if node.name in name_map: + uuid_map[node.uuid] = name_map[node.name].uuid + continue + + name_map[node.name] = node + + return [node for node in name_map.values()], uuid_map + + +async def compress_nodes( + llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] +) -> tuple[list[EntityNode], dict[str, str]]: + node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] + + results = await asyncio.gather( + *[dedupe_node_list(llm_client, chunk) for chunk in node_chunks] + ) + + extended_map = dict(uuid_map) + compressed_nodes: list[EntityNode] = [] + for node_chunk, uuid_map_chunk in results: + compressed_nodes += node_chunk + extended_map.update(uuid_map_chunk) + + # Check if we have removed all duplicates + if len(compressed_nodes) == len(nodes): + compressed_uuid_map = compress_uuid_map(extended_map) + return compressed_nodes, compressed_uuid_map + + return await compress_nodes(llm_client, compressed_nodes, extended_map) + + +async def compress_edges( + llm_client: LLMClient, edges: list[EntityEdge] +) -> list[EntityEdge]: + edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] + + results = await asyncio.gather( + *[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks] + ) + + compressed_edges: list[EntityEdge] = [] + for edge_chunk in results: + compressed_edges += edge_chunk + + # Check if we have removed all duplicates + if len(compressed_edges) == len(edges): + return compressed_edges + + return await compress_edges(llm_client, compressed_edges) + + +def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]: + # make sure all uuid values aren't mapped to other uuids + compressed_map = {} + for key, uuid in uuid_map.items(): + curr_value = uuid + while curr_value in uuid_map.keys(): + curr_value = uuid_map[curr_value] + + compressed_map[key] = curr_value + return compressed_map + + +def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]): + for edge in edges: + source_uuid = edge.source_node_uuid + target_uuid = edge.target_node_uuid + edge.source_node_uuid = ( + uuid_map[source_uuid] if source_uuid in uuid_map else source_uuid + ) + edge.target_node_uuid = ( + uuid_map[target_uuid] if target_uuid in uuid_map else target_uuid + ) + + return edges diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index f275ed96..ec505cb5 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -1,6 +1,7 @@ import json from typing import List from datetime import datetime +from time import time from pydantic import BaseModel @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) def build_episodic_edges( entity_nodes: List[EntityNode], episode: EpisodicNode, - transaction_from: datetime, + created_at: datetime, ) -> List[EpisodicEdge]: edges: List[EpisodicEdge] = [] @@ -25,7 +26,7 @@ def build_episodic_edges( edge = EpisodicEdge( source_node_uuid=episode.uuid, target_node_uuid=node.uuid, - created_at=transaction_from, + created_at=created_at, ) edges.append(edge) @@ -144,6 +145,8 @@ async def extract_edges( nodes: list[EntityNode], previous_episodes: list[EpisodicNode], ) -> list[EntityEdge]: + start = time() + # Prepare context for LLM context = { "episode_content": episode.content, @@ -167,7 +170,9 @@ async def extract_edges( prompt_library.extract_edges.v2(context) ) edges_data = llm_response.get("edges", []) - logger.info(f"Extracted new edges: {edges_data}") + + end = time() + logger.info(f"Extracted new edges: {edges_data} in {(end - start) * 1000} ms") # Convert the extracted data into EntityEdge objects edges = [] @@ -199,11 +204,11 @@ async def dedupe_extracted_edges( # Create edge map edge_map = {} for edge in existing_edges: - edge_map[edge.name] = edge + edge_map[edge.fact] = edge for edge in extracted_edges: - if edge.name in edge_map.keys(): + if edge.fact in edge_map.keys(): continue - edge_map[edge.name] = edge + edge_map[edge.fact] = edge # Prepare context for LLM context = { @@ -224,7 +229,40 @@ async def dedupe_extracted_edges( # Get full edge data edges = [] for edge_data in new_edges_data: - edge = edge_map[edge_data["name"]] + edge = edge_map[edge_data["fact"]] edges.append(edge) return edges + + +async def dedupe_edge_list( + llm_client: LLMClient, + edges: list[EntityEdge], +) -> list[EntityEdge]: + start = time() + + # Create edge map + edge_map = {} + for edge in edges: + edge_map[edge.fact] = edge + + # Prepare context for LLM + context = {"edges": [{"name": edge.name, "fact": edge.fact} for edge in edges]} + + llm_response = await llm_client.generate_response( + prompt_library.dedupe_edges.edge_list(context) + ) + unique_edges_data = llm_response.get("unique_edges", []) + + end = time() + logger.info( + f"Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms " + ) + + # Get full edge data + unique_edges = [] + for edge_data in unique_edges_data: + fact = edge_data["fact"] + unique_edges.append(edge_map[fact]) + + return unique_edges diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py index 62beb5b1..790400f3 100644 --- a/core/utils/maintenance/graph_data_operations.py +++ b/core/utils/maintenance/graph_data_operations.py @@ -4,6 +4,7 @@ from core.nodes import EpisodicNode from neo4j import AsyncDriver import logging +EPISODE_WINDOW_LEN = 3 logger = logging.getLogger(__name__) @@ -18,11 +19,15 @@ async def clear_data(driver: AsyncDriver): async def retrieve_episodes( - driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages" + driver: AsyncDriver, + reference_time: datetime, + last_n: int, + sources: list[str] | None = "messages", ) -> list[EpisodicNode]: """Retrieve the last n episodic nodes from the graph""" - query = """ - MATCH (e:Episodic) + result = await driver.execute_query( + """ + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time RETURN e.content as content, e.created_at as created_at, e.valid_at as valid_at, @@ -32,8 +37,10 @@ async def retrieve_episodes( e.source as source ORDER BY e.created_at DESC LIMIT $num_episodes - """ - result = await driver.execute_query(query, num_episodes=last_n) + """, + reference_time=reference_time, + num_episodes=last_n, + ) episodes = [ EpisodicNode( content=record["content"], diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index 1b67bf56..3397ee72 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -1,4 +1,5 @@ from datetime import datetime +from time import time from core.nodes import EntityNode, EpisodicNode import logging @@ -68,6 +69,8 @@ async def extract_nodes( episode: EpisodicNode, previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: + start = time() + # Prepare context for LLM context = { "episode_content": episode.content, @@ -87,7 +90,9 @@ async def extract_nodes( prompt_library.extract_nodes.v3(context) ) new_nodes_data = llm_response.get("new_nodes", []) - logger.info(f"Extracted new nodes: {new_nodes_data}") + + end = time() + logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms") # Convert the extracted data into EntityNode objects new_nodes = [] for node_data in new_nodes_data: @@ -107,15 +112,13 @@ async def dedupe_extracted_nodes( llm_client: LLMClient, extracted_nodes: list[EntityNode], existing_nodes: list[EntityNode], -) -> list[EntityNode]: - # build node map +) -> tuple[list[EntityNode], dict[str, str]]: + start = time() + + # build existing node map node_map = {} for node in existing_nodes: node_map[node.name] = node - for node in extracted_nodes: - if node.name in node_map.keys(): - continue - node_map[node.name] = node # Prepare context for LLM existing_nodes_context = [ @@ -132,16 +135,69 @@ async def dedupe_extracted_nodes( } llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.v1(context) + prompt_library.dedupe_nodes.v2(context) ) - new_nodes_data = llm_response.get("new_nodes", []) - logger.info(f"Deduplicated nodes: {new_nodes_data}") + duplicate_data = llm_response.get("duplicates", []) + + end = time() + logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms") + + uuid_map = {} + for duplicate in duplicate_data: + uuid = node_map[duplicate["name"]].uuid + uuid_value = node_map[duplicate["duplicate_of"]].uuid + uuid_map[uuid] = uuid_value - # Get full node data nodes = [] - for node_data in new_nodes_data: - node = node_map[node_data["name"]] + for node in extracted_nodes: + if node.uuid in uuid_map: + existing_name = uuid_map[node.name] + existing_node = node_map[existing_name] + nodes.append(existing_node) + continue nodes.append(node) - return nodes + return nodes, uuid_map + + +async def dedupe_node_list( + llm_client: LLMClient, + nodes: list[EntityNode], +) -> tuple[list[EntityNode], dict[str, str]]: + start = time() + + # build node map + node_map = {} + for node in nodes: + node_map[node.name] = node + + # Prepare context for LLM + nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes] + + context = { + "nodes": nodes_context, + } + + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.node_list(context) + ) + + nodes_data = llm_response.get("nodes", []) + + end = time() + logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms") + + # Get full node data + unique_nodes = [] + uuid_map: dict[str, str] = {} + for node_data in nodes_data: + node = node_map[node_data["names"][0]] + unique_nodes.append(node) + + for name in node_data["names"][1:]: + uuid = node_map[name].uuid + uuid_value = node_map[node_data["names"][0]].uuid + uuid_map[uuid] = uuid_value + + return unique_nodes, uuid_map diff --git a/core/utils/search/search_utils.py b/core/utils/search/search_utils.py index c57a6a35..110b7a21 100644 --- a/core/utils/search/search_utils.py +++ b/core/utils/search/search_utils.py @@ -1,6 +1,7 @@ import asyncio import logging from datetime import datetime +from time import time from neo4j import AsyncDriver @@ -9,6 +10,8 @@ from core.nodes import EntityNode logger = logging.getLogger(__name__) +RELEVANT_SCHEMA_LIMIT = 3 + async def bfs(node_ids: list[str], driver: AsyncDriver): records, _, _ = await driver.execute_query( @@ -60,7 +63,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # vector similarity search over embedded facts records, _, _ = await driver.execute_query( @@ -80,9 +83,10 @@ async def edge_similarity_search( r.expired_at AS expired_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT 10 + ORDER BY score DESC LIMIT $limit """, search_vector=search_vector, + limit=limit, ) edges: list[EntityEdge] = [] @@ -106,18 +110,16 @@ async def edge_similarity_search( edges.append(edge) - logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}") - return edges async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( """ - CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector) + CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score RETURN n.uuid As uuid, @@ -127,6 +129,7 @@ async def entity_similarity_search( ORDER BY score DESC """, search_vector=search_vector, + limit=limit, ) nodes: list[EntityNode] = [] @@ -141,12 +144,12 @@ async def entity_similarity_search( ) ) - logger.info(f"name semantic search results. RESULT: {nodes}") - return nodes -async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]: +async def entity_fulltext_search( + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT +) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = query + "~" records, _, _ = await driver.execute_query( @@ -158,9 +161,10 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity node.created_at AS created_at, node.summary AS summary ORDER BY score DESC - LIMIT 10 + LIMIT $limit """, query=fuzzy_query, + limit=limit, ) nodes: list[EntityNode] = [] @@ -175,12 +179,12 @@ async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[Entity ) ) - logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}") - return nodes -async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]: +async def edge_fulltext_search( + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT +) -> list[EntityEdge]: # fulltext search over facts fuzzy_query = query + "~" @@ -201,9 +205,10 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd r.expired_at AS expired_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT 10 + ORDER BY score DESC LIMIT $limit """, query=fuzzy_query, + limit=limit, ) edges: list[EntityEdge] = [] @@ -227,10 +232,6 @@ async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEd edges.append(edge) - logger.info( - f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}" - ) - return edges @@ -238,7 +239,9 @@ async def get_relevant_nodes( nodes: list[EntityNode], driver: AsyncDriver, ) -> list[EntityNode]: - relevant_nodes: dict[str, EntityNode] = {} + start = time() + relevant_nodes: list[EntityNode] = [] + relevant_node_uuids = set() results = await asyncio.gather( *[entity_fulltext_search(node.name, driver) for node in nodes], @@ -247,18 +250,27 @@ async def get_relevant_nodes( for result in results: for node in result: - relevant_nodes[node.uuid] = node + if node.uuid in relevant_node_uuids: + continue - logger.info(f"Found relevant nodes: {relevant_nodes.keys()}") + relevant_node_uuids.add(node.uuid) + relevant_nodes.append(node) - return relevant_nodes.values() + end = time() + logger.info( + f"Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms" + ) + + return relevant_nodes async def get_relevant_edges( edges: list[EntityEdge], driver: AsyncDriver, ) -> list[EntityEdge]: - relevant_edges: dict[str, EntityEdge] = {} + start = time() + relevant_edges: list[EntityEdge] = [] + relevant_edge_uuids = set() results = await asyncio.gather( *[edge_similarity_search(edge.fact_embedding, driver) for edge in edges], @@ -267,8 +279,15 @@ async def get_relevant_edges( for result in results: for edge in result: - relevant_edges[edge.uuid] = edge + if edge.uuid in relevant_edge_uuids: + continue - logger.info(f"Found relevant nodes: {relevant_edges.keys()}") + relevant_edge_uuids.add(edge.uuid) + relevant_edges.append(edge) - return list(relevant_edges.values()) + end = time() + logger.info( + f"Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms" + ) + + return relevant_edges diff --git a/podcast_runner.py b/podcast_runner.py index a22ccb6d..2fff8285 100644 --- a/podcast_runner.py +++ b/podcast_runner.py @@ -1,4 +1,5 @@ from core import Graphiti +from core.utils.bulk_utils import BulkEpisode from core.utils.maintenance.graph_data_operations import clear_data from dotenv import load_dotenv import os @@ -37,18 +38,33 @@ def setup_logging(): return logger -async def main(): +async def main(use_bulk: bool = True): setup_logging() client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) await clear_data(client.driver) messages = parse_podcast_messages() - for i, message in enumerate(messages[3:50]): - await client.add_episode( + + if not use_bulk: + for i, message in enumerate(messages[3:14]): + await client.add_episode( + name=f"Message {i}", + episode_body=f"{message.speaker_name} ({message.role}): {message.content}", + reference_time=message.actual_timestamp, + source_description="Podcast Transcript", + ) + + episodes: list[BulkEpisode] = [ + BulkEpisode( name=f"Message {i}", - episode_body=f"{message.speaker_name} ({message.role}): {message.content}", - reference_time=message.actual_timestamp, + content=f"{message.speaker_name} ({message.role}): {message.content}", source_description="Podcast Transcript", + episode_type="string", + reference_time=message.actual_timestamp, ) + for i, message in enumerate(messages[3:7]) + ] + + await client.add_episode_bulk(episodes) asyncio.run(main())