chore: Make deleting groups safer (#155)
* chore: Make deleting groups safer * chore: Use appropriate errors in delete group checks * chore: Add GroupsEdgesNotFound error type
This commit is contained in:
parent
bca838f61d
commit
b537cf56e5
4 changed files with 32 additions and 19 deletions
|
|
@ -24,7 +24,7 @@ from uuid import uuid4
|
|||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.errors import EdgeNotFoundError
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
from graphiti_core.nodes import Node
|
||||
|
|
@ -147,10 +147,9 @@ class EpisodicEdge(Edge):
|
|||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
uuids = [edge.uuid for edge in edges]
|
||||
|
||||
if len(edges) == 0:
|
||||
raise EdgeNotFoundError(uuids[0])
|
||||
raise GroupsEdgesNotFoundError(group_ids)
|
||||
return edges
|
||||
|
||||
|
||||
|
|
@ -293,10 +292,9 @@ class EntityEdge(Edge):
|
|||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
uuids = [edge.uuid for edge in edges]
|
||||
|
||||
if len(edges) == 0:
|
||||
raise EdgeNotFoundError(uuids[0])
|
||||
raise GroupsEdgesNotFoundError(group_ids)
|
||||
return edges
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,14 @@ class EdgeNotFoundError(GraphitiError):
|
|||
super().__init__(self.message)
|
||||
|
||||
|
||||
class GroupsEdgesNotFoundError(GraphitiError):
|
||||
"""Raised when no edges are found for a list of group ids."""
|
||||
|
||||
def __init__(self, group_ids: list[str]):
|
||||
self.message = f'no edges found for group ids {group_ids}'
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NodeNotFoundError(GraphitiError):
|
||||
"""Raised when a node is not found."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.5"
|
||||
version = "0.3.6"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
|
|
@ -1,15 +1,18 @@
|
|||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from graphiti_core import Graphiti # type: ignore
|
||||
from graphiti_core.edges import EntityEdge # type: ignore
|
||||
from graphiti_core.errors import EdgeNotFoundError, NodeNotFoundError # type: ignore
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError, NodeNotFoundError
|
||||
from graphiti_core.llm_client import LLMClient # type: ignore
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore
|
||||
|
||||
from graph_service.config import ZepEnvDep
|
||||
from graph_service.dto import FactResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZepGraphiti(Graphiti):
|
||||
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
|
||||
|
|
@ -36,18 +39,22 @@ class ZepGraphiti(Graphiti):
|
|||
async def delete_group(self, group_id: str):
|
||||
try:
|
||||
edges = await EntityEdge.get_by_group_ids(self.driver, [group_id])
|
||||
nodes = await EntityNode.get_by_group_ids(self.driver, [group_id])
|
||||
episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id])
|
||||
for edge in edges:
|
||||
await edge.delete(self.driver)
|
||||
for node in nodes:
|
||||
await node.delete(self.driver)
|
||||
for episode in episodes:
|
||||
await episode.delete(self.driver)
|
||||
except EdgeNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=e.message) from e
|
||||
except NodeNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=e.message) from e
|
||||
except GroupsEdgesNotFoundError:
|
||||
logger.warning(f'No edges found for group {group_id}')
|
||||
edges = []
|
||||
|
||||
nodes = await EntityNode.get_by_group_ids(self.driver, [group_id])
|
||||
|
||||
episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id])
|
||||
|
||||
for edge in edges:
|
||||
await edge.delete(self.driver)
|
||||
|
||||
for node in nodes:
|
||||
await node.delete(self.driver)
|
||||
|
||||
for episode in episodes:
|
||||
await episode.delete(self.driver)
|
||||
|
||||
async def delete_entity_edge(self, uuid: str):
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue