diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index d1c4f1d2..952a4671 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -64,6 +64,21 @@ class Edge(BaseModel, ABC): return result + @classmethod + async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]): + result = await driver.execute_query( + """ + MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m) + WHERE e.uuid IN $uuids + DELETE e + """, + uuids=uuids, + ) + + logger.debug(f'Deleted Edges: {uuids}') + + return result + def __hash__(self): return hash(self.uuid) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 86265f9d..67f830fa 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -28,6 +28,7 @@ from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.edges import ( CommunityEdge, + Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings, @@ -46,6 +47,7 @@ from graphiti_core.nodes import ( EntityNode, EpisodeType, EpisodicNode, + Node, create_entity_node_embeddings, ) from graphiti_core.search.search import SearchConfig, search @@ -1066,12 +1068,7 @@ class Graphiti: if record['episode_count'] == 1: nodes_to_delete.append(node) - await semaphore_gather( - *[node.delete(self.driver) for node in nodes_to_delete], - max_coroutines=self.max_coroutines, - ) - await semaphore_gather( - *[edge.delete(self.driver) for edge in edges_to_delete], - max_coroutines=self.max_coroutines, - ) + await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete]) + + await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete]) await episode.delete(self.driver) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 0ec7c730..bbe360f9 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -142,6 +142,33 @@ class Node(BaseModel, ABC): batch_size=batch_size, ) + @classmethod + async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100): + if driver.provider == GraphProvider.FALKORDB: + for label in ['Entity', 'Episodic', 'Community']: + await driver.execute_query( + f""" + MATCH (n:{label}) + WHERE n.uuid IN $uuids + DETACH DELETE n + """, + uuids=uuids, + ) + else: + async with driver.session() as session: + await session.run( + """ + MATCH (n:Entity|Episodic|Community) + WHERE n.uuid IN $uuids + CALL { + WITH n + DETACH DELETE n + } IN TRANSACTIONS OF $batch_size ROWS + """, + uuids=uuids, + batch_size=batch_size, + ) + @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... diff --git a/pyproject.toml b/pyproject.toml index cf476ef6..c76f2f11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.18.6" +version = "0.18.7" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index 0d3fcadb..5f8407a6 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.18.6" +version = "0.18.7" source = { editable = "." } dependencies = [ { name = "diskcache" },