diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 3b0f3992..e2878391 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -24,6 +24,7 @@ from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from graphiti_core.errors import EdgeNotFoundError from graphiti_core.helpers import parse_db_date from graphiti_core.llm_client.config import EMBEDDING_DIM from graphiti_core.nodes import Node @@ -104,7 +105,8 @@ class EpisodicEdge(Edge): edges = [get_episodic_edge_from_record(record) for record in records] logger.info(f'Found Edge: {uuid}') - + if len(edges) == 0: + raise EdgeNotFoundError(uuid) return edges[0] @@ -191,7 +193,8 @@ class EntityEdge(Edge): edges = [get_entity_edge_from_record(record) for record in records] logger.info(f'Found Edge: {uuid}') - + if len(edges) == 0: + raise EdgeNotFoundError(uuid) return edges[0] diff --git a/graphiti_core/errors.py b/graphiti_core/errors.py new file mode 100644 index 00000000..e6da50d0 --- /dev/null +++ b/graphiti_core/errors.py @@ -0,0 +1,18 @@ +class GraphitiError(Exception): + """Base exception class for Graphiti Core.""" + + +class EdgeNotFoundError(GraphitiError): + """Raised when an edge is not found.""" + + def __init__(self, uuid: str): + self.message = f'edge {uuid} not found' + super().__init__(self.message) + + +class NodeNotFoundError(GraphitiError): + """Raised when a node is not found.""" + + def __init__(self, uuid: str): + self.message = f'node {uuid} not found' + super().__init__(self.message) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 354c9d67..e924c897 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -25,6 +25,7 @@ from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from graphiti_core.errors import NodeNotFoundError from graphiti_core.llm_client.config import EMBEDDING_DIM logger = logging.getLogger(__name__) @@ -148,7 +149,7 @@ class EpisodicNode(Node): e.valid_at AS valid_at, e.uuid AS uuid, e.name AS name, - e.group_id AS group_id + e.group_id AS group_id, e.source_description AS source_description, e.source AS source """, @@ -159,6 +160,9 @@ class EpisodicNode(Node): logger.info(f'Found Node: {uuid}') + if len(episodes) == 0: + raise NodeNotFoundError(uuid) + return episodes[0] @classmethod diff --git a/server/graph_service/routers/ingest.py b/server/graph_service/routers/ingest.py index 590f2eac..58ff9f35 100644 --- a/server/graph_service/routers/ingest.py +++ b/server/graph_service/routers/ingest.py @@ -83,6 +83,18 @@ async def add_entity_node( return node +@router.delete('/entity-edge/{uuid}', status_code=status.HTTP_200_OK) +async def delete_entity_edge(uuid: str, graphiti: ZepGraphitiDep): + await graphiti.delete_entity_edge(uuid) + return Result(message='Entity Edge deleted', success=True) + + +@router.delete('/episode/{uuid}', status_code=status.HTTP_200_OK) +async def delete_episode(uuid: str, graphiti: ZepGraphitiDep): + await graphiti.delete_episodic_node(uuid) + return Result(message='Episode deleted', success=True) + + @router.post('/clear', status_code=status.HTTP_200_OK) async def clear( graphiti: ZepGraphitiDep, diff --git a/server/graph_service/routers/retrieve.py b/server/graph_service/routers/retrieve.py index 4ee3a4ac..57d697b0 100644 --- a/server/graph_service/routers/retrieve.py +++ b/server/graph_service/routers/retrieve.py @@ -27,6 +27,11 @@ async def search(query: SearchQuery, graphiti: ZepGraphitiDep): ) +@router.get('/entity-edge/{uuid}', status_code=status.HTTP_200_OK) +async def get_entity_edge(uuid: str, graphiti: ZepGraphitiDep): + return await graphiti.get_entity_edge(uuid) + + @router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK) async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep): episodes = await graphiti.retrieve_episodes( diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index 9b98b35a..236a1b35 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -1,10 +1,11 @@ from typing import Annotated -from fastapi import Depends +from fastapi import Depends, HTTPException from graphiti_core import Graphiti # type: ignore from graphiti_core.edges import EntityEdge # type: ignore +from graphiti_core.errors import EdgeNotFoundError, NodeNotFoundError # type: ignore from graphiti_core.llm_client import LLMClient # type: ignore -from graphiti_core.nodes import EntityNode # type: ignore +from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore from graph_service.config import ZepEnvDep from graph_service.dto import FactResult @@ -25,6 +26,27 @@ class ZepGraphiti(Graphiti): await new_node.save(self.driver) return new_node + async def get_entity_edge(self, uuid: str): + try: + edge = await EntityEdge.get_by_uuid(self.driver, uuid) + return edge + except EdgeNotFoundError as e: + raise HTTPException(status_code=404, detail=e.message) from e + + async def delete_entity_edge(self, uuid: str): + try: + edge = await EntityEdge.get_by_uuid(self.driver, uuid) + await edge.delete(self.driver) + except EdgeNotFoundError as e: + raise HTTPException(status_code=404, detail=e.message) from e + + async def delete_episodic_node(self, uuid: str): + try: + episode = await EpisodicNode.get_by_uuid(self.driver, uuid) + await episode.delete(self.driver) + except NodeNotFoundError as e: + raise HTTPException(status_code=404, detail=e.message) from e + async def get_graphiti(settings: ZepEnvDep): client = ZepGraphiti(