From a99aad59de3163a31a9891c6b2190c76e4cb498c Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Wed, 5 Feb 2025 15:17:08 -0500 Subject: [PATCH] Remove Episode by uuid (#261) * add remove_episode * delete episodes * update * bump version --- graphiti_core/graphiti.py | 32 ++++++++++++++++++++++++++++++++ graphiti_core/nodes.py | 10 +++++++--- pyproject.toml | 2 +- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 83688392..438d31a6 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -21,6 +21,7 @@ from time import time from dotenv import load_dotenv from neo4j import AsyncGraphDatabase from pydantic import BaseModel +from typing_extensions import LiteralString from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient @@ -747,3 +748,34 @@ class Graphiti: await add_nodes_and_edges_bulk( self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges ) + + async def remove_episode(self, episode_uuid: str): + # Find the episode to be deleted + episode = await EpisodicNode.get_by_uuid(self.driver, episode_uuid) + + # Find edges mentioned by the episode + edges = await EntityEdge.get_by_uuids(self.driver, episode.entity_edges) + + # We should only delete edges created by the episode + edges_to_delete: list[EntityEdge] = [] + for edge in edges: + if edge.episodes[0] == episode.uuid: + edges_to_delete.append(edge) + + # Find nodes mentioned by the episode + nodes = await get_mentioned_nodes(self.driver, [episode]) + # We should delete all nodes that are only mentioned in the deleted episode + nodes_to_delete: list[EntityNode] = [] + for node in nodes: + query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count' + records, _, _ = await self.driver.execute_query( + query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r' + ) + + for record in records: + if record['episode_count'] == 1: + nodes_to_delete.append(node) + + await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete]) + await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete]) + await episode.delete(self.driver) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 3d635060..6a490c8c 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -170,7 +170,8 @@ class EpisodicNode(Node): e.name AS name, e.group_id AS group_id, e.source_description AS source_description, - e.source AS source + e.source AS source, + e.entity_edges AS entity_edges """, uuid=uuid, database_=DEFAULT_DATABASE, @@ -197,7 +198,8 @@ class EpisodicNode(Node): e.name AS name, e.group_id AS group_id, e.source_description AS source_description, - e.source AS source + e.source AS source, + e.entity_edges AS entity_edges """, uuids=uuids, database_=DEFAULT_DATABASE, @@ -233,7 +235,8 @@ class EpisodicNode(Node): e.name AS name, e.group_id AS group_id, e.source_description AS source_description, - e.source AS source + e.source AS source, + e.entity_edges AS entity_edges ORDER BY e.uuid DESC """ + limit_query, @@ -490,6 +493,7 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode: source=EpisodeType.from_str(record['source']), name=record['name'], source_description=record['source_description'], + entity_edges=record['entity_edges'], ) diff --git a/pyproject.toml b/pyproject.toml index bca5595c..90649789 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.5.3" +version = "0.6.0" description = "A temporal graph building library" authors = [ "Paul Paliychuk ",