Merge remote-tracking branch 'origin/main' into code-graph
This commit is contained in:
commit
e07364fc25
46 changed files with 455 additions and 181 deletions
|
|
@ -3,10 +3,13 @@ import os
|
|||
import uvicorn
|
||||
import logging
|
||||
import sentry_sdk
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from cognee.exceptions import CogneeApiError
|
||||
from traceback import format_exc
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
|
|
@ -76,6 +79,26 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
|||
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||
)
|
||||
|
||||
@app.exception_handler(CogneeApiError)
|
||||
async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
||||
detail = {}
|
||||
|
||||
if exc.name and exc.message and exc.status_code:
|
||||
status_code = exc.status_code
|
||||
detail["message"] = f"{exc.message} [{exc.name}]"
|
||||
else:
|
||||
# Log an error indicating the exception is improperly defined
|
||||
logger.error("Improperly defined exception: %s", exc)
|
||||
# Provide a default error response
|
||||
detail["message"] = "An unexpected error occurred."
|
||||
status_code = status.HTTP_418_IM_A_TEAPOT
|
||||
|
||||
# log the stack trace for easier serverside debugging
|
||||
logger.error(format_exc())
|
||||
return JSONResponse(
|
||||
status_code=status_code, content={"detail": detail["message"]}
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_auth_router(),
|
||||
prefix = "/api/v1/auth",
|
||||
|
|
|
|||
|
|
@ -22,11 +22,6 @@ logger = logging.getLogger("code_graph_pipeline")
|
|||
|
||||
update_status_lock = asyncio.Lock()
|
||||
|
||||
class PermissionDeniedException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
|
|
|||
|
|
@ -24,11 +24,6 @@ logger = logging.getLogger("cognify.v2")
|
|||
|
||||
update_status_lock = asyncio.Lock()
|
||||
|
||||
class PermissionDeniedException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
""" This module is used to set the configuration of the system."""
|
||||
import os
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.exceptions import InvalidValueError, InvalidAttributeError
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
|
|
@ -85,7 +86,7 @@ class config():
|
|||
if hasattr(llm_config, key):
|
||||
object.__setattr__(llm_config, key, value)
|
||||
else:
|
||||
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
|
||||
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||
|
||||
@staticmethod
|
||||
def set_chunk_strategy(chunk_strategy: object):
|
||||
|
|
@ -123,7 +124,7 @@ class config():
|
|||
if hasattr(relational_db_config, key):
|
||||
object.__setattr__(relational_db_config, key, value)
|
||||
else:
|
||||
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
|
||||
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||
|
||||
@staticmethod
|
||||
def set_vector_db_config(config_dict: dict):
|
||||
|
|
@ -135,7 +136,7 @@ class config():
|
|||
if hasattr(vector_db_config, key):
|
||||
object.__setattr__(vector_db_config, key, value)
|
||||
else:
|
||||
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
|
||||
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||
|
||||
@staticmethod
|
||||
def set_vector_db_key(db_key: str):
|
||||
|
|
@ -153,7 +154,7 @@ class config():
|
|||
base_config = get_base_config()
|
||||
|
||||
if "username" not in graphistry_config or "password" not in graphistry_config:
|
||||
raise ValueError("graphistry_config dictionary must contain 'username' and 'password' keys.")
|
||||
raise InvalidValueError(message="graphistry_config dictionary must contain 'username' and 'password' keys.")
|
||||
|
||||
base_config.graphistry_username = graphistry_config.get("username")
|
||||
base_config.graphistry_password = graphistry_config.get("password")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse, FileResponse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from cognee.api.DTO import OutDTO
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
|
|
@ -55,9 +56,8 @@ def get_datasets_router() -> APIRouter:
|
|||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Dataset ({dataset_id}) not found."
|
||||
raise EntityNotFoundError(
|
||||
message=f"Dataset ({dataset_id}) not found."
|
||||
)
|
||||
|
||||
await delete_dataset(dataset)
|
||||
|
|
@ -72,17 +72,15 @@ def get_datasets_router() -> APIRouter:
|
|||
|
||||
#TODO: Handle situation differently if user doesn't have permission to access data?
|
||||
if dataset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Dataset ({dataset_id}) not found."
|
||||
raise EntityNotFoundError(
|
||||
message=f"Dataset ({dataset_id}) not found."
|
||||
)
|
||||
|
||||
data = await get_data(data_id)
|
||||
|
||||
if data is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Dataset ({data_id}) not found."
|
||||
raise EntityNotFoundError(
|
||||
message=f"Data ({data_id}) not found."
|
||||
)
|
||||
|
||||
await delete_data(data)
|
||||
|
|
@ -158,18 +156,13 @@ def get_datasets_router() -> APIRouter:
|
|||
dataset_data = await get_dataset_data(dataset.id)
|
||||
|
||||
if dataset_data is None:
|
||||
raise HTTPException(status_code=404, detail=f"No data found in dataset ({dataset_id}).")
|
||||
raise EntityNotFoundError(message=f"No data found in dataset ({dataset_id}).")
|
||||
|
||||
matching_data = [data for data in dataset_data if str(data.id) == data_id]
|
||||
|
||||
# Check if matching_data contains an element
|
||||
if len(matching_data) == 0:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"detail": f"Data ({data_id}) not found in dataset ({dataset_id})."
|
||||
}
|
||||
)
|
||||
raise EntityNotFoundError(message= f"Data ({data_id}) not found in dataset ({dataset_id}).")
|
||||
|
||||
data = matching_data[0]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundError
|
||||
from cognee.modules.users import get_user_db
|
||||
from cognee.modules.users.models import User, Group, Permission
|
||||
|
||||
|
|
@ -12,7 +14,7 @@ def get_permissions_router() -> APIRouter:
|
|||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code = 404, detail = "Group not found")
|
||||
raise GroupNotFoundError
|
||||
|
||||
permission = db.query(Permission).filter(Permission.name == permission).first()
|
||||
|
||||
|
|
@ -31,8 +33,10 @@ def get_permissions_router() -> APIRouter:
|
|||
user = db.query(User).filter(User.id == user_id).first()
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
|
||||
if not user or not group:
|
||||
raise HTTPException(status_code = 404, detail = "User or group not found")
|
||||
if not user:
|
||||
raise UserNotFoundError
|
||||
elif not group:
|
||||
raise GroupNotFoundError
|
||||
|
||||
user.groups.append(group)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
|
|||
from cognee.modules.search.vector.search_traverse import search_traverse
|
||||
from cognee.modules.search.graph.search_summary import search_summary
|
||||
from cognee.modules.search.graph.search_similarity import search_similarity
|
||||
|
||||
from cognee.exceptions import UserNotFoundError
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
|
@ -47,7 +49,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
|
|||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
raise UserNotFoundError
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
search_params = SearchParameters(search_type = search_type, params = params)
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@ import json
|
|||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.exceptions import UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
|
|
@ -22,7 +25,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
|
|||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
raise UserNotFoundError
|
||||
|
||||
query = await log_query(query_text, str(query_type), user.id)
|
||||
|
||||
|
|
@ -52,7 +55,7 @@ async def specific_search(query_type: SearchType, query: str, user) -> list:
|
|||
search_task = search_tasks.get(query_type)
|
||||
|
||||
if search_task is None:
|
||||
raise ValueError(f"Unsupported search type: {query_type}")
|
||||
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
|
||||
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
|
|
|
|||
13
cognee/exceptions/__init__.py
Normal file
13
cognee/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""
|
||||
Custom exceptions for the Cognee API.
|
||||
|
||||
This module defines a set of exceptions for handling various application errors,
|
||||
such as service failures, resource conflicts, and invalid operations.
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
CogneeApiError,
|
||||
ServiceError,
|
||||
InvalidValueError,
|
||||
InvalidAttributeError,
|
||||
)
|
||||
54
cognee/exceptions/exceptions.py
Normal file
54
cognee/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from fastapi import status
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CogneeApiError(Exception):
|
||||
"""Base exception class"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Service is unavailable.",
|
||||
name: str = "Cognee",
|
||||
status_code=status.HTTP_418_IM_A_TEAPOT,
|
||||
):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.status_code = status_code
|
||||
|
||||
# Automatically log the exception details
|
||||
logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})")
|
||||
|
||||
super().__init__(self.message, self.name)
|
||||
|
||||
|
||||
class ServiceError(CogneeApiError):
|
||||
"""Failures in external services or APIs, like a database or a third-party service"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Service is unavailable.",
|
||||
name: str = "ServiceError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class InvalidValueError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid Value.",
|
||||
name: str = "InvalidValueError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class InvalidAttributeError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid attribute.",
|
||||
name: str = "InvalidAttributeError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.shared.utils import extract_pos_tags
|
||||
|
||||
def extract_keywords(text: str) -> list[str]:
|
||||
if len(text) == 0:
|
||||
raise ValueError("extract_keywords cannot extract keywords from empty text.")
|
||||
raise InvalidValueError(message="extract_keywords cannot extract keywords from empty text.")
|
||||
|
||||
tags = extract_pos_tags(text)
|
||||
nouns = [word for (word, tag) in tags if tag == "NN"]
|
||||
|
|
|
|||
10
cognee/infrastructure/databases/exceptions/__init__.py
Normal file
10
cognee/infrastructure/databases/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""
|
||||
Custom exceptions for the Cognee API.
|
||||
|
||||
This module defines a set of exceptions for handling various database errors
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
EntityNotFoundError,
|
||||
EntityAlreadyExistsError,
|
||||
)
|
||||
25
cognee/infrastructure/databases/exceptions/exceptions.py
Normal file
25
cognee/infrastructure/databases/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
class EntityNotFoundError(CogneeApiError):
|
||||
"""Database returns nothing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The requested entity does not exist.",
|
||||
name: str = "EntityNotFoundError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class EntityAlreadyExistsError(CogneeApiError):
|
||||
"""Conflict detected, like trying to create a resource that already exists"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The entity already exists.",
|
||||
name: str = "EntityAlreadyExistsError",
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
@ -5,6 +5,7 @@ from uuid import UUID
|
|||
from textwrap import dedent
|
||||
from falkordb import FalkorDB
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
||||
|
|
@ -243,7 +244,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
with_vector: bool = False,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embed_data([query_text]))[0]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from functools import lru_cache
|
||||
|
||||
from .config import get_relational_config
|
||||
from .create_relational_engine import create_relational_engine
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_relational_engine():
|
||||
relational_config = get_relational_config()
|
||||
|
||||
return create_relational_engine(**relational_config.to_dict())
|
||||
return create_relational_engine(**relational_config.to_dict())
|
||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import text, select, MetaData, Table
|
|||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from ..ModelBase import Base
|
||||
|
||||
class SQLAlchemyAdapter():
|
||||
|
|
@ -117,7 +118,7 @@ class SQLAlchemyAdapter():
|
|||
if table_name in Base.metadata.tables:
|
||||
return Base.metadata.tables[table_name]
|
||||
else:
|
||||
raise ValueError(f"Table '{table_name}' not found.")
|
||||
raise EntityNotFoundError(message=f"Table '{table_name}' not found.")
|
||||
else:
|
||||
# Create a MetaData instance to load table information
|
||||
metadata = MetaData()
|
||||
|
|
@ -128,7 +129,7 @@ class SQLAlchemyAdapter():
|
|||
# Check if table is in list of tables for the given schema
|
||||
if full_table_name in metadata.tables:
|
||||
return metadata.tables[full_table_name]
|
||||
raise ValueError(f"Table '{full_table_name}' not found.")
|
||||
raise EntityNotFoundError(message=f"Table '{full_table_name}' not found.")
|
||||
|
||||
async def get_table_names(self) -> List[str]:
|
||||
"""
|
||||
|
|
@ -171,6 +172,27 @@ class SQLAlchemyAdapter():
|
|||
results = await connection.execute(query)
|
||||
return {result["data_id"]: result["status"] for result in results}
|
||||
|
||||
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
|
||||
async with self.get_async_session() as session:
|
||||
# Validate inputs to prevent SQL injection
|
||||
if not table_name.isidentifier():
|
||||
raise ValueError("Invalid table name")
|
||||
if schema and not schema.isidentifier():
|
||||
raise ValueError("Invalid schema name")
|
||||
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
table = await self.get_table(table_name)
|
||||
else:
|
||||
table = await self.get_table(table_name, schema)
|
||||
|
||||
# Query all data from the table
|
||||
query = select(table)
|
||||
result = await session.execute(query)
|
||||
|
||||
# Fetch all rows as a list of dictionaries
|
||||
rows = result.mappings().all()
|
||||
return rows
|
||||
|
||||
async def execute_query(self, query):
|
||||
async with self.engine.begin() as connection:
|
||||
result = await connection.execute(text(query))
|
||||
|
|
@ -205,7 +227,6 @@ class SQLAlchemyAdapter():
|
|||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
LocalStorage.remove(self.db_path)
|
||||
self.db_path = None
|
||||
else:
|
||||
async with self.engine.begin() as connection:
|
||||
schema_list = await self.get_schema_list()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from uuid import UUID
|
|||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import Vector, LanceModel
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
|
|
@ -122,7 +124,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
new_size = await collection.count_rows()
|
||||
|
||||
if new_size <= original_size:
|
||||
raise ValueError(
|
||||
raise InvalidValueError(message=
|
||||
"LanceDB create_datapoints error: data points did not get added.")
|
||||
|
||||
|
||||
|
|
@ -148,7 +150,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
query_vector: List[float] = None
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
|
@ -178,7 +180,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
normalized: bool = True
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
from sqlalchemy import JSON, Column, Table, select, delete
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
from .serialize_data import serialize_data
|
||||
|
|
@ -156,7 +158,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
if collection_name in Base.metadata.tables:
|
||||
return Base.metadata.tables[collection_name]
|
||||
else:
|
||||
raise ValueError(f"Table '{collection_name}' not found.")
|
||||
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||
# Get PGVectorDataPoint Table from database
|
||||
|
|
@ -230,7 +232,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
with_vector: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from ...relational.ModelBase import Base
|
||||
from ..get_vector_engine import get_vector_engine, get_vectordb_config
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
async def create_db_and_tables():
|
||||
vector_config = get_vectordb_config()
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
if vector_config.vector_db_provider == "pgvector":
|
||||
await vector_engine.create_database()
|
||||
async with vector_engine.engine.begin() as connection:
|
||||
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from uuid import UUID
|
|||
from typing import List, Dict, Optional
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
|
@ -186,7 +187,7 @@ class QDrantAdapter(VectorDBInterface):
|
|||
with_vector: bool = False
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import logging
|
|||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
|
@ -194,7 +195,7 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
import weaviate.classes as wvc
|
||||
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_vector is None:
|
||||
query_vector = (await self.embed_data([query_text]))[0]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
from tenacity import retry, stop_after_attempt
|
||||
import anthropic
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
|
||||
class AnthropicAdapter(LLMInterface):
|
||||
|
|
@ -37,3 +39,17 @@ class AnthropicAdapter(LLMInterface):
|
|||
}],
|
||||
response_model = response_model,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise InvalidValueError(message="No system prompt path provided.")
|
||||
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||
|
||||
return formatted_prompt
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ import asyncio
|
|||
from typing import List, Type
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
from tenacity import retry, stop_after_attempt
|
||||
import openai
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
|
@ -52,60 +54,6 @@ class GenericAPIAdapter(LLMInterface):
|
|||
mode = instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def completions_with_backoff(self, **kwargs):
|
||||
"""Wrapper around ChatCompletion.create w/ backoff"""
|
||||
# Local model
|
||||
return openai.chat.completions.create(**kwargs)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acompletions_with_backoff(self, **kwargs):
|
||||
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
||||
return await openai.chat.completions.acreate(**kwargs)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||
|
||||
return await self.aclient.embeddings.create(input = input, model = model)
|
||||
|
||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
text = text.replace("\n", " ")
|
||||
response = await self.aclient.embeddings.create(input = text, model = model)
|
||||
embedding = response.data[0].embedding
|
||||
return embedding
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def create_embedding_with_backoff(self, **kwargs):
|
||||
"""Wrapper around Embedding.create w/ backoff"""
|
||||
return openai.embeddings.create(**kwargs)
|
||||
|
||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting
|
||||
:param text: str
|
||||
:param model: str
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
response = self.create_embedding_with_backoff(input=[text], model=model)
|
||||
embedding = response.data[0].embedding
|
||||
return embedding
|
||||
|
||||
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
||||
"""To get multiple text embeddings in parallel, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
# Collect all coroutines
|
||||
coroutines = (self.async_get_embedding_with_backoff(text, model)
|
||||
for text, model in zip(texts, models))
|
||||
|
||||
# Run the coroutines in parallel and gather the results
|
||||
embeddings = await asyncio.gather(*coroutines)
|
||||
|
||||
return embeddings
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||
"""Generate a response from a user query."""
|
||||
|
||||
|
|
@ -122,3 +70,13 @@ class GenericAPIAdapter(LLMInterface):
|
|||
response_model = response_model,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise InvalidValueError(message="No system prompt path provided.")
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||
return formatted_prompt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Get the LLM client."""
|
||||
from enum import Enum
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm import get_llm_config
|
||||
|
||||
# Define an Enum for LLM Providers
|
||||
|
|
@ -17,7 +19,7 @@ def get_llm_client():
|
|||
|
||||
if provider == LLMProvider.OPENAI:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise ValueError("LLM API key is not set.")
|
||||
raise InvalidValueError(message="LLM API key is not set.")
|
||||
|
||||
from .openai.adapter import OpenAIAdapter
|
||||
|
||||
|
|
@ -32,7 +34,7 @@ def get_llm_client():
|
|||
|
||||
elif provider == LLMProvider.OLLAMA:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise ValueError("LLM API key is not set.")
|
||||
raise InvalidValueError(message="LLM API key is not set.")
|
||||
|
||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
||||
|
|
@ -43,10 +45,10 @@ def get_llm_client():
|
|||
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise ValueError("LLM API key is not set.")
|
||||
raise InvalidValueError(message="LLM API key is not set.")
|
||||
|
||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
raise InvalidValueError(message=f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ import litellm
|
|||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
class OpenAIAdapter(LLMInterface):
|
||||
name = "OpenAI"
|
||||
|
|
@ -120,3 +122,14 @@ class OpenAIAdapter(LLMInterface):
|
|||
max_tokens = 300,
|
||||
max_retries = 5,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise InvalidValueError(message="No system prompt path provided.")
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||
return formatted_prompt
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from cognee.exceptions import InvalidAttributeError
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
|
@ -12,7 +13,7 @@ async def delete_data(data: Data):
|
|||
ValueError: If the data object is invalid.
|
||||
"""
|
||||
if not hasattr(data, '__tablename__'):
|
||||
raise ValueError("The provided data object is missing the required '__tablename__' attribute.")
|
||||
raise InvalidAttributeError(message="The provided data object is missing the required '__tablename__' attribute.")
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import logging
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
|
||||
|
|
@ -18,10 +20,10 @@ async def translate_text(text, source_language: str = "sr", target_language: str
|
|||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
|
||||
if not text:
|
||||
raise ValueError("No text to translate.")
|
||||
raise InvalidValueError(message="No text to translate.")
|
||||
|
||||
if not source_language or not target_language:
|
||||
raise ValueError("Source and target language codes are required.")
|
||||
raise InvalidValueError(message="Source and target language codes are required.")
|
||||
|
||||
try:
|
||||
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import numpy as np
|
||||
|
||||
from typing import List, Dict, Union
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
||||
|
|
@ -29,7 +32,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
if node.id not in self.nodes:
|
||||
self.nodes[node.id] = node
|
||||
else:
|
||||
raise ValueError(f"Node with id {node.id} already exists.")
|
||||
raise EntityAlreadyExistsError(message=f"Node with id {node.id} already exists.")
|
||||
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
if edge not in self.edges:
|
||||
|
|
@ -37,7 +40,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
edge.node1.add_skeleton_edge(edge)
|
||||
edge.node2.add_skeleton_edge(edge)
|
||||
else:
|
||||
raise ValueError(f"Edge {edge} already exists in the graph.")
|
||||
raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.")
|
||||
|
||||
def get_node(self, node_id: str) -> Node:
|
||||
return self.nodes.get(node_id, None)
|
||||
|
|
@ -47,7 +50,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
if node:
|
||||
return node.skeleton_edges
|
||||
else:
|
||||
raise ValueError(f"Node with id {node_id} does not exist.")
|
||||
raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
|
||||
|
||||
def get_edges(self)-> List[Edge]:
|
||||
return self.edges
|
||||
|
|
@ -64,7 +67,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
) -> None:
|
||||
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise ValueError("Dimensions must be positive integers")
|
||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
if len(memory_fragment_filter) == 0:
|
||||
|
|
@ -73,9 +76,9 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
|
||||
|
||||
if not nodes_data:
|
||||
raise ValueError("No node data retrieved from the database.")
|
||||
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
||||
if not edges_data:
|
||||
raise ValueError("No edge data retrieved from the database.")
|
||||
raise EntityNotFoundError(message="No edge data retrieved from the database.")
|
||||
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
|
|
@ -95,7 +98,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
target_node.add_skeleton_edge(edge)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Edge references nonexistent nodes: {source_id} -> {target_id}")
|
||||
raise EntityNotFoundError(message=f"Edge references nonexistent nodes: {source_id} -> {target_id}")
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"Error projecting graph: {e}")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import numpy as np
|
||||
from typing import List, Dict, Optional, Any, Union
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Represents a node in a graph.
|
||||
|
|
@ -18,7 +21,7 @@ class Node:
|
|||
|
||||
def __init__(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1):
|
||||
if dimension <= 0:
|
||||
raise ValueError("Dimension must be a positive integer")
|
||||
raise InvalidValueError(message="Dimension must be a positive integer")
|
||||
self.id = node_id
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float('inf')
|
||||
|
|
@ -53,7 +56,7 @@ class Node:
|
|||
|
||||
def is_node_alive_in_dimension(self, dimension: int) -> bool:
|
||||
if dimension < 0 or dimension >= len(self.status):
|
||||
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||
return self.status[dimension] == 1
|
||||
|
||||
def add_attribute(self, key: str, value: Any) -> None:
|
||||
|
|
@ -90,7 +93,7 @@ class Edge:
|
|||
|
||||
def __init__(self, node1: "Node", node2: "Node", attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1):
|
||||
if dimension <= 0:
|
||||
raise ValueError("Dimensions must be a positive integer.")
|
||||
raise InvalidValueError(message="Dimensions must be a positive integer.")
|
||||
self.node1 = node1
|
||||
self.node2 = node2
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
|
|
@ -100,7 +103,7 @@ class Edge:
|
|||
|
||||
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
|
||||
if dimension < 0 or dimension >= len(self.status):
|
||||
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||
return self.status[dimension] == 1
|
||||
|
||||
def add_attribute(self, key: str, value: Any) -> None:
|
||||
|
|
|
|||
10
cognee/modules/graph/exceptions/__init__.py
Normal file
10
cognee/modules/graph/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""
|
||||
Custom exceptions for the Cognee API.
|
||||
|
||||
This module defines a set of exceptions for handling various graph errors
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
EntityNotFoundError,
|
||||
EntityAlreadyExistsError,
|
||||
)
|
||||
25
cognee/modules/graph/exceptions/exceptions.py
Normal file
25
cognee/modules/graph/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
class EntityNotFoundError(CogneeApiError):
|
||||
"""Database returns nothing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The requested entity does not exist.",
|
||||
name: str = "EntityNotFoundError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class EntityAlreadyExistsError(CogneeApiError):
|
||||
"""Conflict detected, like trying to create a resource that already exists"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The entity already exists.",
|
||||
name: str = "EntityAlreadyExistsError",
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
from io import BufferedReader
|
||||
from typing import Union, BinaryIO
|
||||
from .exceptions import IngestionException
|
||||
from .data_types import TextData, BinaryData
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
|
||||
|
||||
def classify(data: Union[str, BinaryIO], filename: str = None):
|
||||
if isinstance(data, str):
|
||||
return TextData(data)
|
||||
|
|
@ -11,4 +13,4 @@ def classify(data: Union[str, BinaryIO], filename: str = None):
|
|||
if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile):
|
||||
return BinaryData(data, data.name.split("/")[-1] if data.name else filename)
|
||||
|
||||
raise IngestionException(f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}")
|
||||
raise IngestionError(message=f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}")
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
|
||||
class IngestionException(Exception):
|
||||
message: str
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
9
cognee/modules/ingestion/exceptions/__init__.py
Normal file
9
cognee/modules/ingestion/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Custom exceptions for the Cognee API.
|
||||
|
||||
This module defines a set of exceptions for handling various ingestion errors
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
IngestionError,
|
||||
)
|
||||
11
cognee/modules/ingestion/exceptions/exceptions.py
Normal file
11
cognee/modules/ingestion/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
class IngestionError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Type of data sent to classify not supported.",
|
||||
name: str = "IngestionError",
|
||||
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
11
cognee/modules/users/exceptions/__init__.py
Normal file
11
cognee/modules/users/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""
|
||||
Custom exceptions for the Cognee API.
|
||||
|
||||
This module defines a set of exceptions for handling various user errors
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
GroupNotFoundError,
|
||||
UserNotFoundError,
|
||||
PermissionDeniedError,
|
||||
)
|
||||
36
cognee/modules/users/exceptions/exceptions.py
Normal file
36
cognee/modules/users/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class GroupNotFoundError(CogneeApiError):
|
||||
"""User group not found"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "User group not found.",
|
||||
name: str = "GroupNotFoundError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class UserNotFoundError(CogneeApiError):
|
||||
"""User not found"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "No user found in the system. Please create a user.",
|
||||
name: str = "UserNotFoundError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class PermissionDeniedError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "User does not have permission on documents.",
|
||||
name: str = "PermissionDeniedError",
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
@ -2,13 +2,14 @@ import os
|
|||
import uuid
|
||||
from typing import Optional
|
||||
from fastapi import Depends, Request
|
||||
from fastapi_users.exceptions import UserNotExists
|
||||
from fastapi_users import BaseUserManager, UUIDIDMixin, models
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
|
||||
from .get_user_db import get_user_db
|
||||
from .models import User
|
||||
from .methods import get_user
|
||||
from fastapi_users.exceptions import UserNotExists
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import logging
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ...models.User import User
|
||||
|
|
@ -9,11 +11,6 @@ from ...models.ACL import ACL
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PermissionDeniedException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
||||
user_group_ids = [group.id for group in user.groups]
|
||||
|
|
@ -33,4 +30,4 @@ async def check_permission_on_documents(user: User, permission_type: str, docume
|
|||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||
|
||||
if not has_permissions:
|
||||
raise PermissionDeniedException(f"User {user.email} does not have {permission_type} permission on documents")
|
||||
raise PermissionDeniedError(message=f"User {user.email} does not have {permission_type} permission on documents")
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ import csv
|
|||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import status
|
||||
from typing import Any, Dict, List, Optional, Union, Type
|
||||
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
||||
|
|
@ -75,9 +78,10 @@ class OntologyEngine:
|
|||
reader = csv.DictReader(content.splitlines())
|
||||
return list(reader)
|
||||
else:
|
||||
raise ValueError("Unsupported file format")
|
||||
raise IngestionError(message="Unsupported file format")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load data from {file_path}: {e}")
|
||||
raise IngestionError(message=f"Failed to load data from {file_path}: {e}",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
async def add_graph_ontology(self, file_path: str = None, documents: list = None):
|
||||
"""Add graph ontology from a JSON or CSV file or infer from documents content."""
|
||||
|
|
@ -148,7 +152,7 @@ class OntologyEngine:
|
|||
if node_id in valid_ids:
|
||||
await graph_client.add_node(node_id, node_data)
|
||||
if node_id not in valid_ids:
|
||||
raise ValueError(f"Node ID {node_id} not found in the dataset")
|
||||
raise EntityNotFoundError(message=f"Node ID {node_id} not found in the dataset")
|
||||
if pd.notna(row.get("relationship_source")) and pd.notna(row.get("relationship_target")):
|
||||
await graph_client.add_edge(
|
||||
row["relationship_source"],
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import dlt
|
||||
from typing import Union
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
@lru_cache
|
||||
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
|
||||
"""
|
||||
Handles propagation of the cognee database configuration to the dlt library
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import dlt
|
||||
import cognee.modules.ingestion as ingestion
|
||||
|
||||
from uuid import UUID
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -17,25 +18,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
)
|
||||
|
||||
@dlt.resource(standalone = True, merge_key = "id")
|
||||
async def data_resources(file_paths: str, user: User):
|
||||
async def data_resources(file_paths: str):
|
||||
for file_path in file_paths:
|
||||
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
data_id = ingestion.identify(classified_data)
|
||||
|
||||
file_metadata = classified_data.get_metadata()
|
||||
yield {
|
||||
"id": data_id,
|
||||
"name": file_metadata["name"],
|
||||
"file_path": file_metadata["file_path"],
|
||||
"extension": file_metadata["extension"],
|
||||
"mime_type": file_metadata["mime_type"],
|
||||
}
|
||||
|
||||
from sqlalchemy import select
|
||||
from cognee.modules.data.models import Data
|
||||
async def data_storing(table_name, dataset_name, user: User):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Read metadata stored with dlt
|
||||
files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
|
||||
for file_metadata in files_metadata:
|
||||
from sqlalchemy import select
|
||||
from cognee.modules.data.models import Data
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
|
||||
data = (await session.execute(
|
||||
select(Data).filter(Data.id == data_id)
|
||||
select(Data).filter(Data.id == UUID(file_metadata["id"]))
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if data is not None:
|
||||
|
|
@ -48,7 +57,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
await session.commit()
|
||||
else:
|
||||
data = Data(
|
||||
id = data_id,
|
||||
id = UUID(file_metadata["id"]),
|
||||
name = file_metadata["name"],
|
||||
raw_data_location = file_metadata["file_path"],
|
||||
extension = file_metadata["extension"],
|
||||
|
|
@ -58,25 +67,34 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
dataset.data.append(data)
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
"id": data_id,
|
||||
"name": file_metadata["name"],
|
||||
"file_path": file_metadata["file_path"],
|
||||
"extension": file_metadata["extension"],
|
||||
"mime_type": file_metadata["mime_type"],
|
||||
}
|
||||
|
||||
await give_permission_on_document(user, data_id, "read")
|
||||
await give_permission_on_document(user, data_id, "write")
|
||||
await give_permission_on_document(user, UUID(file_metadata["id"]), "read")
|
||||
await give_permission_on_document(user, UUID(file_metadata["id"]), "write")
|
||||
|
||||
|
||||
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths, user),
|
||||
table_name = "file_metadata",
|
||||
dataset_name = dataset_name,
|
||||
write_disposition = "merge",
|
||||
)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||
# can't be used inside the pipeline
|
||||
if db_engine.engine.dialect.name == "sqlite":
|
||||
# To use sqlite with dlt dataset_name must be set to "main".
|
||||
# Sqlite doesn't support schemas
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths),
|
||||
table_name = "file_metadata",
|
||||
dataset_name = "main",
|
||||
write_disposition = "merge",
|
||||
)
|
||||
else:
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths),
|
||||
table_name="file_metadata",
|
||||
dataset_name=dataset_name,
|
||||
write_disposition="merge",
|
||||
)
|
||||
|
||||
await data_storing("file_metadata", dataset_name, user)
|
||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
||||
|
||||
return run_info
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from typing import Union, BinaryIO
|
||||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
from cognee.modules.ingestion import save_data_to_file
|
||||
|
||||
def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str) -> str:
|
||||
|
|
@ -15,6 +17,6 @@ def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str
|
|||
else:
|
||||
file_path = save_data_to_file(data_item, dataset_name)
|
||||
else:
|
||||
raise ValueError(f"Data type not supported: {type(data_item)}")
|
||||
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||
|
||||
return file_path
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Any, BinaryIO, Union
|
||||
from typing import Union, BinaryIO, Any
|
||||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
from cognee.modules.ingestion import save_data_to_file
|
||||
|
||||
|
||||
|
|
@ -28,6 +29,6 @@ async def save_data_item_with_metadata_to_storage(
|
|||
else:
|
||||
file_path = save_data_to_file(data_item, dataset_name)
|
||||
else:
|
||||
raise ValueError(f"Data type not supported: {type(data_item)}")
|
||||
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||
|
||||
return file_path
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||
|
||||
|
||||
|
|
@ -15,7 +16,7 @@ def test_node_initialization():
|
|||
|
||||
def test_node_invalid_dimension():
|
||||
"""Test that initializing a Node with a non-positive dimension raises an error."""
|
||||
with pytest.raises(ValueError, match="Dimension must be a positive integer"):
|
||||
with pytest.raises(InvalidValueError, match="Dimension must be a positive integer"):
|
||||
Node("node1", dimension=0)
|
||||
|
||||
|
||||
|
|
@ -68,7 +69,7 @@ def test_is_node_alive_in_dimension():
|
|||
def test_node_alive_invalid_dimension():
|
||||
"""Test that checking alive status with an invalid dimension raises an error."""
|
||||
node = Node("node1", dimension=1)
|
||||
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
||||
with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"):
|
||||
node.is_node_alive_in_dimension(1)
|
||||
|
||||
|
||||
|
|
@ -105,7 +106,7 @@ def test_edge_invalid_dimension():
|
|||
"""Test that initializing an Edge with a non-positive dimension raises an error."""
|
||||
node1 = Node("node1")
|
||||
node2 = Node("node2")
|
||||
with pytest.raises(ValueError, match="Dimensions must be a positive integer."):
|
||||
with pytest.raises(InvalidValueError, match="Dimensions must be a positive integer."):
|
||||
Edge(node1, node2, dimension=0)
|
||||
|
||||
|
||||
|
|
@ -124,7 +125,7 @@ def test_edge_alive_invalid_dimension():
|
|||
node1 = Node("node1")
|
||||
node2 = Node("node2")
|
||||
edge = Edge(node1, node2, dimension=1)
|
||||
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
||||
with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"):
|
||||
edge.is_edge_alive_in_dimension(1)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||
|
||||
|
|
@ -23,7 +24,7 @@ def test_add_duplicate_node(setup_graph):
|
|||
graph = setup_graph
|
||||
node = Node("node1")
|
||||
graph.add_node(node)
|
||||
with pytest.raises(ValueError, match="Node with id node1 already exists."):
|
||||
with pytest.raises(EntityAlreadyExistsError, match="Node with id node1 already exists."):
|
||||
graph.add_node(node)
|
||||
|
||||
|
||||
|
|
@ -50,7 +51,7 @@ def test_add_duplicate_edge(setup_graph):
|
|||
graph.add_node(node2)
|
||||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
with pytest.raises(ValueError, match="Edge .* already exists in the graph."):
|
||||
with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."):
|
||||
graph.add_edge(edge)
|
||||
|
||||
|
||||
|
|
@ -83,5 +84,5 @@ def test_get_edges_success(setup_graph):
|
|||
def test_get_edges_nonexistent_node(setup_graph):
|
||||
"""Test retrieving edges for a nonexistent node raises an exception."""
|
||||
graph = setup_graph
|
||||
with pytest.raises(ValueError, match="Node with id nonexistent does not exist."):
|
||||
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
||||
graph.get_edges_from_node("nonexistent")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue