Remove Episode by uuid (#261)
* add remove_episode * delete episodes * update * bump version
This commit is contained in:
parent
104516bd89
commit
a99aad59de
3 changed files with 40 additions and 4 deletions
|
|
@ -21,6 +21,7 @@ from time import time
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
|
|
@ -747,3 +748,34 @@ class Graphiti:
|
||||||
await add_nodes_and_edges_bulk(
|
await add_nodes_and_edges_bulk(
|
||||||
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
|
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def remove_episode(self, episode_uuid: str):
|
||||||
|
# Find the episode to be deleted
|
||||||
|
episode = await EpisodicNode.get_by_uuid(self.driver, episode_uuid)
|
||||||
|
|
||||||
|
# Find edges mentioned by the episode
|
||||||
|
edges = await EntityEdge.get_by_uuids(self.driver, episode.entity_edges)
|
||||||
|
|
||||||
|
# We should only delete edges created by the episode
|
||||||
|
edges_to_delete: list[EntityEdge] = []
|
||||||
|
for edge in edges:
|
||||||
|
if edge.episodes[0] == episode.uuid:
|
||||||
|
edges_to_delete.append(edge)
|
||||||
|
|
||||||
|
# Find nodes mentioned by the episode
|
||||||
|
nodes = await get_mentioned_nodes(self.driver, [episode])
|
||||||
|
# We should delete all nodes that are only mentioned in the deleted episode
|
||||||
|
nodes_to_delete: list[EntityNode] = []
|
||||||
|
for node in nodes:
|
||||||
|
query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
|
||||||
|
records, _, _ = await self.driver.execute_query(
|
||||||
|
query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||||
|
)
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
if record['episode_count'] == 1:
|
||||||
|
nodes_to_delete.append(node)
|
||||||
|
|
||||||
|
await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
|
||||||
|
await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
|
||||||
|
await episode.delete(self.driver)
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,8 @@ class EpisodicNode(Node):
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.group_id AS group_id,
|
e.group_id AS group_id,
|
||||||
e.source_description AS source_description,
|
e.source_description AS source_description,
|
||||||
e.source AS source
|
e.source AS source,
|
||||||
|
e.entity_edges AS entity_edges
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
|
|
@ -197,7 +198,8 @@ class EpisodicNode(Node):
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.group_id AS group_id,
|
e.group_id AS group_id,
|
||||||
e.source_description AS source_description,
|
e.source_description AS source_description,
|
||||||
e.source AS source
|
e.source AS source,
|
||||||
|
e.entity_edges AS entity_edges
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
|
|
@ -233,7 +235,8 @@ class EpisodicNode(Node):
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.group_id AS group_id,
|
e.group_id AS group_id,
|
||||||
e.source_description AS source_description,
|
e.source_description AS source_description,
|
||||||
e.source AS source
|
e.source AS source,
|
||||||
|
e.entity_edges AS entity_edges
|
||||||
ORDER BY e.uuid DESC
|
ORDER BY e.uuid DESC
|
||||||
"""
|
"""
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
|
|
@ -490,6 +493,7 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||||
source=EpisodeType.from_str(record['source']),
|
source=EpisodeType.from_str(record['source']),
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
source_description=record['source_description'],
|
source_description=record['source_description'],
|
||||||
|
entity_edges=record['entity_edges'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.5.3"
|
version = "0.6.0"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue