Remove Episode by uuid (#261)
* add remove_episode * delete episodes * update * bump version
This commit is contained in:
parent
104516bd89
commit
a99aad59de
3 changed files with 40 additions and 4 deletions
|
|
@ -21,6 +21,7 @@ from time import time
|
|||
from dotenv import load_dotenv
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
|
|
@ -747,3 +748,34 @@ class Graphiti:
|
|||
await add_nodes_and_edges_bulk(
|
||||
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
|
||||
)
|
||||
|
||||
async def remove_episode(self, episode_uuid: str):
|
||||
# Find the episode to be deleted
|
||||
episode = await EpisodicNode.get_by_uuid(self.driver, episode_uuid)
|
||||
|
||||
# Find edges mentioned by the episode
|
||||
edges = await EntityEdge.get_by_uuids(self.driver, episode.entity_edges)
|
||||
|
||||
# We should only delete edges created by the episode
|
||||
edges_to_delete: list[EntityEdge] = []
|
||||
for edge in edges:
|
||||
if edge.episodes[0] == episode.uuid:
|
||||
edges_to_delete.append(edge)
|
||||
|
||||
# Find nodes mentioned by the episode
|
||||
nodes = await get_mentioned_nodes(self.driver, [episode])
|
||||
# We should delete all nodes that are only mentioned in the deleted episode
|
||||
nodes_to_delete: list[EntityNode] = []
|
||||
for node in nodes:
|
||||
query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
|
||||
records, _, _ = await self.driver.execute_query(
|
||||
query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
|
||||
for record in records:
|
||||
if record['episode_count'] == 1:
|
||||
nodes_to_delete.append(node)
|
||||
|
||||
await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
|
||||
await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
|
||||
await episode.delete(self.driver)
|
||||
|
|
|
|||
|
|
@ -170,7 +170,8 @@ class EpisodicNode(Node):
|
|||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
|
|
@ -197,7 +198,8 @@ class EpisodicNode(Node):
|
|||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
|
|
@ -233,7 +235,8 @@ class EpisodicNode(Node):
|
|||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
|
|
@ -490,6 +493,7 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|||
source=EpisodeType.from_str(record['source']),
|
||||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
entity_edges=record['entity_edges'],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.5.3"
|
||||
version = "0.6.0"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue