fix: add error handling

This commit is contained in:
Boris Arzentar 2025-07-06 21:03:02 +02:00
parent f8f1bb3576
commit 685d282f5c
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
10 changed files with 265 additions and 120 deletions

View file

@ -296,13 +296,13 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
task_config={"batch_size": 10},
task_config={"batch_size": 25},
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
task_config={"batch_size": 10},
task_config={"batch_size": 25},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": 25}),
]
return default_tasks

View file

@ -5,7 +5,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from cognee.exceptions import InvalidValueError
@ -113,9 +113,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
return False
@retry(
retry=retry_if_exception_type(Union[DuplicateTableError, UniqueViolationError]),
retry=retry_if_exception_type((DuplicateTableError, UniqueViolationError)),
stop=stop_after_attempt(3),
sleep=1,
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
@ -159,7 +159,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
stop=stop_after_attempt(3),
sleep=1,
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
@ -204,26 +204,26 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
for data_index, data_point in enumerate(data_points):
# Check to see if data should be updated or a new data item should be created
data_point_db = (
await session.execute(
select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
)
).scalar_one_or_none()
# data_point_db = (
# await session.execute(
# select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
# )
# ).scalar_one_or_none()
# If data point exists update it, if not create a new one
if data_point_db:
data_point_db.id = data_point.id
data_point_db.vector = data_vectors[data_index]
data_point_db.payload = serialize_data(data_point.model_dump())
pgvector_data_points.append(data_point_db)
else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
# if data_point_db:
# data_point_db.id = data_point.id
# data_point_db.vector = data_vectors[data_index]
# data_point_db.payload = serialize_data(data_point.model_dump())
# pgvector_data_points.append(data_point_db)
# else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
)
def to_dict(obj):
return {

View file

@ -1,11 +1,16 @@
from typing import List, Optional
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from cognee.shared.logging_utils import get_logger
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from distributed.utils import override_distributed
from distributed.tasks.queued_add_data_points import queued_add_data_points
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
@ -13,6 +18,18 @@ from ..vector_db_interface import VectorDBInterface
logger = get_logger("WeaviateAdapter")
def is_retryable_request(error):
from weaviate.exceptions import UnexpectedStatusCodeException
from requests.exceptions import RequestException
if isinstance(error, UnexpectedStatusCodeException):
# Retry on conflict, service unavailable, internal error
return error.status_code in {409, 503, 500}
if isinstance(error, RequestException):
return True # Includes timeout, connection error, etc.
return False
class IndexSchema(DataPoint):
"""
Define a schema for indexing data points with textual content.
@ -124,6 +141,11 @@ class WeaviateAdapter(VectorDBInterface):
client = await self.get_client()
return await client.collections.exists(collection_name)
@retry(
retry=retry_if_exception(is_retryable_request),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(
self,
collection_name: str,
@ -184,6 +206,12 @@ class WeaviateAdapter(VectorDBInterface):
client = await self.get_client()
return client.collections.get(collection_name)
@retry(
retry=retry_if_exception(is_retryable_request),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
"""
Create or update data points in the specified collection in the Weaviate database.

View file

@ -6,6 +6,7 @@ from cognee.modules.data.models import Dataset
from cognee.modules.data.methods import create_dataset
from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.data.exceptions import DatasetNotFoundError
from cognee.modules.users.permissions.methods import give_permission_on_dataset
async def load_or_create_datasets(
@ -45,6 +46,11 @@ async def load_or_create_datasets(
async with db_engine.get_async_session() as session:
await create_dataset(identifier, user, session)
await give_permission_on_dataset(user, new_dataset.id, "read")
await give_permission_on_dataset(user, new_dataset.id, "write")
await give_permission_on_dataset(user, new_dataset.id, "delete")
await give_permission_on_dataset(user, new_dataset.id, "share")
result.append(new_dataset)
return result

View file

@ -27,7 +27,7 @@ if modal:
@app.function(
image=image,
timeout=86400,
max_containers=100,
max_containers=50,
secrets=[modal.Secret.from_name("distributed_cognee")],
)
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):

View file

@ -1,7 +1,7 @@
from uuid import UUID
from sqlalchemy.future import select
from asyncpg import UniqueViolationError
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from sqlalchemy.exc import IntegrityError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.permissions import PERMISSION_TYPES
@ -10,10 +10,14 @@ from cognee.modules.users.exceptions import PermissionNotFoundError
from ...models import Principal, ACL, Permission
class GivePermissionOnDatasetError(Exception):
message: str = "Failed to give permission on dataset"
@retry(
retry=retry_if_exception_type(UniqueViolationError),
retry=retry_if_exception_type(GivePermissionOnDatasetError),
stop=stop_after_attempt(3),
sleep=1,
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def give_permission_on_dataset(
principal: Principal,
@ -50,6 +54,11 @@ async def give_permission_on_dataset(
# If no existing ACL entry is found, proceed to add a new one
if existing_acl is None:
acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission)
session.add(acl)
await session.commit()
try:
acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission)
session.add(acl)
await session.commit()
except IntegrityError:
session.rollback()
raise GivePermissionOnDatasetError()

View file

@ -2,14 +2,19 @@ import json
import inspect
from uuid import UUID
from typing import Union, BinaryIO, Any, List, Optional
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.models.DatasetData import DatasetData
from cognee.modules.data.models import Data
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import give_permission_on_dataset
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.data.methods import (
get_authorized_existing_datasets,
get_dataset_data,
load_or_create_datasets,
)
from .save_data_item_to_storage import save_data_item_to_storage
@ -56,20 +61,40 @@ async def ingest_data(
node_set: Optional[List[str]] = None,
dataset_id: UUID = None,
):
new_datapoints = []
existing_data_points = []
dataset_new_data_points = []
if not isinstance(data, list):
# Convert data to a list as we work with lists further down.
data = [data]
file_paths = []
if dataset_id:
# Retrieve existing dataset
dataset = await get_specific_user_permission_datasets(user.id, "write", [dataset_id])
# Convert from list to Dataset element
if isinstance(dataset, list):
dataset = dataset[0]
else:
# Find existing dataset or create a new one
existing_datasets = await get_authorized_existing_datasets(
user=user, permission_type="write", datasets=[dataset_name]
)
dataset = await load_or_create_datasets(
dataset_names=[dataset_name],
existing_datasets=existing_datasets,
user=user,
)
if isinstance(dataset, list):
dataset = dataset[0]
dataset_data: list[Data] = await get_dataset_data(dataset.id)
dataset_data_map = {str(data.id): True for data in dataset_data}
# Process data
for data_item in data:
file_path = await save_data_item_to_storage(data_item, dataset_name)
file_paths.append(file_path)
# Ingest data and add metadata
# with open(file_path.replace("file://", ""), mode="rb") as file:
with open_data_file(file_path) as file:
classified_data = ingestion.classify(file, s3fs=fs)
@ -80,90 +105,76 @@ async def ingest_data(
from sqlalchemy import select
from cognee.modules.data.models import Data
db_engine = get_relational_engine()
# Check to see if data should be updated
async with db_engine.get_async_session() as session:
if dataset_id:
# Retrieve existing dataset
dataset = await get_specific_user_permission_datasets(
user.id, "write", [dataset_id]
)
# Convert from list to Dataset element
if isinstance(dataset, list):
dataset = dataset[0]
dataset = await session.merge(
dataset
) # Add found dataset object into current session
else:
# Create new one
dataset = await create_dataset(dataset_name, user, session)
# Check to see if data should be updated
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
ext_metadata = get_external_metadata_dict(data_item)
if node_set:
ext_metadata["node_set"] = node_set
ext_metadata = get_external_metadata_dict(data_item)
if data_point is not None:
data_point.name = file_metadata["name"]
data_point.raw_data_location = file_metadata["file_path"]
data_point.extension = file_metadata["extension"]
data_point.mime_type = file_metadata["mime_type"]
data_point.owner_id = user.id
data_point.content_hash = file_metadata["content_hash"]
data_point.external_metadata = ext_metadata
data_point.node_set = json.dumps(node_set) if node_set else None
await session.merge(data_point)
else:
data_point = Data(
id=data_id,
name=file_metadata["name"],
raw_data_location=file_metadata["file_path"],
extension=file_metadata["extension"],
mime_type=file_metadata["mime_type"],
owner_id=user.id,
content_hash=file_metadata["content_hash"],
external_metadata=ext_metadata,
node_set=json.dumps(node_set) if node_set else None,
token_count=-1,
)
session.add(data_point)
if node_set:
ext_metadata["node_set"] = node_set
if data_point is not None:
data_point.name = file_metadata["name"]
data_point.raw_data_location = file_metadata["file_path"]
data_point.extension = file_metadata["extension"]
data_point.mime_type = file_metadata["mime_type"]
data_point.owner_id = user.id
data_point.content_hash = file_metadata["content_hash"]
data_point.external_metadata = ext_metadata
data_point.node_set = json.dumps(node_set) if node_set else None
# Check if data is already in dataset
dataset_data = (
await session.execute(
select(DatasetData).filter(
DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id
)
)
).scalar_one_or_none()
# If data is not present in dataset add it
if dataset_data is None:
dataset.data.append(data_point)
await session.merge(dataset)
if str(data_point.id) in dataset_data_map:
existing_data_points.append(data_point)
else:
dataset_new_data_points.append(data_point)
dataset_data_map[str(data_point.id)] = True
else:
if str(data_id) in dataset_data_map:
continue
await session.commit()
data_point = Data(
id=data_id,
name=file_metadata["name"],
raw_data_location=file_metadata["file_path"],
extension=file_metadata["extension"],
mime_type=file_metadata["mime_type"],
owner_id=user.id,
content_hash=file_metadata["content_hash"],
external_metadata=ext_metadata,
node_set=json.dumps(node_set) if node_set else None,
token_count=-1,
)
await give_permission_on_dataset(user, dataset.id, "read")
await give_permission_on_dataset(user, dataset.id, "write")
await give_permission_on_dataset(user, dataset.id, "delete")
await give_permission_on_dataset(user, dataset.id, "share")
new_datapoints.append(data_point)
dataset_data_map[str(data_point.id)] = True
return file_paths
async with db_engine.get_async_session() as session:
if dataset not in session:
session.add(dataset)
await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)
if len(new_datapoints) > 0:
session.add_all(new_datapoints)
dataset.data.extend(new_datapoints)
datasets = await get_datasets_by_name(dataset_name, user.id)
if len(existing_data_points) > 0:
for data_point in existing_data_points:
await session.merge(data_point)
# In case no files were processed no dataset will be created
if datasets:
dataset = datasets[0]
data_documents = await get_dataset_data(dataset_id=dataset.id)
return data_documents
return []
if len(dataset_new_data_points) > 0:
for data_point in dataset_new_data_points:
await session.merge(data_point)
dataset.data.extend(dataset_new_data_points)
await session.merge(dataset)
await session.commit()
return existing_data_points + dataset_new_data_points + new_datapoints
return await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)

View file

@ -23,7 +23,7 @@ async def main():
await add_nodes_and_edges_queue.clear.aio()
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn
number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn
results = []
consumer_futures = []
@ -44,7 +44,8 @@ async def main():
worker_future = data_point_saving_worker.spawn()
consumer_futures.append(worker_future)
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
# s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
s3_bucket_name = "s3://s3-test-laszlo/Pdf"
await cognee.add(s3_bucket_name, dataset_name="s3-files")

View file

@ -1,6 +1,7 @@
import modal
import asyncio
from sqlalchemy.exc import OperationalError, DBAPIError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from distributed.app import app
from distributed.modal_image import image
@ -13,6 +14,31 @@ from cognee.infrastructure.databases.vector import get_vector_engine
logger = get_logger("data_point_saving_worker")
class VectorDatabaseDeadlockError(Exception):
message = "A deadlock occurred while trying to add data points to the vector database."
def is_deadlock_error(error):
# SQLAlchemy-wrapped asyncpg
try:
import asyncpg
if isinstance(error.orig, asyncpg.exceptions.DeadlockDetectedError):
return True
except ImportError:
pass
# PostgreSQL: SQLSTATE 40P01 = deadlock_detected
if hasattr(error.orig, "pgcode") and error.orig.pgcode == "40P01":
return True
# SQLite: It doesn't support real deadlocks but may simulate them as "database is locked"
if "database is locked" in str(error.orig).lower():
return True
return False
@app.function(
image=image,
timeout=86400,
@ -40,9 +66,24 @@ async def data_point_saving_worker():
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
await vector_engine.create_data_points(
collection_name, data_points, distributed=False
@retry(
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def add_data_points():
try:
await vector_engine.create_data_points(
collection_name, data_points, distributed=False
)
except DBAPIError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
except OperationalError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
await add_data_points()
print("Finished adding data points.")

View file

@ -1,6 +1,6 @@
import modal
import asyncio
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from distributed.app import app
from distributed.modal_image import image
@ -8,11 +8,35 @@ from distributed.queues import add_nodes_and_edges_queue
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.config import get_graph_config
logger = get_logger("graph_saving_worker")
class GraphDatabaseDeadlockError(Exception):
message = "A deadlock occurred while trying to add data points to the vector database."
def is_deadlock_error(error):
graph_config = get_graph_config()
if graph_config.graph_database_provider == "neo4j":
# Neo4j
from neo4j.exceptions import TransientError
if isinstance(error, TransientError) and (
error.code == "Neo.TransientError.Transaction.DeadlockDetected"
):
return True
# Kuzu
if "deadlock" in str(error).lower() or "cannot acquire lock" in str(error).lower():
return True
return False
@app.function(
image=image,
timeout=86400,
@ -42,11 +66,36 @@ async def graph_saving_worker():
nodes = nodes_and_edges[0]
edges = nodes_and_edges[1]
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_nodes(new_nodes):
try:
await graph_engine.add_nodes(new_nodes, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_edges(new_edges):
try:
await graph_engine.add_edges(new_edges, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
if nodes:
await graph_engine.add_nodes(nodes, distributed=False)
await save_graph_nodes(nodes)
if edges:
await graph_engine.add_edges(edges, distributed=False)
await save_graph_edges(edges)
print("Finished adding nodes and edges.")
else: