diff --git a/cognee/api/client.py b/cognee/api/client.py index 53c9f9762..95a273351 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -3,10 +3,13 @@ import os import uvicorn import logging import sentry_sdk -from fastapi import FastAPI +from fastapi import FastAPI, status from fastapi.responses import JSONResponse, Response from fastapi.middleware.cors import CORSMiddleware +from cognee.exceptions import CogneeApiError +from traceback import format_exc + # Set up logging logging.basicConfig( level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL) @@ -76,6 +79,26 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}), ) +@app.exception_handler(CogneeApiError) +async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse: + detail = {} + + if exc.name and exc.message and exc.status_code: + status_code = exc.status_code + detail["message"] = f"{exc.message} [{exc.name}]" + else: + # Log an error indicating the exception is improperly defined + logger.error("Improperly defined exception: %s", exc) + # Provide a default error response + detail["message"] = "An unexpected error occurred." + status_code = status.HTTP_418_IM_A_TEAPOT + + # log the stack trace for easier serverside debugging + logger.error(format_exc()) + return JSONResponse( + status_code=status_code, content={"detail": detail["message"]} + ) + app.include_router( get_auth_router(), prefix = "/api/v1/auth", diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 2cbb606c1..59c658300 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -22,11 +22,6 @@ logger = logging.getLogger("code_graph_pipeline") update_status_lock = asyncio.Lock() -class PermissionDeniedException(Exception): - def __init__(self, message: str): - self.message = message - super().__init__(self.message) - async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None): if user is None: user = await get_default_user() diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index be9ecd1ce..791ab516a 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -24,11 +24,6 @@ logger = logging.getLogger("cognify.v2") update_status_lock = asyncio.Lock() -class PermissionDeniedException(Exception): - def __init__(self, message: str): - self.message = message - super().__init__(self.message) - async def cognify(datasets: Union[str, list[str]] = None, user: User = None): if user is None: user = await get_default_user() diff --git a/cognee/api/v1/config/config.py b/cognee/api/v1/config/config.py index 1fbed9bdc..1347fcba8 100644 --- a/cognee/api/v1/config/config.py +++ b/cognee/api/v1/config/config.py @@ -1,6 +1,7 @@ """ This module is used to set the configuration of the system.""" import os from cognee.base_config import get_base_config +from cognee.exceptions import InvalidValueError, InvalidAttributeError from cognee.modules.cognify.config import get_cognify_config from cognee.infrastructure.data.chunking.config import get_chunk_config from cognee.infrastructure.databases.vector import get_vectordb_config @@ -85,7 +86,7 @@ class config(): if hasattr(llm_config, key): object.__setattr__(llm_config, key, value) else: - raise AttributeError(f"'{key}' is not a valid attribute of the config.") + raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.") @staticmethod def set_chunk_strategy(chunk_strategy: object): @@ -123,7 +124,7 @@ class config(): if hasattr(relational_db_config, key): object.__setattr__(relational_db_config, key, value) else: - raise AttributeError(f"'{key}' is not a valid attribute of the config.") + raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.") @staticmethod def set_vector_db_config(config_dict: dict): @@ -135,7 +136,7 @@ class config(): if hasattr(vector_db_config, key): object.__setattr__(vector_db_config, key, value) else: - raise AttributeError(f"'{key}' is not a valid attribute of the config.") + raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.") @staticmethod def set_vector_db_key(db_key: str): @@ -153,7 +154,7 @@ class config(): base_config = get_base_config() if "username" not in graphistry_config or "password" not in graphistry_config: - raise ValueError("graphistry_config dictionary must contain 'username' and 'password' keys.") + raise InvalidValueError(message="graphistry_config dictionary must contain 'username' and 'password' keys.") base_config.graphistry_username = graphistry_config.get("username") base_config.graphistry_password = graphistry_config.get("password") diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py index f27c6c2ad..31e3fa67d 100644 --- a/cognee/api/v1/datasets/routers/get_datasets_router.py +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse, FileResponse from pydantic import BaseModel from cognee.api.DTO import OutDTO +from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.modules.users.models import User from cognee.modules.users.methods import get_authenticated_user from cognee.modules.pipelines.models import PipelineRunStatus @@ -55,9 +56,8 @@ def get_datasets_router() -> APIRouter: dataset = await get_dataset(user.id, dataset_id) if dataset is None: - raise HTTPException( - status_code=404, - detail=f"Dataset ({dataset_id}) not found." + raise EntityNotFoundError( + message=f"Dataset ({dataset_id}) not found." ) await delete_dataset(dataset) @@ -72,17 +72,15 @@ def get_datasets_router() -> APIRouter: #TODO: Handle situation differently if user doesn't have permission to access data? if dataset is None: - raise HTTPException( - status_code=404, - detail=f"Dataset ({dataset_id}) not found." + raise EntityNotFoundError( + message=f"Dataset ({dataset_id}) not found." ) data = await get_data(data_id) if data is None: - raise HTTPException( - status_code=404, - detail=f"Dataset ({data_id}) not found." + raise EntityNotFoundError( + message=f"Data ({data_id}) not found." ) await delete_data(data) @@ -158,18 +156,13 @@ def get_datasets_router() -> APIRouter: dataset_data = await get_dataset_data(dataset.id) if dataset_data is None: - raise HTTPException(status_code=404, detail=f"No data found in dataset ({dataset_id}).") + raise EntityNotFoundError(message=f"No data found in dataset ({dataset_id}).") matching_data = [data for data in dataset_data if str(data.id) == data_id] # Check if matching_data contains an element if len(matching_data) == 0: - return JSONResponse( - status_code=404, - content={ - "detail": f"Data ({data_id}) not found in dataset ({dataset_id})." - } - ) + raise EntityNotFoundError(message= f"Data ({data_id}) not found in dataset ({dataset_id}).") data = matching_data[0] diff --git a/cognee/api/v1/permissions/routers/get_permissions_router.py b/cognee/api/v1/permissions/routers/get_permissions_router.py index ab20fb1a2..8d012d600 100644 --- a/cognee/api/v1/permissions/routers/get_permissions_router.py +++ b/cognee/api/v1/permissions/routers/get_permissions_router.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import JSONResponse from sqlalchemy.orm import Session + +from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundError from cognee.modules.users import get_user_db from cognee.modules.users.models import User, Group, Permission @@ -12,7 +14,7 @@ def get_permissions_router() -> APIRouter: group = db.query(Group).filter(Group.id == group_id).first() if not group: - raise HTTPException(status_code = 404, detail = "Group not found") + raise GroupNotFoundError permission = db.query(Permission).filter(Permission.name == permission).first() @@ -31,8 +33,10 @@ def get_permissions_router() -> APIRouter: user = db.query(User).filter(User.id == user_id).first() group = db.query(Group).filter(Group.id == group_id).first() - if not user or not group: - raise HTTPException(status_code = 404, detail = "User or group not found") + if not user: + raise UserNotFoundError + elif not group: + raise GroupNotFoundError user.groups.append(group) diff --git a/cognee/api/v1/search/search.legacy.py b/cognee/api/v1/search/search.legacy.py index aaa22fd62..cea3b3874 100644 --- a/cognee/api/v1/search/search.legacy.py +++ b/cognee/api/v1/search/search.legacy.py @@ -9,6 +9,8 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent from cognee.modules.search.vector.search_traverse import search_traverse from cognee.modules.search.graph.search_summary import search_summary from cognee.modules.search.graph.search_similarity import search_similarity + +from cognee.exceptions import UserNotFoundError from cognee.shared.utils import send_telemetry from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.modules.users.methods import get_default_user @@ -47,7 +49,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) -> user = await get_default_user() if user is None: - raise PermissionError("No user found in the system. Please create a user.") + raise UserNotFoundError own_document_ids = await get_document_ids_for_user(user.id) search_params = SearchParameters(search_type = search_type, params = params) diff --git a/cognee/api/v1/search/search_v2.py b/cognee/api/v1/search/search_v2.py index c1bc0ee4d..d77aa5fa8 100644 --- a/cognee/api/v1/search/search_v2.py +++ b/cognee/api/v1/search/search_v2.py @@ -2,9 +2,12 @@ import json from uuid import UUID from enum import Enum from typing import Callable, Dict + +from cognee.exceptions import InvalidValueError from cognee.modules.search.operations import log_query, log_result from cognee.modules.storage.utils import JSONEncoder from cognee.shared.utils import send_telemetry +from cognee.modules.users.exceptions import UserNotFoundError from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.modules.users.permissions.methods import get_document_ids_for_user @@ -22,7 +25,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) -> user = await get_default_user() if user is None: - raise PermissionError("No user found in the system. Please create a user.") + raise UserNotFoundError query = await log_query(query_text, str(query_type), user.id) @@ -52,7 +55,7 @@ async def specific_search(query_type: SearchType, query: str, user) -> list: search_task = search_tasks.get(query_type) if search_task is None: - raise ValueError(f"Unsupported search type: {query_type}") + raise InvalidValueError(message=f"Unsupported search type: {query_type}") send_telemetry("cognee.search EXECUTION STARTED", user.id) diff --git a/cognee/exceptions/__init__.py b/cognee/exceptions/__init__.py new file mode 100644 index 000000000..40120e0e1 --- /dev/null +++ b/cognee/exceptions/__init__.py @@ -0,0 +1,13 @@ +""" +Custom exceptions for the Cognee API. + +This module defines a set of exceptions for handling various application errors, +such as service failures, resource conflicts, and invalid operations. +""" + +from .exceptions import ( + CogneeApiError, + ServiceError, + InvalidValueError, + InvalidAttributeError, +) \ No newline at end of file diff --git a/cognee/exceptions/exceptions.py b/cognee/exceptions/exceptions.py new file mode 100644 index 000000000..f94daf8c9 --- /dev/null +++ b/cognee/exceptions/exceptions.py @@ -0,0 +1,54 @@ +from fastapi import status +import logging + +logger = logging.getLogger(__name__) + +class CogneeApiError(Exception): + """Base exception class""" + + def __init__( + self, + message: str = "Service is unavailable.", + name: str = "Cognee", + status_code=status.HTTP_418_IM_A_TEAPOT, + ): + self.message = message + self.name = name + self.status_code = status_code + + # Automatically log the exception details + logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})") + + super().__init__(self.message, self.name) + + +class ServiceError(CogneeApiError): + """Failures in external services or APIs, like a database or a third-party service""" + + def __init__( + self, + message: str = "Service is unavailable.", + name: str = "ServiceError", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + ): + super().__init__(message, name, status_code) + + +class InvalidValueError(CogneeApiError): + def __init__( + self, + message: str = "Invalid Value.", + name: str = "InvalidValueError", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + ): + super().__init__(message, name, status_code) + + +class InvalidAttributeError(CogneeApiError): + def __init__( + self, + message: str = "Invalid attribute.", + name: str = "InvalidAttributeError", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + ): + super().__init__(message, name, status_code) \ No newline at end of file diff --git a/cognee/infrastructure/data/utils/extract_keywords.py b/cognee/infrastructure/data/utils/extract_keywords.py index ab32ddefb..11f061889 100644 --- a/cognee/infrastructure/data/utils/extract_keywords.py +++ b/cognee/infrastructure/data/utils/extract_keywords.py @@ -1,9 +1,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer + +from cognee.exceptions import InvalidValueError from cognee.shared.utils import extract_pos_tags def extract_keywords(text: str) -> list[str]: if len(text) == 0: - raise ValueError("extract_keywords cannot extract keywords from empty text.") + raise InvalidValueError(message="extract_keywords cannot extract keywords from empty text.") tags = extract_pos_tags(text) nouns = [word for (word, tag) in tags if tag == "NN"] diff --git a/cognee/infrastructure/databases/exceptions/__init__.py b/cognee/infrastructure/databases/exceptions/__init__.py new file mode 100644 index 000000000..5836e7d11 --- /dev/null +++ b/cognee/infrastructure/databases/exceptions/__init__.py @@ -0,0 +1,10 @@ +""" +Custom exceptions for the Cognee API. + +This module defines a set of exceptions for handling various database errors +""" + +from .exceptions import ( + EntityNotFoundError, + EntityAlreadyExistsError, +) \ No newline at end of file diff --git a/cognee/infrastructure/databases/exceptions/exceptions.py b/cognee/infrastructure/databases/exceptions/exceptions.py new file mode 100644 index 000000000..af15bb616 --- /dev/null +++ b/cognee/infrastructure/databases/exceptions/exceptions.py @@ -0,0 +1,25 @@ +from cognee.exceptions import CogneeApiError +from fastapi import status + +class EntityNotFoundError(CogneeApiError): + """Database returns nothing""" + + def __init__( + self, + message: str = "The requested entity does not exist.", + name: str = "EntityNotFoundError", + status_code=status.HTTP_404_NOT_FOUND, + ): + super().__init__(message, name, status_code) + + +class EntityAlreadyExistsError(CogneeApiError): + """Conflict detected, like trying to create a resource that already exists""" + + def __init__( + self, + message: str = "The entity already exists.", + name: str = "EntityAlreadyExistsError", + status_code=status.HTTP_409_CONFLICT, + ): + super().__init__(message, name, status_code) \ No newline at end of file diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index ea5a75088..a28d827a1 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -4,6 +4,7 @@ from typing import Any from uuid import UUID from falkordb import FalkorDB +from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine @@ -200,7 +201,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): with_vector: bool = False, ): if query_text is None and query_vector is None: - raise ValueError("One of query_text or query_vector must be provided!") + raise InvalidValueError(message="One of query_text or query_vector must be provided!") if query_text and not query_vector: query_vector = (await self.embed_data([query_text]))[0] diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index a5733967e..8041aeaea 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -6,6 +6,8 @@ from contextlib import asynccontextmanager from sqlalchemy import text, select, MetaData, Table from sqlalchemy.orm import joinedload from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker + +from cognee.infrastructure.databases.exceptions import EntityNotFoundError from ..ModelBase import Base class SQLAlchemyAdapter(): @@ -116,7 +118,7 @@ class SQLAlchemyAdapter(): if table_name in Base.metadata.tables: return Base.metadata.tables[table_name] else: - raise ValueError(f"Table '{table_name}' not found.") + raise EntityNotFoundError(message=f"Table '{table_name}' not found.") else: # Create a MetaData instance to load table information metadata = MetaData() @@ -127,7 +129,7 @@ class SQLAlchemyAdapter(): # Check if table is in list of tables for the given schema if full_table_name in metadata.tables: return metadata.tables[full_table_name] - raise ValueError(f"Table '{full_table_name}' not found.") + raise EntityNotFoundError(message=f"Table '{full_table_name}' not found.") async def get_table_names(self) -> List[str]: """ diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 5204c1bad..3cea3bc27 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -5,6 +5,8 @@ from uuid import UUID import lancedb from pydantic import BaseModel from lancedb.pydantic import Vector, LanceModel + +from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.files.storage import LocalStorage from cognee.modules.storage.utils import copy_model, get_own_properties @@ -123,7 +125,7 @@ class LanceDBAdapter(VectorDBInterface): new_size = await collection.count_rows() if new_size <= original_size: - raise ValueError( + raise InvalidValueError(message= "LanceDB create_datapoints error: data points did not get added.") @@ -149,7 +151,7 @@ class LanceDBAdapter(VectorDBInterface): query_vector: List[float] = None ): if query_text is None and query_vector is None: - raise ValueError("One of query_text or query_vector must be provided!") + raise InvalidValueError(message="One of query_text or query_vector must be provided!") if query_text and not query_vector: query_vector = (await self.embedding_engine.embed_text([query_text]))[0] @@ -179,7 +181,7 @@ class LanceDBAdapter(VectorDBInterface): normalized: bool = True ): if query_text is None and query_vector is None: - raise ValueError("One of query_text or query_vector must be provided!") + raise InvalidValueError(message="One of query_text or query_vector must be provided!") if query_text and not query_vector: query_vector = (await self.embedding_engine.embed_text([query_text]))[0] diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index fd0fd493c..27db2c276 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -6,6 +6,8 @@ from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy import JSON, Column, Table, select, delete from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker +from cognee.exceptions import InvalidValueError +from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.engine import DataPoint from .serialize_data import serialize_data @@ -156,7 +158,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): if collection_name in Base.metadata.tables: return Base.metadata.tables[collection_name] else: - raise ValueError(f"Table '{collection_name}' not found.") + raise EntityNotFoundError(message=f"Table '{collection_name}' not found.") async def retrieve(self, collection_name: str, data_point_ids: List[str]): # Get PGVectorDataPoint Table from database @@ -230,7 +232,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): with_vector: bool = False, ) -> List[ScoredResult]: if query_text is None and query_vector is None: - raise ValueError("One of query_text or query_vector must be provided!") + raise InvalidValueError(message="One of query_text or query_vector must be provided!") if query_text and not query_vector: query_vector = (await self.embedding_engine.embed_text([query_text]))[0] diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index c340928f4..dc33e98ae 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -3,6 +3,7 @@ from uuid import UUID from typing import List, Dict, Optional from qdrant_client import AsyncQdrantClient, models +from cognee.exceptions import InvalidValueError from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.engine import DataPoint from ..vector_db_interface import VectorDBInterface @@ -186,7 +187,7 @@ class QDrantAdapter(VectorDBInterface): with_vector: bool = False ): if query_text is None and query_vector is None: - raise ValueError("One of query_text or query_vector must be provided!") + raise InvalidValueError(message="One of query_text or query_vector must be provided!") client = self.get_qdrant_client() diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index c9848e02c..0c97dc9a8 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -3,6 +3,7 @@ import logging from typing import List, Optional from uuid import UUID +from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint from ..vector_db_interface import VectorDBInterface from ..models.ScoredResult import ScoredResult @@ -194,7 +195,7 @@ class WeaviateAdapter(VectorDBInterface): import weaviate.classes as wvc if query_text is None and query_vector is None: - raise ValueError("One of query_text or query_vector must be provided!") + raise InvalidValueError(message="One of query_text or query_vector must be provided!") if query_vector is None: query_vector = (await self.embed_data([query_text]))[0] diff --git a/cognee/infrastructure/llm/anthropic/adapter.py b/cognee/infrastructure/llm/anthropic/adapter.py index 8df59e3e5..b74beaf6e 100644 --- a/cognee/infrastructure/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/anthropic/adapter.py @@ -3,6 +3,8 @@ from pydantic import BaseModel import instructor from tenacity import retry, stop_after_attempt import anthropic + +from cognee.exceptions import InvalidValueError from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt @@ -45,7 +47,7 @@ class AnthropicAdapter(LLMInterface): if not text_input: text_input = "No user input provided." if not system_prompt: - raise ValueError("No system prompt path provided.") + raise InvalidValueError(message="No system prompt path provided.") system_prompt = read_query_prompt(system_prompt) diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index f65d559d5..7d3d97ebb 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -5,6 +5,8 @@ from pydantic import BaseModel import instructor from tenacity import retry, stop_after_attempt import openai + +from cognee.exceptions import InvalidValueError from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.shared.data_models import MonitoringTool @@ -128,7 +130,7 @@ class GenericAPIAdapter(LLMInterface): if not text_input: text_input = "No user input provided." if not system_prompt: - raise ValueError("No system prompt path provided.") + raise InvalidValueError(message="No system prompt path provided.") system_prompt = read_query_prompt(system_prompt) formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 1449d33b3..9a23892f2 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -1,5 +1,7 @@ """Get the LLM client.""" from enum import Enum + +from cognee.exceptions import InvalidValueError from cognee.infrastructure.llm import get_llm_config # Define an Enum for LLM Providers @@ -17,7 +19,7 @@ def get_llm_client(): if provider == LLMProvider.OPENAI: if llm_config.llm_api_key is None: - raise ValueError("LLM API key is not set.") + raise InvalidValueError(message="LLM API key is not set.") from .openai.adapter import OpenAIAdapter @@ -32,7 +34,7 @@ def get_llm_client(): elif provider == LLMProvider.OLLAMA: if llm_config.llm_api_key is None: - raise ValueError("LLM API key is not set.") + raise InvalidValueError(message="LLM API key is not set.") from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama") @@ -43,10 +45,10 @@ def get_llm_client(): elif provider == LLMProvider.CUSTOM: if llm_config.llm_api_key is None: - raise ValueError("LLM API key is not set.") + raise InvalidValueError(message="LLM API key is not set.") from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom") else: - raise ValueError(f"Unsupported LLM provider: {provider}") + raise InvalidValueError(message=f"Unsupported LLM provider: {provider}") diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index 1dc9b70f5..b2929c6c0 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -7,6 +7,7 @@ import litellm import instructor from pydantic import BaseModel +from cognee.exceptions import InvalidValueError from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt @@ -127,7 +128,7 @@ class OpenAIAdapter(LLMInterface): if not text_input: text_input = "No user input provided." if not system_prompt: - raise ValueError("No system prompt path provided.") + raise InvalidValueError(message="No system prompt path provided.") system_prompt = read_query_prompt(system_prompt) formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None diff --git a/cognee/modules/data/methods/delete_data.py b/cognee/modules/data/methods/delete_data.py index 7560762e1..c0493a606 100644 --- a/cognee/modules/data/methods/delete_data.py +++ b/cognee/modules/data/methods/delete_data.py @@ -1,3 +1,4 @@ +from cognee.exceptions import InvalidAttributeError from cognee.modules.data.models import Data from cognee.infrastructure.databases.relational import get_relational_engine @@ -12,7 +13,7 @@ async def delete_data(data: Data): ValueError: If the data object is invalid. """ if not hasattr(data, '__tablename__'): - raise ValueError("The provided data object is missing the required '__tablename__' attribute.") + raise InvalidAttributeError(message="The provided data object is missing the required '__tablename__' attribute.") db_engine = get_relational_engine() diff --git a/cognee/modules/data/operations/translate_text.py b/cognee/modules/data/operations/translate_text.py index 411712648..d8c27e42a 100644 --- a/cognee/modules/data/operations/translate_text.py +++ b/cognee/modules/data/operations/translate_text.py @@ -1,5 +1,7 @@ import logging +from cognee.exceptions import InvalidValueError + logger = logging.getLogger(__name__) async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"): @@ -18,10 +20,10 @@ async def translate_text(text, source_language: str = "sr", target_language: str from botocore.exceptions import BotoCoreError, ClientError if not text: - raise ValueError("No text to translate.") + raise InvalidValueError(message="No text to translate.") if not source_language or not target_language: - raise ValueError("Source and target language codes are required.") + raise InvalidValueError(message="Source and target language codes are required.") try: translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 21d095f3d..3bf8e3d83 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,6 +1,9 @@ import numpy as np from typing import List, Dict, Union + +from cognee.exceptions import InvalidValueError +from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph @@ -29,7 +32,7 @@ class CogneeGraph(CogneeAbstractGraph): if node.id not in self.nodes: self.nodes[node.id] = node else: - raise ValueError(f"Node with id {node.id} already exists.") + raise EntityAlreadyExistsError(message=f"Node with id {node.id} already exists.") def add_edge(self, edge: Edge) -> None: if edge not in self.edges: @@ -37,7 +40,7 @@ class CogneeGraph(CogneeAbstractGraph): edge.node1.add_skeleton_edge(edge) edge.node2.add_skeleton_edge(edge) else: - raise ValueError(f"Edge {edge} already exists in the graph.") + raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.") def get_node(self, node_id: str) -> Node: return self.nodes.get(node_id, None) @@ -47,7 +50,7 @@ class CogneeGraph(CogneeAbstractGraph): if node: return node.skeleton_edges else: - raise ValueError(f"Node with id {node_id} does not exist.") + raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.") def get_edges(self)-> List[Edge]: return self.edges @@ -62,7 +65,7 @@ class CogneeGraph(CogneeAbstractGraph): memory_fragment_filter = []) -> None: if node_dimension < 1 or edge_dimension < 1: - raise ValueError("Dimensions must be positive integers") + raise InvalidValueError(message="Dimensions must be positive integers") try: if len(memory_fragment_filter) == 0: @@ -71,9 +74,9 @@ class CogneeGraph(CogneeAbstractGraph): nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter) if not nodes_data: - raise ValueError("No node data retrieved from the database.") + raise EntityNotFoundError(message="No node data retrieved from the database.") if not edges_data: - raise ValueError("No edge data retrieved from the database.") + raise EntityNotFoundError(message="No edge data retrieved from the database.") for node_id, properties in nodes_data: node_attributes = {key: properties.get(key) for key in node_properties_to_project} @@ -93,7 +96,7 @@ class CogneeGraph(CogneeAbstractGraph): target_node.add_skeleton_edge(edge) else: - raise ValueError(f"Edge references nonexistent nodes: {source_id} -> {target_id}") + raise EntityNotFoundError(message=f"Edge references nonexistent nodes: {source_id} -> {target_id}") except (ValueError, TypeError) as e: print(f"Error projecting graph: {e}") diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index cecb0a272..09d1e84cf 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -1,6 +1,9 @@ import numpy as np from typing import List, Dict, Optional, Any, Union +from cognee.exceptions import InvalidValueError + + class Node: """ Represents a node in a graph. @@ -18,7 +21,7 @@ class Node: def __init__(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1): if dimension <= 0: - raise ValueError("Dimension must be a positive integer") + raise InvalidValueError(message="Dimension must be a positive integer") self.id = node_id self.attributes = attributes if attributes is not None else {} self.attributes["vector_distance"] = float('inf') @@ -53,7 +56,7 @@ class Node: def is_node_alive_in_dimension(self, dimension: int) -> bool: if dimension < 0 or dimension >= len(self.status): - raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") + raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") return self.status[dimension] == 1 def add_attribute(self, key: str, value: Any) -> None: @@ -90,7 +93,7 @@ class Edge: def __init__(self, node1: "Node", node2: "Node", attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1): if dimension <= 0: - raise ValueError("Dimensions must be a positive integer.") + raise InvalidValueError(message="Dimensions must be a positive integer.") self.node1 = node1 self.node2 = node2 self.attributes = attributes if attributes is not None else {} @@ -100,7 +103,7 @@ class Edge: def is_edge_alive_in_dimension(self, dimension: int) -> bool: if dimension < 0 or dimension >= len(self.status): - raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") + raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") return self.status[dimension] == 1 def add_attribute(self, key: str, value: Any) -> None: diff --git a/cognee/modules/graph/exceptions/__init__.py b/cognee/modules/graph/exceptions/__init__.py new file mode 100644 index 000000000..e8330caf3 --- /dev/null +++ b/cognee/modules/graph/exceptions/__init__.py @@ -0,0 +1,10 @@ +""" +Custom exceptions for the Cognee API. + +This module defines a set of exceptions for handling various graph errors +""" + +from .exceptions import ( + EntityNotFoundError, + EntityAlreadyExistsError, +) \ No newline at end of file diff --git a/cognee/modules/graph/exceptions/exceptions.py b/cognee/modules/graph/exceptions/exceptions.py new file mode 100644 index 000000000..af15bb616 --- /dev/null +++ b/cognee/modules/graph/exceptions/exceptions.py @@ -0,0 +1,25 @@ +from cognee.exceptions import CogneeApiError +from fastapi import status + +class EntityNotFoundError(CogneeApiError): + """Database returns nothing""" + + def __init__( + self, + message: str = "The requested entity does not exist.", + name: str = "EntityNotFoundError", + status_code=status.HTTP_404_NOT_FOUND, + ): + super().__init__(message, name, status_code) + + +class EntityAlreadyExistsError(CogneeApiError): + """Conflict detected, like trying to create a resource that already exists""" + + def __init__( + self, + message: str = "The entity already exists.", + name: str = "EntityAlreadyExistsError", + status_code=status.HTTP_409_CONFLICT, + ): + super().__init__(message, name, status_code) \ No newline at end of file diff --git a/cognee/modules/ingestion/classify.py b/cognee/modules/ingestion/classify.py index 8e8c9fb00..dbb191cc3 100644 --- a/cognee/modules/ingestion/classify.py +++ b/cognee/modules/ingestion/classify.py @@ -1,9 +1,11 @@ from io import BufferedReader from typing import Union, BinaryIO -from .exceptions import IngestionException from .data_types import TextData, BinaryData from tempfile import SpooledTemporaryFile +from cognee.modules.ingestion.exceptions import IngestionError + + def classify(data: Union[str, BinaryIO], filename: str = None): if isinstance(data, str): return TextData(data) @@ -11,4 +13,4 @@ def classify(data: Union[str, BinaryIO], filename: str = None): if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile): return BinaryData(data, data.name.split("/")[-1] if data.name else filename) - raise IngestionException(f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}") + raise IngestionError(message=f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}") diff --git a/cognee/modules/ingestion/exceptions.py b/cognee/modules/ingestion/exceptions.py deleted file mode 100644 index 0a189fb81..000000000 --- a/cognee/modules/ingestion/exceptions.py +++ /dev/null @@ -1,6 +0,0 @@ - -class IngestionException(Exception): - message: str - - def __init__(self, message: str): - self.message = message diff --git a/cognee/modules/ingestion/exceptions/__init__.py b/cognee/modules/ingestion/exceptions/__init__.py new file mode 100644 index 000000000..33d59e113 --- /dev/null +++ b/cognee/modules/ingestion/exceptions/__init__.py @@ -0,0 +1,9 @@ +""" +Custom exceptions for the Cognee API. + +This module defines a set of exceptions for handling various ingestion errors +""" + +from .exceptions import ( + IngestionError, +) \ No newline at end of file diff --git a/cognee/modules/ingestion/exceptions/exceptions.py b/cognee/modules/ingestion/exceptions/exceptions.py new file mode 100644 index 000000000..4901be110 --- /dev/null +++ b/cognee/modules/ingestion/exceptions/exceptions.py @@ -0,0 +1,11 @@ +from cognee.exceptions import CogneeApiError +from fastapi import status + +class IngestionError(CogneeApiError): + def __init__( + self, + message: str = "Type of data sent to classify not supported.", + name: str = "IngestionError", + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + ): + super().__init__(message, name, status_code) \ No newline at end of file diff --git a/cognee/modules/users/exceptions/__init__.py b/cognee/modules/users/exceptions/__init__.py new file mode 100644 index 000000000..70e6e9d2a --- /dev/null +++ b/cognee/modules/users/exceptions/__init__.py @@ -0,0 +1,11 @@ +""" +Custom exceptions for the Cognee API. + +This module defines a set of exceptions for handling various user errors +""" + +from .exceptions import ( + GroupNotFoundError, + UserNotFoundError, + PermissionDeniedError, +) \ No newline at end of file diff --git a/cognee/modules/users/exceptions/exceptions.py b/cognee/modules/users/exceptions/exceptions.py new file mode 100644 index 000000000..7dda702db --- /dev/null +++ b/cognee/modules/users/exceptions/exceptions.py @@ -0,0 +1,36 @@ +from cognee.exceptions import CogneeApiError +from fastapi import status + + +class GroupNotFoundError(CogneeApiError): + """User group not found""" + + def __init__( + self, + message: str = "User group not found.", + name: str = "GroupNotFoundError", + status_code=status.HTTP_404_NOT_FOUND, + ): + super().__init__(message, name, status_code) + + +class UserNotFoundError(CogneeApiError): + """User not found""" + + def __init__( + self, + message: str = "No user found in the system. Please create a user.", + name: str = "UserNotFoundError", + status_code=status.HTTP_404_NOT_FOUND, + ): + super().__init__(message, name, status_code) + + +class PermissionDeniedError(CogneeApiError): + def __init__( + self, + message: str = "User does not have permission on documents.", + name: str = "PermissionDeniedError", + status_code=status.HTTP_403_FORBIDDEN, + ): + super().__init__(message, name, status_code) diff --git a/cognee/modules/users/get_user_manager.py b/cognee/modules/users/get_user_manager.py index b538535ca..30410a985 100644 --- a/cognee/modules/users/get_user_manager.py +++ b/cognee/modules/users/get_user_manager.py @@ -2,13 +2,14 @@ import os import uuid from typing import Optional from fastapi import Depends, Request -from fastapi_users.exceptions import UserNotExists from fastapi_users import BaseUserManager, UUIDIDMixin, models from fastapi_users.db import SQLAlchemyUserDatabase from .get_user_db import get_user_db from .models import User from .methods import get_user +from fastapi_users.exceptions import UserNotExists + class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret") diff --git a/cognee/modules/users/permissions/methods/check_permission_on_documents.py b/cognee/modules/users/permissions/methods/check_permission_on_documents.py index c8c283e4a..f9c5a2258 100644 --- a/cognee/modules/users/permissions/methods/check_permission_on_documents.py +++ b/cognee/modules/users/permissions/methods/check_permission_on_documents.py @@ -2,6 +2,8 @@ import logging from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import joinedload + +from cognee.modules.users.exceptions import PermissionDeniedError from cognee.infrastructure.databases.relational import get_relational_engine from ...models.User import User @@ -9,11 +11,6 @@ from ...models.ACL import ACL logger = logging.getLogger(__name__) -class PermissionDeniedException(Exception): - def __init__(self, message: str): - self.message = message - super().__init__(self.message) - async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]): user_group_ids = [group.id for group in user.groups] @@ -33,4 +30,4 @@ async def check_permission_on_documents(user: User, permission_type: str, docume has_permissions = all(document_id in resource_ids for document_id in document_ids) if not has_permissions: - raise PermissionDeniedException(f"User {user.email} does not have {permission_type} permission on documents") + raise PermissionDeniedError(message=f"User {user.email} does not have {permission_type} permission on documents") diff --git a/cognee/tasks/graph/infer_data_ontology.py b/cognee/tasks/graph/infer_data_ontology.py index eea378eb1..4e11cd9af 100644 --- a/cognee/tasks/graph/infer_data_ontology.py +++ b/cognee/tasks/graph/infer_data_ontology.py @@ -4,12 +4,15 @@ import csv import json import logging from datetime import datetime, timezone +from fastapi import status from typing import Any, Dict, List, Optional, Union, Type import aiofiles import pandas as pd from pydantic import BaseModel +from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError +from cognee.modules.ingestion.exceptions import IngestionError from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.data.chunking.config import get_chunk_config @@ -75,9 +78,10 @@ class OntologyEngine: reader = csv.DictReader(content.splitlines()) return list(reader) else: - raise ValueError("Unsupported file format") + raise IngestionError(message="Unsupported file format") except Exception as e: - raise RuntimeError(f"Failed to load data from {file_path}: {e}") + raise IngestionError(message=f"Failed to load data from {file_path}: {e}", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) async def add_graph_ontology(self, file_path: str = None, documents: list = None): """Add graph ontology from a JSON or CSV file or infer from documents content.""" @@ -148,7 +152,7 @@ class OntologyEngine: if node_id in valid_ids: await graph_client.add_node(node_id, node_data) if node_id not in valid_ids: - raise ValueError(f"Node ID {node_id} not found in the dataset") + raise EntityNotFoundError(message=f"Node ID {node_id} not found in the dataset") if pd.notna(row.get("relationship_source")) and pd.notna(row.get("relationship_target")): await graph_client.add_edge( row["relationship_source"], diff --git a/cognee/tasks/ingestion/save_data_item_to_storage.py b/cognee/tasks/ingestion/save_data_item_to_storage.py index 4782f271f..e2a7c8ee7 100644 --- a/cognee/tasks/ingestion/save_data_item_to_storage.py +++ b/cognee/tasks/ingestion/save_data_item_to_storage.py @@ -1,4 +1,6 @@ from typing import Union, BinaryIO + +from cognee.modules.ingestion.exceptions import IngestionError from cognee.modules.ingestion import save_data_to_file def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str) -> str: @@ -15,6 +17,6 @@ def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str else: file_path = save_data_to_file(data_item, dataset_name) else: - raise ValueError(f"Data type not supported: {type(data_item)}") + raise IngestionError(message=f"Data type not supported: {type(data_item)}") return file_path \ No newline at end of file diff --git a/cognee/tasks/ingestion/save_data_item_with_metadata_to_storage.py b/cognee/tasks/ingestion/save_data_item_with_metadata_to_storage.py index ec29edb89..bf5a1f093 100644 --- a/cognee/tasks/ingestion/save_data_item_with_metadata_to_storage.py +++ b/cognee/tasks/ingestion/save_data_item_with_metadata_to_storage.py @@ -1,4 +1,6 @@ from typing import Union, BinaryIO, Any + +from cognee.modules.ingestion.exceptions import IngestionError from cognee.modules.ingestion import save_data_to_file def save_data_item_with_metadata_to_storage(data_item: Union[BinaryIO, str, Any], dataset_name: str) -> str: @@ -23,6 +25,6 @@ def save_data_item_with_metadata_to_storage(data_item: Union[BinaryIO, str, Any] else: file_path = save_data_to_file(data_item, dataset_name) else: - raise ValueError(f"Data type not supported: {type(data_item)}") + raise IngestionError(message=f"Data type not supported: {type(data_item)}") return file_path \ No newline at end of file diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index a3755a58f..1d9bad07c 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from cognee.exceptions import InvalidValueError from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node @@ -15,7 +16,7 @@ def test_node_initialization(): def test_node_invalid_dimension(): """Test that initializing a Node with a non-positive dimension raises an error.""" - with pytest.raises(ValueError, match="Dimension must be a positive integer"): + with pytest.raises(InvalidValueError, match="Dimension must be a positive integer"): Node("node1", dimension=0) @@ -68,7 +69,7 @@ def test_is_node_alive_in_dimension(): def test_node_alive_invalid_dimension(): """Test that checking alive status with an invalid dimension raises an error.""" node = Node("node1", dimension=1) - with pytest.raises(ValueError, match="Dimension 1 is out of range"): + with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"): node.is_node_alive_in_dimension(1) @@ -105,7 +106,7 @@ def test_edge_invalid_dimension(): """Test that initializing an Edge with a non-positive dimension raises an error.""" node1 = Node("node1") node2 = Node("node2") - with pytest.raises(ValueError, match="Dimensions must be a positive integer."): + with pytest.raises(InvalidValueError, match="Dimensions must be a positive integer."): Edge(node1, node2, dimension=0) @@ -124,7 +125,7 @@ def test_edge_alive_invalid_dimension(): node1 = Node("node1") node2 = Node("node2") edge = Edge(node1, node2, dimension=1) - with pytest.raises(ValueError, match="Dimension 1 is out of range"): + with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"): edge.is_edge_alive_in_dimension(1) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index e3b748dab..6f6165202 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -1,5 +1,6 @@ import pytest +from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node @@ -23,7 +24,7 @@ def test_add_duplicate_node(setup_graph): graph = setup_graph node = Node("node1") graph.add_node(node) - with pytest.raises(ValueError, match="Node with id node1 already exists."): + with pytest.raises(EntityAlreadyExistsError, match="Node with id node1 already exists."): graph.add_node(node) @@ -50,7 +51,7 @@ def test_add_duplicate_edge(setup_graph): graph.add_node(node2) edge = Edge(node1, node2) graph.add_edge(edge) - with pytest.raises(ValueError, match="Edge .* already exists in the graph."): + with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."): graph.add_edge(edge) @@ -83,5 +84,5 @@ def test_get_edges_success(setup_graph): def test_get_edges_nonexistent_node(setup_graph): """Test retrieving edges for a nonexistent node raises an exception.""" graph = setup_graph - with pytest.raises(ValueError, match="Node with id nonexistent does not exist."): + with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."): graph.get_edges_from_node("nonexistent")