Cog 502 backend error handling (#32)

Rework of the cognee lib exception/error handling.

Added custom exceptions and custom exception handling.

Whenever a custom exception is raised in the cognee fastapi backend it
will be processed by the exception handler making sure exception
information is logged and proper JSONResponse is sent. No need to catch
these exception in endpoints with the goal of logging and responding.

Note: The exception handler is only used for the cognee FastAPI backend
server, using cognee as a library won't utilize this exception handling
This commit is contained in:
Igor Ilic 2024-12-02 13:47:38 +01:00 committed by GitHub
commit dd8af12aa9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 347 additions and 96 deletions

View file

@ -3,10 +3,13 @@ import os
import uvicorn import uvicorn
import logging import logging
import sentry_sdk import sentry_sdk
from fastapi import FastAPI from fastapi import FastAPI, status
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from cognee.exceptions import CogneeApiError
from traceback import format_exc
# Set up logging # Set up logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL) level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
@ -76,6 +79,26 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}), content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
) )
@app.exception_handler(CogneeApiError)
async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
detail = {}
if exc.name and exc.message and exc.status_code:
status_code = exc.status_code
detail["message"] = f"{exc.message} [{exc.name}]"
else:
# Log an error indicating the exception is improperly defined
logger.error("Improperly defined exception: %s", exc)
# Provide a default error response
detail["message"] = "An unexpected error occurred."
status_code = status.HTTP_418_IM_A_TEAPOT
# log the stack trace for easier serverside debugging
logger.error(format_exc())
return JSONResponse(
status_code=status_code, content={"detail": detail["message"]}
)
app.include_router( app.include_router(
get_auth_router(), get_auth_router(),
prefix = "/api/v1/auth", prefix = "/api/v1/auth",

View file

@ -22,11 +22,6 @@ logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock() update_status_lock = asyncio.Lock()
class PermissionDeniedException(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None): async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()

View file

@ -24,11 +24,6 @@ logger = logging.getLogger("cognify.v2")
update_status_lock = asyncio.Lock() update_status_lock = asyncio.Lock()
class PermissionDeniedException(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
async def cognify(datasets: Union[str, list[str]] = None, user: User = None): async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()

View file

@ -1,6 +1,7 @@
""" This module is used to set the configuration of the system.""" """ This module is used to set the configuration of the system."""
import os import os
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.exceptions import InvalidValueError, InvalidAttributeError
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.data.chunking.config import get_chunk_config from cognee.infrastructure.data.chunking.config import get_chunk_config
from cognee.infrastructure.databases.vector import get_vectordb_config from cognee.infrastructure.databases.vector import get_vectordb_config
@ -85,7 +86,7 @@ class config():
if hasattr(llm_config, key): if hasattr(llm_config, key):
object.__setattr__(llm_config, key, value) object.__setattr__(llm_config, key, value)
else: else:
raise AttributeError(f"'{key}' is not a valid attribute of the config.") raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
@staticmethod @staticmethod
def set_chunk_strategy(chunk_strategy: object): def set_chunk_strategy(chunk_strategy: object):
@ -123,7 +124,7 @@ class config():
if hasattr(relational_db_config, key): if hasattr(relational_db_config, key):
object.__setattr__(relational_db_config, key, value) object.__setattr__(relational_db_config, key, value)
else: else:
raise AttributeError(f"'{key}' is not a valid attribute of the config.") raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
@staticmethod @staticmethod
def set_vector_db_config(config_dict: dict): def set_vector_db_config(config_dict: dict):
@ -135,7 +136,7 @@ class config():
if hasattr(vector_db_config, key): if hasattr(vector_db_config, key):
object.__setattr__(vector_db_config, key, value) object.__setattr__(vector_db_config, key, value)
else: else:
raise AttributeError(f"'{key}' is not a valid attribute of the config.") raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
@staticmethod @staticmethod
def set_vector_db_key(db_key: str): def set_vector_db_key(db_key: str):
@ -153,7 +154,7 @@ class config():
base_config = get_base_config() base_config = get_base_config()
if "username" not in graphistry_config or "password" not in graphistry_config: if "username" not in graphistry_config or "password" not in graphistry_config:
raise ValueError("graphistry_config dictionary must contain 'username' and 'password' keys.") raise InvalidValueError(message="graphistry_config dictionary must contain 'username' and 'password' keys.")
base_config.graphistry_username = graphistry_config.get("username") base_config.graphistry_username = graphistry_config.get("username")
base_config.graphistry_password = graphistry_config.get("password") base_config.graphistry_password = graphistry_config.get("password")

View file

@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from cognee.api.DTO import OutDTO from cognee.api.DTO import OutDTO
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user from cognee.modules.users.methods import get_authenticated_user
from cognee.modules.pipelines.models import PipelineRunStatus from cognee.modules.pipelines.models import PipelineRunStatus
@ -55,9 +56,8 @@ def get_datasets_router() -> APIRouter:
dataset = await get_dataset(user.id, dataset_id) dataset = await get_dataset(user.id, dataset_id)
if dataset is None: if dataset is None:
raise HTTPException( raise EntityNotFoundError(
status_code=404, message=f"Dataset ({dataset_id}) not found."
detail=f"Dataset ({dataset_id}) not found."
) )
await delete_dataset(dataset) await delete_dataset(dataset)
@ -72,17 +72,15 @@ def get_datasets_router() -> APIRouter:
#TODO: Handle situation differently if user doesn't have permission to access data? #TODO: Handle situation differently if user doesn't have permission to access data?
if dataset is None: if dataset is None:
raise HTTPException( raise EntityNotFoundError(
status_code=404, message=f"Dataset ({dataset_id}) not found."
detail=f"Dataset ({dataset_id}) not found."
) )
data = await get_data(data_id) data = await get_data(data_id)
if data is None: if data is None:
raise HTTPException( raise EntityNotFoundError(
status_code=404, message=f"Data ({data_id}) not found."
detail=f"Dataset ({data_id}) not found."
) )
await delete_data(data) await delete_data(data)
@ -158,18 +156,13 @@ def get_datasets_router() -> APIRouter:
dataset_data = await get_dataset_data(dataset.id) dataset_data = await get_dataset_data(dataset.id)
if dataset_data is None: if dataset_data is None:
raise HTTPException(status_code=404, detail=f"No data found in dataset ({dataset_id}).") raise EntityNotFoundError(message=f"No data found in dataset ({dataset_id}).")
matching_data = [data for data in dataset_data if str(data.id) == data_id] matching_data = [data for data in dataset_data if str(data.id) == data_id]
# Check if matching_data contains an element # Check if matching_data contains an element
if len(matching_data) == 0: if len(matching_data) == 0:
return JSONResponse( raise EntityNotFoundError(message= f"Data ({data_id}) not found in dataset ({dataset_id}).")
status_code=404,
content={
"detail": f"Data ({data_id}) not found in dataset ({dataset_id})."
}
)
data = matching_data[0] data = matching_data[0]

View file

@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundError
from cognee.modules.users import get_user_db from cognee.modules.users import get_user_db
from cognee.modules.users.models import User, Group, Permission from cognee.modules.users.models import User, Group, Permission
@ -12,7 +14,7 @@ def get_permissions_router() -> APIRouter:
group = db.query(Group).filter(Group.id == group_id).first() group = db.query(Group).filter(Group.id == group_id).first()
if not group: if not group:
raise HTTPException(status_code = 404, detail = "Group not found") raise GroupNotFoundError
permission = db.query(Permission).filter(Permission.name == permission).first() permission = db.query(Permission).filter(Permission.name == permission).first()
@ -31,8 +33,10 @@ def get_permissions_router() -> APIRouter:
user = db.query(User).filter(User.id == user_id).first() user = db.query(User).filter(User.id == user_id).first()
group = db.query(Group).filter(Group.id == group_id).first() group = db.query(Group).filter(Group.id == group_id).first()
if not user or not group: if not user:
raise HTTPException(status_code = 404, detail = "User or group not found") raise UserNotFoundError
elif not group:
raise GroupNotFoundError
user.groups.append(group) user.groups.append(group)

View file

@ -9,6 +9,8 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_traverse import search_traverse from cognee.modules.search.vector.search_traverse import search_traverse
from cognee.modules.search.graph.search_summary import search_summary from cognee.modules.search.graph.search_summary import search_summary
from cognee.modules.search.graph.search_similarity import search_similarity from cognee.modules.search.graph.search_similarity import search_similarity
from cognee.exceptions import UserNotFoundError
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
@ -47,7 +49,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
user = await get_default_user() user = await get_default_user()
if user is None: if user is None:
raise PermissionError("No user found in the system. Please create a user.") raise UserNotFoundError
own_document_ids = await get_document_ids_for_user(user.id) own_document_ids = await get_document_ids_for_user(user.id)
search_params = SearchParameters(search_type = search_type, params = params) search_params = SearchParameters(search_type = search_type, params = params)

View file

@ -2,9 +2,12 @@ import json
from uuid import UUID from uuid import UUID
from enum import Enum from enum import Enum
from typing import Callable, Dict from typing import Callable, Dict
from cognee.exceptions import InvalidValueError
from cognee.modules.search.operations import log_query, log_result from cognee.modules.search.operations import log_query, log_result
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.users.exceptions import UserNotFoundError
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.modules.users.permissions.methods import get_document_ids_for_user
@ -22,7 +25,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
user = await get_default_user() user = await get_default_user()
if user is None: if user is None:
raise PermissionError("No user found in the system. Please create a user.") raise UserNotFoundError
query = await log_query(query_text, str(query_type), user.id) query = await log_query(query_text, str(query_type), user.id)
@ -52,7 +55,7 @@ async def specific_search(query_type: SearchType, query: str, user) -> list:
search_task = search_tasks.get(query_type) search_task = search_tasks.get(query_type)
if search_task is None: if search_task is None:
raise ValueError(f"Unsupported search type: {query_type}") raise InvalidValueError(message=f"Unsupported search type: {query_type}")
send_telemetry("cognee.search EXECUTION STARTED", user.id) send_telemetry("cognee.search EXECUTION STARTED", user.id)

View file

@ -0,0 +1,13 @@
"""
Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various application errors,
such as service failures, resource conflicts, and invalid operations.
"""
from .exceptions import (
CogneeApiError,
ServiceError,
InvalidValueError,
InvalidAttributeError,
)

View file

@ -0,0 +1,54 @@
from fastapi import status
import logging
logger = logging.getLogger(__name__)
class CogneeApiError(Exception):
"""Base exception class"""
def __init__(
self,
message: str = "Service is unavailable.",
name: str = "Cognee",
status_code=status.HTTP_418_IM_A_TEAPOT,
):
self.message = message
self.name = name
self.status_code = status_code
# Automatically log the exception details
logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})")
super().__init__(self.message, self.name)
class ServiceError(CogneeApiError):
"""Failures in external services or APIs, like a database or a third-party service"""
def __init__(
self,
message: str = "Service is unavailable.",
name: str = "ServiceError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
class InvalidValueError(CogneeApiError):
def __init__(
self,
message: str = "Invalid Value.",
name: str = "InvalidValueError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
class InvalidAttributeError(CogneeApiError):
def __init__(
self,
message: str = "Invalid attribute.",
name: str = "InvalidAttributeError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)

View file

@ -1,9 +1,11 @@
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from cognee.exceptions import InvalidValueError
from cognee.shared.utils import extract_pos_tags from cognee.shared.utils import extract_pos_tags
def extract_keywords(text: str) -> list[str]: def extract_keywords(text: str) -> list[str]:
if len(text) == 0: if len(text) == 0:
raise ValueError("extract_keywords cannot extract keywords from empty text.") raise InvalidValueError(message="extract_keywords cannot extract keywords from empty text.")
tags = extract_pos_tags(text) tags = extract_pos_tags(text)
nouns = [word for (word, tag) in tags if tag == "NN"] nouns = [word for (word, tag) in tags if tag == "NN"]

View file

@ -0,0 +1,10 @@
"""
Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various database errors
"""
from .exceptions import (
EntityNotFoundError,
EntityAlreadyExistsError,
)

View file

@ -0,0 +1,25 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
class EntityNotFoundError(CogneeApiError):
"""Database returns nothing"""
def __init__(
self,
message: str = "The requested entity does not exist.",
name: str = "EntityNotFoundError",
status_code=status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class EntityAlreadyExistsError(CogneeApiError):
"""Conflict detected, like trying to create a resource that already exists"""
def __init__(
self,
message: str = "The entity already exists.",
name: str = "EntityAlreadyExistsError",
status_code=status.HTTP_409_CONFLICT,
):
super().__init__(message, name, status_code)

View file

@ -4,6 +4,7 @@ from typing import Any
from uuid import UUID from uuid import UUID
from falkordb import FalkorDB from falkordb import FalkorDB
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
@ -200,7 +201,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
with_vector: bool = False, with_vector: bool = False,
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embed_data([query_text]))[0] query_vector = (await self.embed_data([query_text]))[0]

View file

@ -6,6 +6,8 @@ from contextlib import asynccontextmanager
from sqlalchemy import text, select, MetaData, Table from sqlalchemy import text, select, MetaData, Table
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from ..ModelBase import Base from ..ModelBase import Base
class SQLAlchemyAdapter(): class SQLAlchemyAdapter():
@ -116,7 +118,7 @@ class SQLAlchemyAdapter():
if table_name in Base.metadata.tables: if table_name in Base.metadata.tables:
return Base.metadata.tables[table_name] return Base.metadata.tables[table_name]
else: else:
raise ValueError(f"Table '{table_name}' not found.") raise EntityNotFoundError(message=f"Table '{table_name}' not found.")
else: else:
# Create a MetaData instance to load table information # Create a MetaData instance to load table information
metadata = MetaData() metadata = MetaData()
@ -127,7 +129,7 @@ class SQLAlchemyAdapter():
# Check if table is in list of tables for the given schema # Check if table is in list of tables for the given schema
if full_table_name in metadata.tables: if full_table_name in metadata.tables:
return metadata.tables[full_table_name] return metadata.tables[full_table_name]
raise ValueError(f"Table '{full_table_name}' not found.") raise EntityNotFoundError(message=f"Table '{full_table_name}' not found.")
async def get_table_names(self) -> List[str]: async def get_table_names(self) -> List[str]:
""" """

View file

@ -5,6 +5,8 @@ from uuid import UUID
import lancedb import lancedb
from pydantic import BaseModel from pydantic import BaseModel
from lancedb.pydantic import Vector, LanceModel from lancedb.pydantic import Vector, LanceModel
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
@ -123,7 +125,7 @@ class LanceDBAdapter(VectorDBInterface):
new_size = await collection.count_rows() new_size = await collection.count_rows()
if new_size <= original_size: if new_size <= original_size:
raise ValueError( raise InvalidValueError(message=
"LanceDB create_datapoints error: data points did not get added.") "LanceDB create_datapoints error: data points did not get added.")
@ -149,7 +151,7 @@ class LanceDBAdapter(VectorDBInterface):
query_vector: List[float] = None query_vector: List[float] = None
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
@ -179,7 +181,7 @@ class LanceDBAdapter(VectorDBInterface):
normalized: bool = True normalized: bool = True
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

View file

@ -6,6 +6,8 @@ from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete from sqlalchemy import JSON, Column, Table, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data from .serialize_data import serialize_data
@ -156,7 +158,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if collection_name in Base.metadata.tables: if collection_name in Base.metadata.tables:
return Base.metadata.tables[collection_name] return Base.metadata.tables[collection_name]
else: else:
raise ValueError(f"Table '{collection_name}' not found.") raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
@ -230,7 +232,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
with_vector: bool = False, with_vector: bool = False,
) -> List[ScoredResult]: ) -> List[ScoredResult]:
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

View file

@ -3,6 +3,7 @@ from uuid import UUID
from typing import List, Dict, Optional from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
@ -186,7 +187,7 @@ class QDrantAdapter(VectorDBInterface):
with_vector: bool = False with_vector: bool = False
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
client = self.get_qdrant_client() client = self.get_qdrant_client()

View file

@ -3,6 +3,7 @@ import logging
from typing import List, Optional from typing import List, Optional
from uuid import UUID from uuid import UUID
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
@ -194,7 +195,7 @@ class WeaviateAdapter(VectorDBInterface):
import weaviate.classes as wvc import weaviate.classes as wvc
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_vector is None: if query_vector is None:
query_vector = (await self.embed_data([query_text]))[0] query_vector = (await self.embed_data([query_text]))[0]

View file

@ -3,6 +3,8 @@ from pydantic import BaseModel
import instructor import instructor
from tenacity import retry, stop_after_attempt from tenacity import retry, stop_after_attempt
import anthropic import anthropic
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
@ -45,7 +47,7 @@ class AnthropicAdapter(LLMInterface):
if not text_input: if not text_input:
text_input = "No user input provided." text_input = "No user input provided."
if not system_prompt: if not system_prompt:
raise ValueError("No system prompt path provided.") raise InvalidValueError(message="No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt) system_prompt = read_query_prompt(system_prompt)

View file

@ -5,6 +5,8 @@ from pydantic import BaseModel
import instructor import instructor
from tenacity import retry, stop_after_attempt from tenacity import retry, stop_after_attempt
import openai import openai
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.shared.data_models import MonitoringTool from cognee.shared.data_models import MonitoringTool
@ -128,7 +130,7 @@ class GenericAPIAdapter(LLMInterface):
if not text_input: if not text_input:
text_input = "No user input provided." text_input = "No user input provided."
if not system_prompt: if not system_prompt:
raise ValueError("No system prompt path provided.") raise InvalidValueError(message="No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt) system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None

View file

@ -1,5 +1,7 @@
"""Get the LLM client.""" """Get the LLM client."""
from enum import Enum from enum import Enum
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm import get_llm_config from cognee.infrastructure.llm import get_llm_config
# Define an Enum for LLM Providers # Define an Enum for LLM Providers
@ -17,7 +19,7 @@ def get_llm_client():
if provider == LLMProvider.OPENAI: if provider == LLMProvider.OPENAI:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.") raise InvalidValueError(message="LLM API key is not set.")
from .openai.adapter import OpenAIAdapter from .openai.adapter import OpenAIAdapter
@ -32,7 +34,7 @@ def get_llm_client():
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.") raise InvalidValueError(message="LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
@ -43,10 +45,10 @@ def get_llm_client():
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.") raise InvalidValueError(message="LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else: else:
raise ValueError(f"Unsupported LLM provider: {provider}") raise InvalidValueError(message=f"Unsupported LLM provider: {provider}")

View file

@ -7,6 +7,7 @@ import litellm
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
@ -127,7 +128,7 @@ class OpenAIAdapter(LLMInterface):
if not text_input: if not text_input:
text_input = "No user input provided." text_input = "No user input provided."
if not system_prompt: if not system_prompt:
raise ValueError("No system prompt path provided.") raise InvalidValueError(message="No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt) system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None

View file

@ -1,3 +1,4 @@
from cognee.exceptions import InvalidAttributeError
from cognee.modules.data.models import Data from cognee.modules.data.models import Data
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
@ -12,7 +13,7 @@ async def delete_data(data: Data):
ValueError: If the data object is invalid. ValueError: If the data object is invalid.
""" """
if not hasattr(data, '__tablename__'): if not hasattr(data, '__tablename__'):
raise ValueError("The provided data object is missing the required '__tablename__' attribute.") raise InvalidAttributeError(message="The provided data object is missing the required '__tablename__' attribute.")
db_engine = get_relational_engine() db_engine = get_relational_engine()

View file

@ -1,5 +1,7 @@
import logging import logging
from cognee.exceptions import InvalidValueError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"): async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
@ -18,10 +20,10 @@ async def translate_text(text, source_language: str = "sr", target_language: str
from botocore.exceptions import BotoCoreError, ClientError from botocore.exceptions import BotoCoreError, ClientError
if not text: if not text:
raise ValueError("No text to translate.") raise InvalidValueError(message="No text to translate.")
if not source_language or not target_language: if not source_language or not target_language:
raise ValueError("Source and target language codes are required.") raise InvalidValueError(message="Source and target language codes are required.")
try: try:
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True) translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)

View file

@ -1,6 +1,9 @@
import numpy as np import numpy as np
from typing import List, Dict, Union from typing import List, Dict, Union
from cognee.exceptions import InvalidValueError
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
@ -29,7 +32,7 @@ class CogneeGraph(CogneeAbstractGraph):
if node.id not in self.nodes: if node.id not in self.nodes:
self.nodes[node.id] = node self.nodes[node.id] = node
else: else:
raise ValueError(f"Node with id {node.id} already exists.") raise EntityAlreadyExistsError(message=f"Node with id {node.id} already exists.")
def add_edge(self, edge: Edge) -> None: def add_edge(self, edge: Edge) -> None:
if edge not in self.edges: if edge not in self.edges:
@ -37,7 +40,7 @@ class CogneeGraph(CogneeAbstractGraph):
edge.node1.add_skeleton_edge(edge) edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge) edge.node2.add_skeleton_edge(edge)
else: else:
raise ValueError(f"Edge {edge} already exists in the graph.") raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.")
def get_node(self, node_id: str) -> Node: def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None) return self.nodes.get(node_id, None)
@ -47,7 +50,7 @@ class CogneeGraph(CogneeAbstractGraph):
if node: if node:
return node.skeleton_edges return node.skeleton_edges
else: else:
raise ValueError(f"Node with id {node_id} does not exist.") raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
def get_edges(self)-> List[Edge]: def get_edges(self)-> List[Edge]:
return self.edges return self.edges
@ -62,7 +65,7 @@ class CogneeGraph(CogneeAbstractGraph):
memory_fragment_filter = []) -> None: memory_fragment_filter = []) -> None:
if node_dimension < 1 or edge_dimension < 1: if node_dimension < 1 or edge_dimension < 1:
raise ValueError("Dimensions must be positive integers") raise InvalidValueError(message="Dimensions must be positive integers")
try: try:
if len(memory_fragment_filter) == 0: if len(memory_fragment_filter) == 0:
@ -71,9 +74,9 @@ class CogneeGraph(CogneeAbstractGraph):
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter) nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
if not nodes_data: if not nodes_data:
raise ValueError("No node data retrieved from the database.") raise EntityNotFoundError(message="No node data retrieved from the database.")
if not edges_data: if not edges_data:
raise ValueError("No edge data retrieved from the database.") raise EntityNotFoundError(message="No edge data retrieved from the database.")
for node_id, properties in nodes_data: for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project} node_attributes = {key: properties.get(key) for key in node_properties_to_project}
@ -93,7 +96,7 @@ class CogneeGraph(CogneeAbstractGraph):
target_node.add_skeleton_edge(edge) target_node.add_skeleton_edge(edge)
else: else:
raise ValueError(f"Edge references nonexistent nodes: {source_id} -> {target_id}") raise EntityNotFoundError(message=f"Edge references nonexistent nodes: {source_id} -> {target_id}")
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
print(f"Error projecting graph: {e}") print(f"Error projecting graph: {e}")

View file

@ -1,6 +1,9 @@
import numpy as np import numpy as np
from typing import List, Dict, Optional, Any, Union from typing import List, Dict, Optional, Any, Union
from cognee.exceptions import InvalidValueError
class Node: class Node:
""" """
Represents a node in a graph. Represents a node in a graph.
@ -18,7 +21,7 @@ class Node:
def __init__(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1): def __init__(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1):
if dimension <= 0: if dimension <= 0:
raise ValueError("Dimension must be a positive integer") raise InvalidValueError(message="Dimension must be a positive integer")
self.id = node_id self.id = node_id
self.attributes = attributes if attributes is not None else {} self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float('inf') self.attributes["vector_distance"] = float('inf')
@ -53,7 +56,7 @@ class Node:
def is_node_alive_in_dimension(self, dimension: int) -> bool: def is_node_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status): if dimension < 0 or dimension >= len(self.status):
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1 return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None: def add_attribute(self, key: str, value: Any) -> None:
@ -90,7 +93,7 @@ class Edge:
def __init__(self, node1: "Node", node2: "Node", attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1): def __init__(self, node1: "Node", node2: "Node", attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1):
if dimension <= 0: if dimension <= 0:
raise ValueError("Dimensions must be a positive integer.") raise InvalidValueError(message="Dimensions must be a positive integer.")
self.node1 = node1 self.node1 = node1
self.node2 = node2 self.node2 = node2
self.attributes = attributes if attributes is not None else {} self.attributes = attributes if attributes is not None else {}
@ -100,7 +103,7 @@ class Edge:
def is_edge_alive_in_dimension(self, dimension: int) -> bool: def is_edge_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status): if dimension < 0 or dimension >= len(self.status):
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1 return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None: def add_attribute(self, key: str, value: Any) -> None:

View file

@ -0,0 +1,10 @@
"""
Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various graph errors
"""
from .exceptions import (
EntityNotFoundError,
EntityAlreadyExistsError,
)

View file

@ -0,0 +1,25 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
class EntityNotFoundError(CogneeApiError):
"""Database returns nothing"""
def __init__(
self,
message: str = "The requested entity does not exist.",
name: str = "EntityNotFoundError",
status_code=status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class EntityAlreadyExistsError(CogneeApiError):
"""Conflict detected, like trying to create a resource that already exists"""
def __init__(
self,
message: str = "The entity already exists.",
name: str = "EntityAlreadyExistsError",
status_code=status.HTTP_409_CONFLICT,
):
super().__init__(message, name, status_code)

View file

@ -1,9 +1,11 @@
from io import BufferedReader from io import BufferedReader
from typing import Union, BinaryIO from typing import Union, BinaryIO
from .exceptions import IngestionException
from .data_types import TextData, BinaryData from .data_types import TextData, BinaryData
from tempfile import SpooledTemporaryFile from tempfile import SpooledTemporaryFile
from cognee.modules.ingestion.exceptions import IngestionError
def classify(data: Union[str, BinaryIO], filename: str = None): def classify(data: Union[str, BinaryIO], filename: str = None):
if isinstance(data, str): if isinstance(data, str):
return TextData(data) return TextData(data)
@ -11,4 +13,4 @@ def classify(data: Union[str, BinaryIO], filename: str = None):
if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile): if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile):
return BinaryData(data, data.name.split("/")[-1] if data.name else filename) return BinaryData(data, data.name.split("/")[-1] if data.name else filename)
raise IngestionException(f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}") raise IngestionError(message=f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}")

View file

@ -1,6 +0,0 @@
class IngestionException(Exception):
message: str
def __init__(self, message: str):
self.message = message

View file

@ -0,0 +1,9 @@
"""
Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various ingestion errors
"""
from .exceptions import (
IngestionError,
)

View file

@ -0,0 +1,11 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
class IngestionError(CogneeApiError):
def __init__(
self,
message: str = "Type of data sent to classify not supported.",
name: str = "IngestionError",
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
):
super().__init__(message, name, status_code)

View file

@ -0,0 +1,11 @@
"""
Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various user errors
"""
from .exceptions import (
GroupNotFoundError,
UserNotFoundError,
PermissionDeniedError,
)

View file

@ -0,0 +1,36 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
class GroupNotFoundError(CogneeApiError):
"""User group not found"""
def __init__(
self,
message: str = "User group not found.",
name: str = "GroupNotFoundError",
status_code=status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class UserNotFoundError(CogneeApiError):
"""User not found"""
def __init__(
self,
message: str = "No user found in the system. Please create a user.",
name: str = "UserNotFoundError",
status_code=status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class PermissionDeniedError(CogneeApiError):
def __init__(
self,
message: str = "User does not have permission on documents.",
name: str = "PermissionDeniedError",
status_code=status.HTTP_403_FORBIDDEN,
):
super().__init__(message, name, status_code)

View file

@ -2,13 +2,14 @@ import os
import uuid import uuid
from typing import Optional from typing import Optional
from fastapi import Depends, Request from fastapi import Depends, Request
from fastapi_users.exceptions import UserNotExists
from fastapi_users import BaseUserManager, UUIDIDMixin, models from fastapi_users import BaseUserManager, UUIDIDMixin, models
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from .get_user_db import get_user_db from .get_user_db import get_user_db
from .models import User from .models import User
from .methods import get_user from .methods import get_user
from fastapi_users.exceptions import UserNotExists
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret") reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")

View file

@ -2,6 +2,8 @@ import logging
from uuid import UUID from uuid import UUID
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from cognee.modules.users.exceptions import PermissionDeniedError
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from ...models.User import User from ...models.User import User
@ -9,11 +11,6 @@ from ...models.ACL import ACL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PermissionDeniedException(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]): async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
user_group_ids = [group.id for group in user.groups] user_group_ids = [group.id for group in user.groups]
@ -33,4 +30,4 @@ async def check_permission_on_documents(user: User, permission_type: str, docume
has_permissions = all(document_id in resource_ids for document_id in document_ids) has_permissions = all(document_id in resource_ids for document_id in document_ids)
if not has_permissions: if not has_permissions:
raise PermissionDeniedException(f"User {user.email} does not have {permission_type} permission on documents") raise PermissionDeniedError(message=f"User {user.email} does not have {permission_type} permission on documents")

View file

@ -4,12 +4,15 @@ import csv
import json import json
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from fastapi import status
from typing import Any, Dict, List, Optional, Union, Type from typing import Any, Dict, List, Optional, Union, Type
import aiofiles import aiofiles
import pandas as pd import pandas as pd
from pydantic import BaseModel from pydantic import BaseModel
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.ingestion.exceptions import IngestionError
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.data.chunking.config import get_chunk_config from cognee.infrastructure.data.chunking.config import get_chunk_config
@ -75,9 +78,10 @@ class OntologyEngine:
reader = csv.DictReader(content.splitlines()) reader = csv.DictReader(content.splitlines())
return list(reader) return list(reader)
else: else:
raise ValueError("Unsupported file format") raise IngestionError(message="Unsupported file format")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load data from {file_path}: {e}") raise IngestionError(message=f"Failed to load data from {file_path}: {e}",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
async def add_graph_ontology(self, file_path: str = None, documents: list = None): async def add_graph_ontology(self, file_path: str = None, documents: list = None):
"""Add graph ontology from a JSON or CSV file or infer from documents content.""" """Add graph ontology from a JSON or CSV file or infer from documents content."""
@ -148,7 +152,7 @@ class OntologyEngine:
if node_id in valid_ids: if node_id in valid_ids:
await graph_client.add_node(node_id, node_data) await graph_client.add_node(node_id, node_data)
if node_id not in valid_ids: if node_id not in valid_ids:
raise ValueError(f"Node ID {node_id} not found in the dataset") raise EntityNotFoundError(message=f"Node ID {node_id} not found in the dataset")
if pd.notna(row.get("relationship_source")) and pd.notna(row.get("relationship_target")): if pd.notna(row.get("relationship_source")) and pd.notna(row.get("relationship_target")):
await graph_client.add_edge( await graph_client.add_edge(
row["relationship_source"], row["relationship_source"],

View file

@ -1,4 +1,6 @@
from typing import Union, BinaryIO from typing import Union, BinaryIO
from cognee.modules.ingestion.exceptions import IngestionError
from cognee.modules.ingestion import save_data_to_file from cognee.modules.ingestion import save_data_to_file
def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str) -> str: def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str) -> str:
@ -15,6 +17,6 @@ def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str
else: else:
file_path = save_data_to_file(data_item, dataset_name) file_path = save_data_to_file(data_item, dataset_name)
else: else:
raise ValueError(f"Data type not supported: {type(data_item)}") raise IngestionError(message=f"Data type not supported: {type(data_item)}")
return file_path return file_path

View file

@ -1,4 +1,6 @@
from typing import Union, BinaryIO, Any from typing import Union, BinaryIO, Any
from cognee.modules.ingestion.exceptions import IngestionError
from cognee.modules.ingestion import save_data_to_file from cognee.modules.ingestion import save_data_to_file
def save_data_item_with_metadata_to_storage(data_item: Union[BinaryIO, str, Any], dataset_name: str) -> str: def save_data_item_with_metadata_to_storage(data_item: Union[BinaryIO, str, Any], dataset_name: str) -> str:
@ -23,6 +25,6 @@ def save_data_item_with_metadata_to_storage(data_item: Union[BinaryIO, str, Any]
else: else:
file_path = save_data_to_file(data_item, dataset_name) file_path = save_data_to_file(data_item, dataset_name)
else: else:
raise ValueError(f"Data type not supported: {type(data_item)}") raise IngestionError(message=f"Data type not supported: {type(data_item)}")
return file_path return file_path

View file

@ -1,6 +1,7 @@
import numpy as np import numpy as np
import pytest import pytest
from cognee.exceptions import InvalidValueError
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
@ -15,7 +16,7 @@ def test_node_initialization():
def test_node_invalid_dimension(): def test_node_invalid_dimension():
"""Test that initializing a Node with a non-positive dimension raises an error.""" """Test that initializing a Node with a non-positive dimension raises an error."""
with pytest.raises(ValueError, match="Dimension must be a positive integer"): with pytest.raises(InvalidValueError, match="Dimension must be a positive integer"):
Node("node1", dimension=0) Node("node1", dimension=0)
@ -68,7 +69,7 @@ def test_is_node_alive_in_dimension():
def test_node_alive_invalid_dimension(): def test_node_alive_invalid_dimension():
"""Test that checking alive status with an invalid dimension raises an error.""" """Test that checking alive status with an invalid dimension raises an error."""
node = Node("node1", dimension=1) node = Node("node1", dimension=1)
with pytest.raises(ValueError, match="Dimension 1 is out of range"): with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"):
node.is_node_alive_in_dimension(1) node.is_node_alive_in_dimension(1)
@ -105,7 +106,7 @@ def test_edge_invalid_dimension():
"""Test that initializing an Edge with a non-positive dimension raises an error.""" """Test that initializing an Edge with a non-positive dimension raises an error."""
node1 = Node("node1") node1 = Node("node1")
node2 = Node("node2") node2 = Node("node2")
with pytest.raises(ValueError, match="Dimensions must be a positive integer."): with pytest.raises(InvalidValueError, match="Dimensions must be a positive integer."):
Edge(node1, node2, dimension=0) Edge(node1, node2, dimension=0)
@ -124,7 +125,7 @@ def test_edge_alive_invalid_dimension():
node1 = Node("node1") node1 = Node("node1")
node2 = Node("node2") node2 = Node("node2")
edge = Edge(node1, node2, dimension=1) edge = Edge(node1, node2, dimension=1)
with pytest.raises(ValueError, match="Dimension 1 is out of range"): with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"):
edge.is_edge_alive_in_dimension(1) edge.is_edge_alive_in_dimension(1)

View file

@ -1,5 +1,6 @@
import pytest import pytest
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
@ -23,7 +24,7 @@ def test_add_duplicate_node(setup_graph):
graph = setup_graph graph = setup_graph
node = Node("node1") node = Node("node1")
graph.add_node(node) graph.add_node(node)
with pytest.raises(ValueError, match="Node with id node1 already exists."): with pytest.raises(EntityAlreadyExistsError, match="Node with id node1 already exists."):
graph.add_node(node) graph.add_node(node)
@ -50,7 +51,7 @@ def test_add_duplicate_edge(setup_graph):
graph.add_node(node2) graph.add_node(node2)
edge = Edge(node1, node2) edge = Edge(node1, node2)
graph.add_edge(edge) graph.add_edge(edge)
with pytest.raises(ValueError, match="Edge .* already exists in the graph."): with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."):
graph.add_edge(edge) graph.add_edge(edge)
@ -83,5 +84,5 @@ def test_get_edges_success(setup_graph):
def test_get_edges_nonexistent_node(setup_graph): def test_get_edges_nonexistent_node(setup_graph):
"""Test retrieving edges for a nonexistent node raises an exception.""" """Test retrieving edges for a nonexistent node raises an exception."""
graph = setup_graph graph = setup_graph
with pytest.raises(ValueError, match="Node with id nonexistent does not exist."): with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
graph.get_edges_from_node("nonexistent") graph.get_edges_from_node("nonexistent")