Merge remote-tracking branch 'origin/main' into code-graph
This commit is contained in:
commit
e07364fc25
46 changed files with 455 additions and 181 deletions
|
|
@ -3,10 +3,13 @@ import os
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import logging
|
import logging
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, status
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from cognee.exceptions import CogneeApiError
|
||||||
|
from traceback import format_exc
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
|
@ -76,6 +79,26 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
||||||
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(CogneeApiError)
|
||||||
|
async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
||||||
|
detail = {}
|
||||||
|
|
||||||
|
if exc.name and exc.message and exc.status_code:
|
||||||
|
status_code = exc.status_code
|
||||||
|
detail["message"] = f"{exc.message} [{exc.name}]"
|
||||||
|
else:
|
||||||
|
# Log an error indicating the exception is improperly defined
|
||||||
|
logger.error("Improperly defined exception: %s", exc)
|
||||||
|
# Provide a default error response
|
||||||
|
detail["message"] = "An unexpected error occurred."
|
||||||
|
status_code = status.HTTP_418_IM_A_TEAPOT
|
||||||
|
|
||||||
|
# log the stack trace for easier serverside debugging
|
||||||
|
logger.error(format_exc())
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code, content={"detail": detail["message"]}
|
||||||
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
get_auth_router(),
|
get_auth_router(),
|
||||||
prefix = "/api/v1/auth",
|
prefix = "/api/v1/auth",
|
||||||
|
|
|
||||||
|
|
@ -22,11 +22,6 @@ logger = logging.getLogger("code_graph_pipeline")
|
||||||
|
|
||||||
update_status_lock = asyncio.Lock()
|
update_status_lock = asyncio.Lock()
|
||||||
|
|
||||||
class PermissionDeniedException(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
self.message = message
|
|
||||||
super().__init__(self.message)
|
|
||||||
|
|
||||||
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
|
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
|
||||||
|
|
@ -24,11 +24,6 @@ logger = logging.getLogger("cognify.v2")
|
||||||
|
|
||||||
update_status_lock = asyncio.Lock()
|
update_status_lock = asyncio.Lock()
|
||||||
|
|
||||||
class PermissionDeniedException(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
self.message = message
|
|
||||||
super().__init__(self.message)
|
|
||||||
|
|
||||||
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
""" This module is used to set the configuration of the system."""
|
""" This module is used to set the configuration of the system."""
|
||||||
import os
|
import os
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.exceptions import InvalidValueError, InvalidAttributeError
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
||||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
|
@ -85,7 +86,7 @@ class config():
|
||||||
if hasattr(llm_config, key):
|
if hasattr(llm_config, key):
|
||||||
object.__setattr__(llm_config, key, value)
|
object.__setattr__(llm_config, key, value)
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
|
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_chunk_strategy(chunk_strategy: object):
|
def set_chunk_strategy(chunk_strategy: object):
|
||||||
|
|
@ -123,7 +124,7 @@ class config():
|
||||||
if hasattr(relational_db_config, key):
|
if hasattr(relational_db_config, key):
|
||||||
object.__setattr__(relational_db_config, key, value)
|
object.__setattr__(relational_db_config, key, value)
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
|
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_vector_db_config(config_dict: dict):
|
def set_vector_db_config(config_dict: dict):
|
||||||
|
|
@ -135,7 +136,7 @@ class config():
|
||||||
if hasattr(vector_db_config, key):
|
if hasattr(vector_db_config, key):
|
||||||
object.__setattr__(vector_db_config, key, value)
|
object.__setattr__(vector_db_config, key, value)
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
|
raise InvalidAttributeError(message=f"'{key}' is not a valid attribute of the config.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_vector_db_key(db_key: str):
|
def set_vector_db_key(db_key: str):
|
||||||
|
|
@ -153,7 +154,7 @@ class config():
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
|
|
||||||
if "username" not in graphistry_config or "password" not in graphistry_config:
|
if "username" not in graphistry_config or "password" not in graphistry_config:
|
||||||
raise ValueError("graphistry_config dictionary must contain 'username' and 'password' keys.")
|
raise InvalidValueError(message="graphistry_config dictionary must contain 'username' and 'password' keys.")
|
||||||
|
|
||||||
base_config.graphistry_username = graphistry_config.get("username")
|
base_config.graphistry_username = graphistry_config.get("username")
|
||||||
base_config.graphistry_password = graphistry_config.get("password")
|
base_config.graphistry_password = graphistry_config.get("password")
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.api.DTO import OutDTO
|
from cognee.api.DTO import OutDTO
|
||||||
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_authenticated_user
|
from cognee.modules.users.methods import get_authenticated_user
|
||||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||||
|
|
@ -55,9 +56,8 @@ def get_datasets_router() -> APIRouter:
|
||||||
dataset = await get_dataset(user.id, dataset_id)
|
dataset = await get_dataset(user.id, dataset_id)
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise HTTPException(
|
raise EntityNotFoundError(
|
||||||
status_code=404,
|
message=f"Dataset ({dataset_id}) not found."
|
||||||
detail=f"Dataset ({dataset_id}) not found."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await delete_dataset(dataset)
|
await delete_dataset(dataset)
|
||||||
|
|
@ -72,17 +72,15 @@ def get_datasets_router() -> APIRouter:
|
||||||
|
|
||||||
#TODO: Handle situation differently if user doesn't have permission to access data?
|
#TODO: Handle situation differently if user doesn't have permission to access data?
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise HTTPException(
|
raise EntityNotFoundError(
|
||||||
status_code=404,
|
message=f"Dataset ({dataset_id}) not found."
|
||||||
detail=f"Dataset ({dataset_id}) not found."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = await get_data(data_id)
|
data = await get_data(data_id)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
raise HTTPException(
|
raise EntityNotFoundError(
|
||||||
status_code=404,
|
message=f"Data ({data_id}) not found."
|
||||||
detail=f"Dataset ({data_id}) not found."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await delete_data(data)
|
await delete_data(data)
|
||||||
|
|
@ -158,18 +156,13 @@ def get_datasets_router() -> APIRouter:
|
||||||
dataset_data = await get_dataset_data(dataset.id)
|
dataset_data = await get_dataset_data(dataset.id)
|
||||||
|
|
||||||
if dataset_data is None:
|
if dataset_data is None:
|
||||||
raise HTTPException(status_code=404, detail=f"No data found in dataset ({dataset_id}).")
|
raise EntityNotFoundError(message=f"No data found in dataset ({dataset_id}).")
|
||||||
|
|
||||||
matching_data = [data for data in dataset_data if str(data.id) == data_id]
|
matching_data = [data for data in dataset_data if str(data.id) == data_id]
|
||||||
|
|
||||||
# Check if matching_data contains an element
|
# Check if matching_data contains an element
|
||||||
if len(matching_data) == 0:
|
if len(matching_data) == 0:
|
||||||
return JSONResponse(
|
raise EntityNotFoundError(message= f"Data ({data_id}) not found in dataset ({dataset_id}).")
|
||||||
status_code=404,
|
|
||||||
content={
|
|
||||||
"detail": f"Data ({data_id}) not found in dataset ({dataset_id})."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = matching_data[0]
|
data = matching_data[0]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundError
|
||||||
from cognee.modules.users import get_user_db
|
from cognee.modules.users import get_user_db
|
||||||
from cognee.modules.users.models import User, Group, Permission
|
from cognee.modules.users.models import User, Group, Permission
|
||||||
|
|
||||||
|
|
@ -12,7 +14,7 @@ def get_permissions_router() -> APIRouter:
|
||||||
group = db.query(Group).filter(Group.id == group_id).first()
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
|
||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(status_code = 404, detail = "Group not found")
|
raise GroupNotFoundError
|
||||||
|
|
||||||
permission = db.query(Permission).filter(Permission.name == permission).first()
|
permission = db.query(Permission).filter(Permission.name == permission).first()
|
||||||
|
|
||||||
|
|
@ -31,8 +33,10 @@ def get_permissions_router() -> APIRouter:
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
group = db.query(Group).filter(Group.id == group_id).first()
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
|
||||||
if not user or not group:
|
if not user:
|
||||||
raise HTTPException(status_code = 404, detail = "User or group not found")
|
raise UserNotFoundError
|
||||||
|
elif not group:
|
||||||
|
raise GroupNotFoundError
|
||||||
|
|
||||||
user.groups.append(group)
|
user.groups.append(group)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
|
||||||
from cognee.modules.search.vector.search_traverse import search_traverse
|
from cognee.modules.search.vector.search_traverse import search_traverse
|
||||||
from cognee.modules.search.graph.search_summary import search_summary
|
from cognee.modules.search.graph.search_summary import search_summary
|
||||||
from cognee.modules.search.graph.search_similarity import search_similarity
|
from cognee.modules.search.graph.search_similarity import search_similarity
|
||||||
|
|
||||||
|
from cognee.exceptions import UserNotFoundError
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
@ -47,7 +49,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise PermissionError("No user found in the system. Please create a user.")
|
raise UserNotFoundError
|
||||||
|
|
||||||
own_document_ids = await get_document_ids_for_user(user.id)
|
own_document_ids = await get_document_ids_for_user(user.id)
|
||||||
search_params = SearchParameters(search_type = search_type, params = params)
|
search_params = SearchParameters(search_type = search_type, params = params)
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,12 @@ import json
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Dict
|
from typing import Callable, Dict
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.modules.search.operations import log_query, log_result
|
from cognee.modules.search.operations import log_query, log_result
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
|
from cognee.modules.users.exceptions import UserNotFoundError
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||||
|
|
@ -22,7 +25,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise PermissionError("No user found in the system. Please create a user.")
|
raise UserNotFoundError
|
||||||
|
|
||||||
query = await log_query(query_text, str(query_type), user.id)
|
query = await log_query(query_text, str(query_type), user.id)
|
||||||
|
|
||||||
|
|
@ -52,7 +55,7 @@ async def specific_search(query_type: SearchType, query: str, user) -> list:
|
||||||
search_task = search_tasks.get(query_type)
|
search_task = search_tasks.get(query_type)
|
||||||
|
|
||||||
if search_task is None:
|
if search_task is None:
|
||||||
raise ValueError(f"Unsupported search type: {query_type}")
|
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
|
||||||
|
|
||||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
|
|
|
||||||
13
cognee/exceptions/__init__.py
Normal file
13
cognee/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
"""
|
||||||
|
Custom exceptions for the Cognee API.
|
||||||
|
|
||||||
|
This module defines a set of exceptions for handling various application errors,
|
||||||
|
such as service failures, resource conflicts, and invalid operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
CogneeApiError,
|
||||||
|
ServiceError,
|
||||||
|
InvalidValueError,
|
||||||
|
InvalidAttributeError,
|
||||||
|
)
|
||||||
54
cognee/exceptions/exceptions.py
Normal file
54
cognee/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
from fastapi import status
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CogneeApiError(Exception):
|
||||||
|
"""Base exception class"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Service is unavailable.",
|
||||||
|
name: str = "Cognee",
|
||||||
|
status_code=status.HTTP_418_IM_A_TEAPOT,
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.name = name
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
# Automatically log the exception details
|
||||||
|
logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})")
|
||||||
|
|
||||||
|
super().__init__(self.message, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceError(CogneeApiError):
|
||||||
|
"""Failures in external services or APIs, like a database or a third-party service"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Service is unavailable.",
|
||||||
|
name: str = "ServiceError",
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidValueError(CogneeApiError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Invalid Value.",
|
||||||
|
name: str = "InvalidValueError",
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAttributeError(CogneeApiError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Invalid attribute.",
|
||||||
|
name: str = "InvalidAttributeError",
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.shared.utils import extract_pos_tags
|
from cognee.shared.utils import extract_pos_tags
|
||||||
|
|
||||||
def extract_keywords(text: str) -> list[str]:
|
def extract_keywords(text: str) -> list[str]:
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
raise ValueError("extract_keywords cannot extract keywords from empty text.")
|
raise InvalidValueError(message="extract_keywords cannot extract keywords from empty text.")
|
||||||
|
|
||||||
tags = extract_pos_tags(text)
|
tags = extract_pos_tags(text)
|
||||||
nouns = [word for (word, tag) in tags if tag == "NN"]
|
nouns = [word for (word, tag) in tags if tag == "NN"]
|
||||||
|
|
|
||||||
10
cognee/infrastructure/databases/exceptions/__init__.py
Normal file
10
cognee/infrastructure/databases/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
"""
|
||||||
|
Custom exceptions for the Cognee API.
|
||||||
|
|
||||||
|
This module defines a set of exceptions for handling various database errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
EntityNotFoundError,
|
||||||
|
EntityAlreadyExistsError,
|
||||||
|
)
|
||||||
25
cognee/infrastructure/databases/exceptions/exceptions.py
Normal file
25
cognee/infrastructure/databases/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
from cognee.exceptions import CogneeApiError
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
class EntityNotFoundError(CogneeApiError):
|
||||||
|
"""Database returns nothing"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "The requested entity does not exist.",
|
||||||
|
name: str = "EntityNotFoundError",
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityAlreadyExistsError(CogneeApiError):
|
||||||
|
"""Conflict detected, like trying to create a resource that already exists"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "The entity already exists.",
|
||||||
|
name: str = "EntityAlreadyExistsError",
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
@ -5,6 +5,7 @@ from uuid import UUID
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from falkordb import FalkorDB
|
from falkordb import FalkorDB
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
||||||
|
|
@ -243,7 +244,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embed_data([query_text]))[0]
|
query_vector = (await self.embed_data([query_text]))[0]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from .config import get_relational_config
|
from .config import get_relational_config
|
||||||
from .create_relational_engine import create_relational_engine
|
from .create_relational_engine import create_relational_engine
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_relational_engine():
|
def get_relational_engine():
|
||||||
relational_config = get_relational_config()
|
relational_config = get_relational_config()
|
||||||
|
|
||||||
return create_relational_engine(**relational_config.to_dict())
|
return create_relational_engine(**relational_config.to_dict())
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import text, select, MetaData, Table
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
from ..ModelBase import Base
|
from ..ModelBase import Base
|
||||||
|
|
||||||
class SQLAlchemyAdapter():
|
class SQLAlchemyAdapter():
|
||||||
|
|
@ -117,7 +118,7 @@ class SQLAlchemyAdapter():
|
||||||
if table_name in Base.metadata.tables:
|
if table_name in Base.metadata.tables:
|
||||||
return Base.metadata.tables[table_name]
|
return Base.metadata.tables[table_name]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Table '{table_name}' not found.")
|
raise EntityNotFoundError(message=f"Table '{table_name}' not found.")
|
||||||
else:
|
else:
|
||||||
# Create a MetaData instance to load table information
|
# Create a MetaData instance to load table information
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
|
|
@ -128,7 +129,7 @@ class SQLAlchemyAdapter():
|
||||||
# Check if table is in list of tables for the given schema
|
# Check if table is in list of tables for the given schema
|
||||||
if full_table_name in metadata.tables:
|
if full_table_name in metadata.tables:
|
||||||
return metadata.tables[full_table_name]
|
return metadata.tables[full_table_name]
|
||||||
raise ValueError(f"Table '{full_table_name}' not found.")
|
raise EntityNotFoundError(message=f"Table '{full_table_name}' not found.")
|
||||||
|
|
||||||
async def get_table_names(self) -> List[str]:
|
async def get_table_names(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -171,6 +172,27 @@ class SQLAlchemyAdapter():
|
||||||
results = await connection.execute(query)
|
results = await connection.execute(query)
|
||||||
return {result["data_id"]: result["status"] for result in results}
|
return {result["data_id"]: result["status"] for result in results}
|
||||||
|
|
||||||
|
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
# Validate inputs to prevent SQL injection
|
||||||
|
if not table_name.isidentifier():
|
||||||
|
raise ValueError("Invalid table name")
|
||||||
|
if schema and not schema.isidentifier():
|
||||||
|
raise ValueError("Invalid schema name")
|
||||||
|
|
||||||
|
if self.engine.dialect.name == "sqlite":
|
||||||
|
table = await self.get_table(table_name)
|
||||||
|
else:
|
||||||
|
table = await self.get_table(table_name, schema)
|
||||||
|
|
||||||
|
# Query all data from the table
|
||||||
|
query = select(table)
|
||||||
|
result = await session.execute(query)
|
||||||
|
|
||||||
|
# Fetch all rows as a list of dictionaries
|
||||||
|
rows = result.mappings().all()
|
||||||
|
return rows
|
||||||
|
|
||||||
async def execute_query(self, query):
|
async def execute_query(self, query):
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
result = await connection.execute(text(query))
|
result = await connection.execute(text(query))
|
||||||
|
|
@ -205,7 +227,6 @@ class SQLAlchemyAdapter():
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
|
||||||
LocalStorage.remove(self.db_path)
|
LocalStorage.remove(self.db_path)
|
||||||
self.db_path = None
|
|
||||||
else:
|
else:
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
schema_list = await self.get_schema_list()
|
schema_list = await self.get_schema_list()
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lancedb.pydantic import Vector, LanceModel
|
from lancedb.pydantic import Vector, LanceModel
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||||
|
|
@ -122,7 +124,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
new_size = await collection.count_rows()
|
new_size = await collection.count_rows()
|
||||||
|
|
||||||
if new_size <= original_size:
|
if new_size <= original_size:
|
||||||
raise ValueError(
|
raise InvalidValueError(message=
|
||||||
"LanceDB create_datapoints error: data points did not get added.")
|
"LanceDB create_datapoints error: data points did not get added.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -148,7 +150,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
query_vector: List[float] = None
|
query_vector: List[float] = None
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
@ -178,7 +180,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
normalized: bool = True
|
normalized: bool = True
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy import JSON, Column, Table, select, delete
|
from sqlalchemy import JSON, Column, Table, select, delete
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
from .serialize_data import serialize_data
|
from .serialize_data import serialize_data
|
||||||
|
|
@ -156,7 +158,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
if collection_name in Base.metadata.tables:
|
if collection_name in Base.metadata.tables:
|
||||||
return Base.metadata.tables[collection_name]
|
return Base.metadata.tables[collection_name]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Table '{collection_name}' not found.")
|
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
|
|
@ -230,7 +232,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
) -> List[ScoredResult]:
|
) -> List[ScoredResult]:
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from ...relational.ModelBase import Base
|
|
||||||
from ..get_vector_engine import get_vector_engine, get_vectordb_config
|
from ..get_vector_engine import get_vector_engine, get_vectordb_config
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
async def create_db_and_tables():
|
async def create_db_and_tables():
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
if vector_config.vector_db_provider == "pgvector":
|
if vector_config.vector_db_provider == "pgvector":
|
||||||
await vector_engine.create_database()
|
|
||||||
async with vector_engine.engine.begin() as connection:
|
async with vector_engine.engine.begin() as connection:
|
||||||
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from uuid import UUID
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
|
|
@ -186,7 +187,7 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
with_vector: bool = False
|
with_vector: bool = False
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
client = self.get_qdrant_client()
|
client = self.get_qdrant_client()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
|
@ -194,7 +195,7 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
|
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
if query_vector is None:
|
if query_vector is None:
|
||||||
query_vector = (await self.embed_data([query_text]))[0]
|
query_vector = (await self.embed_data([query_text]))[0]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import instructor
|
import instructor
|
||||||
from tenacity import retry, stop_after_attempt
|
|
||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
class AnthropicAdapter(LLMInterface):
|
class AnthropicAdapter(LLMInterface):
|
||||||
|
|
@ -37,3 +39,17 @@ class AnthropicAdapter(LLMInterface):
|
||||||
}],
|
}],
|
||||||
response_model = response_model,
|
response_model = response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
"""Format and display the prompt for a user query."""
|
||||||
|
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise InvalidValueError(message="No system prompt path provided.")
|
||||||
|
|
||||||
|
system_prompt = read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||||
|
|
||||||
|
return formatted_prompt
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,11 @@ import asyncio
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import instructor
|
import instructor
|
||||||
from tenacity import retry, stop_after_attempt
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
from cognee.shared.data_models import MonitoringTool
|
from cognee.shared.data_models import MonitoringTool
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
@ -52,60 +54,6 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
mode = instructor.Mode.JSON,
|
mode = instructor.Mode.JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def completions_with_backoff(self, **kwargs):
|
|
||||||
"""Wrapper around ChatCompletion.create w/ backoff"""
|
|
||||||
# Local model
|
|
||||||
return openai.chat.completions.create(**kwargs)
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acompletions_with_backoff(self, **kwargs):
|
|
||||||
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
|
||||||
return await openai.chat.completions.acreate(**kwargs)
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
|
||||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
|
||||||
|
|
||||||
return await self.aclient.embeddings.create(input = input, model = model)
|
|
||||||
|
|
||||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
|
||||||
"""To get text embeddings, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting + is async"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
response = await self.aclient.embeddings.create(input = text, model = model)
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def create_embedding_with_backoff(self, **kwargs):
|
|
||||||
"""Wrapper around Embedding.create w/ backoff"""
|
|
||||||
return openai.embeddings.create(**kwargs)
|
|
||||||
|
|
||||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
|
||||||
"""To get text embeddings, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting
|
|
||||||
:param text: str
|
|
||||||
:param model: str
|
|
||||||
"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
response = self.create_embedding_with_backoff(input=[text], model=model)
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
|
||||||
"""To get multiple text embeddings in parallel, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting + is async"""
|
|
||||||
# Collect all coroutines
|
|
||||||
coroutines = (self.async_get_embedding_with_backoff(text, model)
|
|
||||||
for text, model in zip(texts, models))
|
|
||||||
|
|
||||||
# Run the coroutines in parallel and gather the results
|
|
||||||
embeddings = await asyncio.gather(*coroutines)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||||
"""Generate a response from a user query."""
|
"""Generate a response from a user query."""
|
||||||
|
|
||||||
|
|
@ -122,3 +70,13 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
response_model = response_model,
|
response_model = response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
"""Format and display the prompt for a user query."""
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise InvalidValueError(message="No system prompt path provided.")
|
||||||
|
system_prompt = read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||||
|
return formatted_prompt
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
"""Get the LLM client."""
|
"""Get the LLM client."""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.llm import get_llm_config
|
from cognee.infrastructure.llm import get_llm_config
|
||||||
|
|
||||||
# Define an Enum for LLM Providers
|
# Define an Enum for LLM Providers
|
||||||
|
|
@ -17,7 +19,7 @@ def get_llm_client():
|
||||||
|
|
||||||
if provider == LLMProvider.OPENAI:
|
if provider == LLMProvider.OPENAI:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise ValueError("LLM API key is not set.")
|
raise InvalidValueError(message="LLM API key is not set.")
|
||||||
|
|
||||||
from .openai.adapter import OpenAIAdapter
|
from .openai.adapter import OpenAIAdapter
|
||||||
|
|
||||||
|
|
@ -32,7 +34,7 @@ def get_llm_client():
|
||||||
|
|
||||||
elif provider == LLMProvider.OLLAMA:
|
elif provider == LLMProvider.OLLAMA:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise ValueError("LLM API key is not set.")
|
raise InvalidValueError(message="LLM API key is not set.")
|
||||||
|
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
||||||
|
|
@ -43,10 +45,10 @@ def get_llm_client():
|
||||||
|
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise ValueError("LLM API key is not set.")
|
raise InvalidValueError(message="LLM API key is not set.")
|
||||||
|
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise InvalidValueError(message=f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
|
|
||||||
class OpenAIAdapter(LLMInterface):
|
class OpenAIAdapter(LLMInterface):
|
||||||
name = "OpenAI"
|
name = "OpenAI"
|
||||||
|
|
@ -120,3 +122,14 @@ class OpenAIAdapter(LLMInterface):
|
||||||
max_tokens = 300,
|
max_tokens = 300,
|
||||||
max_retries = 5,
|
max_retries = 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
"""Format and display the prompt for a user query."""
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise InvalidValueError(message="No system prompt path provided.")
|
||||||
|
system_prompt = read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||||
|
return formatted_prompt
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from cognee.exceptions import InvalidAttributeError
|
||||||
from cognee.modules.data.models import Data
|
from cognee.modules.data.models import Data
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
|
|
@ -12,7 +13,7 @@ async def delete_data(data: Data):
|
||||||
ValueError: If the data object is invalid.
|
ValueError: If the data object is invalid.
|
||||||
"""
|
"""
|
||||||
if not hasattr(data, '__tablename__'):
|
if not hasattr(data, '__tablename__'):
|
||||||
raise ValueError("The provided data object is missing the required '__tablename__' attribute.")
|
raise InvalidAttributeError(message="The provided data object is missing the required '__tablename__' attribute.")
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
|
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
|
||||||
|
|
@ -18,10 +20,10 @@ async def translate_text(text, source_language: str = "sr", target_language: str
|
||||||
from botocore.exceptions import BotoCoreError, ClientError
|
from botocore.exceptions import BotoCoreError, ClientError
|
||||||
|
|
||||||
if not text:
|
if not text:
|
||||||
raise ValueError("No text to translate.")
|
raise InvalidValueError(message="No text to translate.")
|
||||||
|
|
||||||
if not source_language or not target_language:
|
if not source_language or not target_language:
|
||||||
raise ValueError("Source and target language codes are required.")
|
raise InvalidValueError(message="Source and target language codes are required.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)
|
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||||
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
||||||
|
|
@ -29,7 +32,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
if node.id not in self.nodes:
|
if node.id not in self.nodes:
|
||||||
self.nodes[node.id] = node
|
self.nodes[node.id] = node
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Node with id {node.id} already exists.")
|
raise EntityAlreadyExistsError(message=f"Node with id {node.id} already exists.")
|
||||||
|
|
||||||
def add_edge(self, edge: Edge) -> None:
|
def add_edge(self, edge: Edge) -> None:
|
||||||
if edge not in self.edges:
|
if edge not in self.edges:
|
||||||
|
|
@ -37,7 +40,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
edge.node1.add_skeleton_edge(edge)
|
edge.node1.add_skeleton_edge(edge)
|
||||||
edge.node2.add_skeleton_edge(edge)
|
edge.node2.add_skeleton_edge(edge)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Edge {edge} already exists in the graph.")
|
raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.")
|
||||||
|
|
||||||
def get_node(self, node_id: str) -> Node:
|
def get_node(self, node_id: str) -> Node:
|
||||||
return self.nodes.get(node_id, None)
|
return self.nodes.get(node_id, None)
|
||||||
|
|
@ -47,7 +50,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
if node:
|
if node:
|
||||||
return node.skeleton_edges
|
return node.skeleton_edges
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Node with id {node_id} does not exist.")
|
raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
|
||||||
|
|
||||||
def get_edges(self)-> List[Edge]:
|
def get_edges(self)-> List[Edge]:
|
||||||
return self.edges
|
return self.edges
|
||||||
|
|
@ -64,7 +67,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
if node_dimension < 1 or edge_dimension < 1:
|
if node_dimension < 1 or edge_dimension < 1:
|
||||||
raise ValueError("Dimensions must be positive integers")
|
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if len(memory_fragment_filter) == 0:
|
if len(memory_fragment_filter) == 0:
|
||||||
|
|
@ -73,9 +76,9 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
|
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
|
||||||
|
|
||||||
if not nodes_data:
|
if not nodes_data:
|
||||||
raise ValueError("No node data retrieved from the database.")
|
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
||||||
if not edges_data:
|
if not edges_data:
|
||||||
raise ValueError("No edge data retrieved from the database.")
|
raise EntityNotFoundError(message="No edge data retrieved from the database.")
|
||||||
|
|
||||||
for node_id, properties in nodes_data:
|
for node_id, properties in nodes_data:
|
||||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||||
|
|
@ -95,7 +98,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
target_node.add_skeleton_edge(edge)
|
target_node.add_skeleton_edge(edge)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Edge references nonexistent nodes: {source_id} -> {target_id}")
|
raise EntityNotFoundError(message=f"Edge references nonexistent nodes: {source_id} -> {target_id}")
|
||||||
|
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
print(f"Error projecting graph: {e}")
|
print(f"Error projecting graph: {e}")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Dict, Optional, Any, Union
|
from typing import List, Dict, Optional, Any, Union
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
"""
|
"""
|
||||||
Represents a node in a graph.
|
Represents a node in a graph.
|
||||||
|
|
@ -18,7 +21,7 @@ class Node:
|
||||||
|
|
||||||
def __init__(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1):
|
def __init__(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1):
|
||||||
if dimension <= 0:
|
if dimension <= 0:
|
||||||
raise ValueError("Dimension must be a positive integer")
|
raise InvalidValueError(message="Dimension must be a positive integer")
|
||||||
self.id = node_id
|
self.id = node_id
|
||||||
self.attributes = attributes if attributes is not None else {}
|
self.attributes = attributes if attributes is not None else {}
|
||||||
self.attributes["vector_distance"] = float('inf')
|
self.attributes["vector_distance"] = float('inf')
|
||||||
|
|
@ -53,7 +56,7 @@ class Node:
|
||||||
|
|
||||||
def is_node_alive_in_dimension(self, dimension: int) -> bool:
|
def is_node_alive_in_dimension(self, dimension: int) -> bool:
|
||||||
if dimension < 0 or dimension >= len(self.status):
|
if dimension < 0 or dimension >= len(self.status):
|
||||||
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||||
return self.status[dimension] == 1
|
return self.status[dimension] == 1
|
||||||
|
|
||||||
def add_attribute(self, key: str, value: Any) -> None:
|
def add_attribute(self, key: str, value: Any) -> None:
|
||||||
|
|
@ -90,7 +93,7 @@ class Edge:
|
||||||
|
|
||||||
def __init__(self, node1: "Node", node2: "Node", attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1):
|
def __init__(self, node1: "Node", node2: "Node", attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1):
|
||||||
if dimension <= 0:
|
if dimension <= 0:
|
||||||
raise ValueError("Dimensions must be a positive integer.")
|
raise InvalidValueError(message="Dimensions must be a positive integer.")
|
||||||
self.node1 = node1
|
self.node1 = node1
|
||||||
self.node2 = node2
|
self.node2 = node2
|
||||||
self.attributes = attributes if attributes is not None else {}
|
self.attributes = attributes if attributes is not None else {}
|
||||||
|
|
@ -100,7 +103,7 @@ class Edge:
|
||||||
|
|
||||||
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
|
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
|
||||||
if dimension < 0 or dimension >= len(self.status):
|
if dimension < 0 or dimension >= len(self.status):
|
||||||
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||||
return self.status[dimension] == 1
|
return self.status[dimension] == 1
|
||||||
|
|
||||||
def add_attribute(self, key: str, value: Any) -> None:
|
def add_attribute(self, key: str, value: Any) -> None:
|
||||||
|
|
|
||||||
10
cognee/modules/graph/exceptions/__init__.py
Normal file
10
cognee/modules/graph/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
"""
|
||||||
|
Custom exceptions for the Cognee API.
|
||||||
|
|
||||||
|
This module defines a set of exceptions for handling various graph errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
EntityNotFoundError,
|
||||||
|
EntityAlreadyExistsError,
|
||||||
|
)
|
||||||
25
cognee/modules/graph/exceptions/exceptions.py
Normal file
25
cognee/modules/graph/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
from cognee.exceptions import CogneeApiError
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
class EntityNotFoundError(CogneeApiError):
|
||||||
|
"""Database returns nothing"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "The requested entity does not exist.",
|
||||||
|
name: str = "EntityNotFoundError",
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityAlreadyExistsError(CogneeApiError):
|
||||||
|
"""Conflict detected, like trying to create a resource that already exists"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "The entity already exists.",
|
||||||
|
name: str = "EntityAlreadyExistsError",
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from io import BufferedReader
|
from io import BufferedReader
|
||||||
from typing import Union, BinaryIO
|
from typing import Union, BinaryIO
|
||||||
from .exceptions import IngestionException
|
|
||||||
from .data_types import TextData, BinaryData
|
from .data_types import TextData, BinaryData
|
||||||
from tempfile import SpooledTemporaryFile
|
from tempfile import SpooledTemporaryFile
|
||||||
|
|
||||||
|
from cognee.modules.ingestion.exceptions import IngestionError
|
||||||
|
|
||||||
|
|
||||||
def classify(data: Union[str, BinaryIO], filename: str = None):
|
def classify(data: Union[str, BinaryIO], filename: str = None):
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
return TextData(data)
|
return TextData(data)
|
||||||
|
|
@ -11,4 +13,4 @@ def classify(data: Union[str, BinaryIO], filename: str = None):
|
||||||
if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile):
|
if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile):
|
||||||
return BinaryData(data, data.name.split("/")[-1] if data.name else filename)
|
return BinaryData(data, data.name.split("/")[-1] if data.name else filename)
|
||||||
|
|
||||||
raise IngestionException(f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}")
|
raise IngestionError(message=f"Type of data sent to classify(data: Union[str, BinaryIO) not supported: {type(data)}")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
|
|
||||||
class IngestionException(Exception):
|
|
||||||
message: str
|
|
||||||
|
|
||||||
def __init__(self, message: str):
|
|
||||||
self.message = message
|
|
||||||
9
cognee/modules/ingestion/exceptions/__init__.py
Normal file
9
cognee/modules/ingestion/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""
|
||||||
|
Custom exceptions for the Cognee API.
|
||||||
|
|
||||||
|
This module defines a set of exceptions for handling various ingestion errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
IngestionError,
|
||||||
|
)
|
||||||
11
cognee/modules/ingestion/exceptions/exceptions.py
Normal file
11
cognee/modules/ingestion/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
from cognee.exceptions import CogneeApiError
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
class IngestionError(CogneeApiError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Type of data sent to classify not supported.",
|
||||||
|
name: str = "IngestionError",
|
||||||
|
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
11
cognee/modules/users/exceptions/__init__.py
Normal file
11
cognee/modules/users/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""
|
||||||
|
Custom exceptions for the Cognee API.
|
||||||
|
|
||||||
|
This module defines a set of exceptions for handling various user errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
GroupNotFoundError,
|
||||||
|
UserNotFoundError,
|
||||||
|
PermissionDeniedError,
|
||||||
|
)
|
||||||
36
cognee/modules/users/exceptions/exceptions.py
Normal file
36
cognee/modules/users/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
from cognee.exceptions import CogneeApiError
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNotFoundError(CogneeApiError):
|
||||||
|
"""User group not found"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "User group not found.",
|
||||||
|
name: str = "GroupNotFoundError",
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotFoundError(CogneeApiError):
|
||||||
|
"""User not found"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "No user found in the system. Please create a user.",
|
||||||
|
name: str = "UserNotFoundError",
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionDeniedError(CogneeApiError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "User does not have permission on documents.",
|
||||||
|
name: str = "PermissionDeniedError",
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
@ -2,13 +2,14 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import Depends, Request
|
from fastapi import Depends, Request
|
||||||
from fastapi_users.exceptions import UserNotExists
|
|
||||||
from fastapi_users import BaseUserManager, UUIDIDMixin, models
|
from fastapi_users import BaseUserManager, UUIDIDMixin, models
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
|
|
||||||
from .get_user_db import get_user_db
|
from .get_user_db import get_user_db
|
||||||
from .models import User
|
from .models import User
|
||||||
from .methods import get_user
|
from .methods import get_user
|
||||||
|
from fastapi_users.exceptions import UserNotExists
|
||||||
|
|
||||||
|
|
||||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")
|
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import logging
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
from ...models.User import User
|
from ...models.User import User
|
||||||
|
|
@ -9,11 +11,6 @@ from ...models.ACL import ACL
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class PermissionDeniedException(Exception):
|
|
||||||
def __init__(self, message: str):
|
|
||||||
self.message = message
|
|
||||||
super().__init__(self.message)
|
|
||||||
|
|
||||||
|
|
||||||
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
||||||
user_group_ids = [group.id for group in user.groups]
|
user_group_ids = [group.id for group in user.groups]
|
||||||
|
|
@ -33,4 +30,4 @@ async def check_permission_on_documents(user: User, permission_type: str, docume
|
||||||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||||
|
|
||||||
if not has_permissions:
|
if not has_permissions:
|
||||||
raise PermissionDeniedException(f"User {user.email} does not have {permission_type} permission on documents")
|
raise PermissionDeniedError(message=f"User {user.email} does not have {permission_type} permission on documents")
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,15 @@ import csv
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from fastapi import status
|
||||||
from typing import Any, Dict, List, Optional, Union, Type
|
from typing import Any, Dict, List, Optional, Union, Type
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||||
|
from cognee.modules.ingestion.exceptions import IngestionError
|
||||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
from cognee.infrastructure.data.chunking.config import get_chunk_config
|
||||||
|
|
@ -75,9 +78,10 @@ class OntologyEngine:
|
||||||
reader = csv.DictReader(content.splitlines())
|
reader = csv.DictReader(content.splitlines())
|
||||||
return list(reader)
|
return list(reader)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported file format")
|
raise IngestionError(message="Unsupported file format")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load data from {file_path}: {e}")
|
raise IngestionError(message=f"Failed to load data from {file_path}: {e}",
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
|
|
||||||
async def add_graph_ontology(self, file_path: str = None, documents: list = None):
|
async def add_graph_ontology(self, file_path: str = None, documents: list = None):
|
||||||
"""Add graph ontology from a JSON or CSV file or infer from documents content."""
|
"""Add graph ontology from a JSON or CSV file or infer from documents content."""
|
||||||
|
|
@ -148,7 +152,7 @@ class OntologyEngine:
|
||||||
if node_id in valid_ids:
|
if node_id in valid_ids:
|
||||||
await graph_client.add_node(node_id, node_data)
|
await graph_client.add_node(node_id, node_data)
|
||||||
if node_id not in valid_ids:
|
if node_id not in valid_ids:
|
||||||
raise ValueError(f"Node ID {node_id} not found in the dataset")
|
raise EntityNotFoundError(message=f"Node ID {node_id} not found in the dataset")
|
||||||
if pd.notna(row.get("relationship_source")) and pd.notna(row.get("relationship_target")):
|
if pd.notna(row.get("relationship_source")) and pd.notna(row.get("relationship_target")):
|
||||||
await graph_client.add_edge(
|
await graph_client.add_edge(
|
||||||
row["relationship_source"],
|
row["relationship_source"],
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
import dlt
|
import dlt
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_config
|
from cognee.infrastructure.databases.relational import get_relational_config
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
|
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
|
||||||
"""
|
"""
|
||||||
Handles propagation of the cognee database configuration to the dlt library
|
Handles propagation of the cognee database configuration to the dlt library
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import dlt
|
import dlt
|
||||||
import cognee.modules.ingestion as ingestion
|
import cognee.modules.ingestion as ingestion
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
@ -17,25 +18,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
)
|
)
|
||||||
|
|
||||||
@dlt.resource(standalone = True, merge_key = "id")
|
@dlt.resource(standalone = True, merge_key = "id")
|
||||||
async def data_resources(file_paths: str, user: User):
|
async def data_resources(file_paths: str):
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
||||||
classified_data = ingestion.classify(file)
|
classified_data = ingestion.classify(file)
|
||||||
|
|
||||||
data_id = ingestion.identify(classified_data)
|
data_id = ingestion.identify(classified_data)
|
||||||
|
|
||||||
file_metadata = classified_data.get_metadata()
|
file_metadata = classified_data.get_metadata()
|
||||||
|
yield {
|
||||||
|
"id": data_id,
|
||||||
|
"name": file_metadata["name"],
|
||||||
|
"file_path": file_metadata["file_path"],
|
||||||
|
"extension": file_metadata["extension"],
|
||||||
|
"mime_type": file_metadata["mime_type"],
|
||||||
|
}
|
||||||
|
|
||||||
from sqlalchemy import select
|
async def data_storing(table_name, dataset_name, user: User):
|
||||||
from cognee.modules.data.models import Data
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
async with db_engine.get_async_session() as session:
|
||||||
|
# Read metadata stored with dlt
|
||||||
async with db_engine.get_async_session() as session:
|
files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
|
||||||
|
for file_metadata in files_metadata:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from cognee.modules.data.models import Data
|
||||||
dataset = await create_dataset(dataset_name, user.id, session)
|
dataset = await create_dataset(dataset_name, user.id, session)
|
||||||
|
|
||||||
data = (await session.execute(
|
data = (await session.execute(
|
||||||
select(Data).filter(Data.id == data_id)
|
select(Data).filter(Data.id == UUID(file_metadata["id"]))
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
|
|
@ -48,7 +57,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
await session.commit()
|
await session.commit()
|
||||||
else:
|
else:
|
||||||
data = Data(
|
data = Data(
|
||||||
id = data_id,
|
id = UUID(file_metadata["id"]),
|
||||||
name = file_metadata["name"],
|
name = file_metadata["name"],
|
||||||
raw_data_location = file_metadata["file_path"],
|
raw_data_location = file_metadata["file_path"],
|
||||||
extension = file_metadata["extension"],
|
extension = file_metadata["extension"],
|
||||||
|
|
@ -58,25 +67,34 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
dataset.data.append(data)
|
dataset.data.append(data)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
yield {
|
await give_permission_on_document(user, UUID(file_metadata["id"]), "read")
|
||||||
"id": data_id,
|
await give_permission_on_document(user, UUID(file_metadata["id"]), "write")
|
||||||
"name": file_metadata["name"],
|
|
||||||
"file_path": file_metadata["file_path"],
|
|
||||||
"extension": file_metadata["extension"],
|
|
||||||
"mime_type": file_metadata["mime_type"],
|
|
||||||
}
|
|
||||||
|
|
||||||
await give_permission_on_document(user, data_id, "read")
|
|
||||||
await give_permission_on_document(user, data_id, "write")
|
|
||||||
|
|
||||||
|
|
||||||
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
||||||
run_info = pipeline.run(
|
|
||||||
data_resources(file_paths, user),
|
db_engine = get_relational_engine()
|
||||||
table_name = "file_metadata",
|
|
||||||
dataset_name = dataset_name,
|
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||||
write_disposition = "merge",
|
# can't be used inside the pipeline
|
||||||
)
|
if db_engine.engine.dialect.name == "sqlite":
|
||||||
|
# To use sqlite with dlt dataset_name must be set to "main".
|
||||||
|
# Sqlite doesn't support schemas
|
||||||
|
run_info = pipeline.run(
|
||||||
|
data_resources(file_paths),
|
||||||
|
table_name = "file_metadata",
|
||||||
|
dataset_name = "main",
|
||||||
|
write_disposition = "merge",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
run_info = pipeline.run(
|
||||||
|
data_resources(file_paths),
|
||||||
|
table_name="file_metadata",
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
write_disposition="merge",
|
||||||
|
)
|
||||||
|
|
||||||
|
await data_storing("file_metadata", dataset_name, user)
|
||||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
||||||
|
|
||||||
return run_info
|
return run_info
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Union, BinaryIO
|
from typing import Union, BinaryIO
|
||||||
|
|
||||||
|
from cognee.modules.ingestion.exceptions import IngestionError
|
||||||
from cognee.modules.ingestion import save_data_to_file
|
from cognee.modules.ingestion import save_data_to_file
|
||||||
|
|
||||||
def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str) -> str:
|
def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str) -> str:
|
||||||
|
|
@ -15,6 +17,6 @@ def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str
|
||||||
else:
|
else:
|
||||||
file_path = save_data_to_file(data_item, dataset_name)
|
file_path = save_data_to_file(data_item, dataset_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Data type not supported: {type(data_item)}")
|
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Any, BinaryIO, Union
|
from typing import Union, BinaryIO, Any
|
||||||
|
|
||||||
|
from cognee.modules.ingestion.exceptions import IngestionError
|
||||||
from cognee.modules.ingestion import save_data_to_file
|
from cognee.modules.ingestion import save_data_to_file
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,6 +29,6 @@ async def save_data_item_with_metadata_to_storage(
|
||||||
else:
|
else:
|
||||||
file_path = save_data_to_file(data_item, dataset_name)
|
file_path = save_data_to_file(data_item, dataset_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Data type not supported: {type(data_item)}")
|
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -15,7 +16,7 @@ def test_node_initialization():
|
||||||
|
|
||||||
def test_node_invalid_dimension():
|
def test_node_invalid_dimension():
|
||||||
"""Test that initializing a Node with a non-positive dimension raises an error."""
|
"""Test that initializing a Node with a non-positive dimension raises an error."""
|
||||||
with pytest.raises(ValueError, match="Dimension must be a positive integer"):
|
with pytest.raises(InvalidValueError, match="Dimension must be a positive integer"):
|
||||||
Node("node1", dimension=0)
|
Node("node1", dimension=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -68,7 +69,7 @@ def test_is_node_alive_in_dimension():
|
||||||
def test_node_alive_invalid_dimension():
|
def test_node_alive_invalid_dimension():
|
||||||
"""Test that checking alive status with an invalid dimension raises an error."""
|
"""Test that checking alive status with an invalid dimension raises an error."""
|
||||||
node = Node("node1", dimension=1)
|
node = Node("node1", dimension=1)
|
||||||
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"):
|
||||||
node.is_node_alive_in_dimension(1)
|
node.is_node_alive_in_dimension(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -105,7 +106,7 @@ def test_edge_invalid_dimension():
|
||||||
"""Test that initializing an Edge with a non-positive dimension raises an error."""
|
"""Test that initializing an Edge with a non-positive dimension raises an error."""
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
node2 = Node("node2")
|
node2 = Node("node2")
|
||||||
with pytest.raises(ValueError, match="Dimensions must be a positive integer."):
|
with pytest.raises(InvalidValueError, match="Dimensions must be a positive integer."):
|
||||||
Edge(node1, node2, dimension=0)
|
Edge(node1, node2, dimension=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -124,7 +125,7 @@ def test_edge_alive_invalid_dimension():
|
||||||
node1 = Node("node1")
|
node1 = Node("node1")
|
||||||
node2 = Node("node2")
|
node2 = Node("node2")
|
||||||
edge = Edge(node1, node2, dimension=1)
|
edge = Edge(node1, node2, dimension=1)
|
||||||
with pytest.raises(ValueError, match="Dimension 1 is out of range"):
|
with pytest.raises(InvalidValueError, match="Dimension 1 is out of range"):
|
||||||
edge.is_edge_alive_in_dimension(1)
|
edge.is_edge_alive_in_dimension(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
|
||||||
|
|
@ -23,7 +24,7 @@ def test_add_duplicate_node(setup_graph):
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
node = Node("node1")
|
node = Node("node1")
|
||||||
graph.add_node(node)
|
graph.add_node(node)
|
||||||
with pytest.raises(ValueError, match="Node with id node1 already exists."):
|
with pytest.raises(EntityAlreadyExistsError, match="Node with id node1 already exists."):
|
||||||
graph.add_node(node)
|
graph.add_node(node)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -50,7 +51,7 @@ def test_add_duplicate_edge(setup_graph):
|
||||||
graph.add_node(node2)
|
graph.add_node(node2)
|
||||||
edge = Edge(node1, node2)
|
edge = Edge(node1, node2)
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
with pytest.raises(ValueError, match="Edge .* already exists in the graph."):
|
with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."):
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,5 +84,5 @@ def test_get_edges_success(setup_graph):
|
||||||
def test_get_edges_nonexistent_node(setup_graph):
|
def test_get_edges_nonexistent_node(setup_graph):
|
||||||
"""Test retrieving edges for a nonexistent node raises an exception."""
|
"""Test retrieving edges for a nonexistent node raises an exception."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
with pytest.raises(ValueError, match="Node with id nonexistent does not exist."):
|
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
||||||
graph.get_edges_from_node("nonexistent")
|
graph.get_edges_from_node("nonexistent")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue