chore: Make deleting groups safer (#155)

* chore: Make deleting groups safer

* chore: Use appropriate errors in delete group checks

* chore: Add GroupsEdgesNotFound error type
This commit is contained in:
Pavlo Paliychuk 2024-09-24 20:08:09 -04:00 committed by GitHub
parent bca838f61d
commit b537cf56e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 32 additions and 19 deletions

View file

@ -24,7 +24,7 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.errors import EdgeNotFoundError
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.nodes import Node
@ -147,10 +147,9 @@ class EpisodicEdge(Edge):
)
edges = [get_episodic_edge_from_record(record) for record in records]
uuids = [edge.uuid for edge in edges]
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
raise GroupsEdgesNotFoundError(group_ids)
return edges
@ -293,10 +292,9 @@ class EntityEdge(Edge):
)
edges = [get_entity_edge_from_record(record) for record in records]
uuids = [edge.uuid for edge in edges]
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
raise GroupsEdgesNotFoundError(group_ids)
return edges

View file

@ -27,6 +27,14 @@ class EdgeNotFoundError(GraphitiError):
super().__init__(self.message)
class GroupsEdgesNotFoundError(GraphitiError):
"""Raised when no edges are found for a list of group ids."""
def __init__(self, group_ids: list[str]):
self.message = f'no edges found for group ids {group_ids}'
super().__init__(self.message)
class NodeNotFoundError(GraphitiError):
"""Raised when a node is not found."""

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.5"
version = "0.3.6"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",

View file

@ -1,15 +1,18 @@
import logging
from typing import Annotated
from fastapi import Depends, HTTPException
from graphiti_core import Graphiti # type: ignore
from graphiti_core.edges import EntityEdge # type: ignore
from graphiti_core.errors import EdgeNotFoundError, NodeNotFoundError # type: ignore
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError, NodeNotFoundError
from graphiti_core.llm_client import LLMClient # type: ignore
from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore
from graph_service.config import ZepEnvDep
from graph_service.dto import FactResult
logger = logging.getLogger(__name__)
class ZepGraphiti(Graphiti):
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
@ -36,18 +39,22 @@ class ZepGraphiti(Graphiti):
async def delete_group(self, group_id: str):
try:
edges = await EntityEdge.get_by_group_ids(self.driver, [group_id])
nodes = await EntityNode.get_by_group_ids(self.driver, [group_id])
episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id])
for edge in edges:
await edge.delete(self.driver)
for node in nodes:
await node.delete(self.driver)
for episode in episodes:
await episode.delete(self.driver)
except EdgeNotFoundError as e:
raise HTTPException(status_code=404, detail=e.message) from e
except NodeNotFoundError as e:
raise HTTPException(status_code=404, detail=e.message) from e
except GroupsEdgesNotFoundError:
logger.warning(f'No edges found for group {group_id}')
edges = []
nodes = await EntityNode.get_by_group_ids(self.driver, [group_id])
episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id])
for edge in edges:
await edge.delete(self.driver)
for node in nodes:
await node.delete(self.driver)
for episode in episodes:
await episode.delete(self.driver)
async def delete_entity_edge(self, uuid: str):
try: