fix: add error handling

This commit is contained in:
Boris Arzentar 2025-07-06 21:03:02 +02:00
parent f8f1bb3576
commit 685d282f5c
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
10 changed files with 265 additions and 120 deletions

View file

@ -296,13 +296,13 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
extract_graph_from_data, extract_graph_from_data,
graph_model=graph_model, graph_model=graph_model,
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path), ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
task_config={"batch_size": 10}, task_config={"batch_size": 25},
), # Generate knowledge graphs from the document chunks. ), # Generate knowledge graphs from the document chunks.
Task( Task(
summarize_text, summarize_text,
task_config={"batch_size": 10}, task_config={"batch_size": 25},
), ),
Task(add_data_points, task_config={"batch_size": 10}), Task(add_data_points, task_config={"batch_size": 25}),
] ]
return default_tasks return default_tasks

View file

@ -5,7 +5,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from tenacity import retry, retry_if_exception_type, stop_after_attempt from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
@ -113,9 +113,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
return False return False
@retry( @retry(
retry=retry_if_exception_type(Union[DuplicateTableError, UniqueViolationError]), retry=retry_if_exception_type((DuplicateTableError, UniqueViolationError)),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
sleep=1, wait=wait_exponential(multiplier=2, min=1, max=6),
) )
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
@ -159,7 +159,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@retry( @retry(
retry=retry_if_exception_type(DeadlockDetectedError), retry=retry_if_exception_type(DeadlockDetectedError),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
sleep=1, wait=wait_exponential(multiplier=2, min=1, max=6),
) )
@override_distributed(queued_add_data_points) @override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
@ -204,26 +204,26 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
for data_index, data_point in enumerate(data_points): for data_index, data_point in enumerate(data_points):
# Check to see if data should be updated or a new data item should be created # Check to see if data should be updated or a new data item should be created
data_point_db = ( # data_point_db = (
await session.execute( # await session.execute(
select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id) # select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
) # )
).scalar_one_or_none() # ).scalar_one_or_none()
# If data point exists update it, if not create a new one # If data point exists update it, if not create a new one
if data_point_db: # if data_point_db:
data_point_db.id = data_point.id # data_point_db.id = data_point.id
data_point_db.vector = data_vectors[data_index] # data_point_db.vector = data_vectors[data_index]
data_point_db.payload = serialize_data(data_point.model_dump()) # data_point_db.payload = serialize_data(data_point.model_dump())
pgvector_data_points.append(data_point_db) # pgvector_data_points.append(data_point_db)
else: # else:
pgvector_data_points.append( pgvector_data_points.append(
PGVectorDataPoint( PGVectorDataPoint(
id=data_point.id, id=data_point.id,
vector=data_vectors[data_index], vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()), payload=serialize_data(data_point.model_dump()),
)
) )
)
def to_dict(obj): def to_dict(obj):
return { return {

View file

@ -1,11 +1,16 @@
from typing import List, Optional from typing import List, Optional
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from distributed.utils import override_distributed
from distributed.tasks.queued_add_data_points import queued_add_data_points
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
@ -13,6 +18,18 @@ from ..vector_db_interface import VectorDBInterface
logger = get_logger("WeaviateAdapter") logger = get_logger("WeaviateAdapter")
def is_retryable_request(error):
from weaviate.exceptions import UnexpectedStatusCodeException
from requests.exceptions import RequestException
if isinstance(error, UnexpectedStatusCodeException):
# Retry on conflict, service unavailable, internal error
return error.status_code in {409, 503, 500}
if isinstance(error, RequestException):
return True # Includes timeout, connection error, etc.
return False
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
""" """
Define a schema for indexing data points with textual content. Define a schema for indexing data points with textual content.
@ -124,6 +141,11 @@ class WeaviateAdapter(VectorDBInterface):
client = await self.get_client() client = await self.get_client()
return await client.collections.exists(collection_name) return await client.collections.exists(collection_name)
@retry(
retry=retry_if_exception(is_retryable_request),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection( async def create_collection(
self, self,
collection_name: str, collection_name: str,
@ -184,6 +206,12 @@ class WeaviateAdapter(VectorDBInterface):
client = await self.get_client() client = await self.get_client()
return client.collections.get(collection_name) return client.collections.get(collection_name)
@retry(
retry=retry_if_exception(is_retryable_request),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
""" """
Create or update data points in the specified collection in the Weaviate database. Create or update data points in the specified collection in the Weaviate database.

View file

@ -6,6 +6,7 @@ from cognee.modules.data.models import Dataset
from cognee.modules.data.methods import create_dataset from cognee.modules.data.methods import create_dataset
from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.data.exceptions import DatasetNotFoundError from cognee.modules.data.exceptions import DatasetNotFoundError
from cognee.modules.users.permissions.methods import give_permission_on_dataset
async def load_or_create_datasets( async def load_or_create_datasets(
@ -45,6 +46,11 @@ async def load_or_create_datasets(
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
await create_dataset(identifier, user, session) await create_dataset(identifier, user, session)
await give_permission_on_dataset(user, new_dataset.id, "read")
await give_permission_on_dataset(user, new_dataset.id, "write")
await give_permission_on_dataset(user, new_dataset.id, "delete")
await give_permission_on_dataset(user, new_dataset.id, "share")
result.append(new_dataset) result.append(new_dataset)
return result return result

View file

@ -27,7 +27,7 @@ if modal:
@app.function( @app.function(
image=image, image=image,
timeout=86400, timeout=86400,
max_containers=100, max_containers=50,
secrets=[modal.Secret.from_name("distributed_cognee")], secrets=[modal.Secret.from_name("distributed_cognee")],
) )
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context): async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):

View file

@ -1,7 +1,7 @@
from uuid import UUID from uuid import UUID
from sqlalchemy.future import select from sqlalchemy.future import select
from asyncpg import UniqueViolationError from sqlalchemy.exc import IntegrityError
from tenacity import retry, retry_if_exception_type, stop_after_attempt from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.permissions import PERMISSION_TYPES from cognee.modules.users.permissions import PERMISSION_TYPES
@ -10,10 +10,14 @@ from cognee.modules.users.exceptions import PermissionNotFoundError
from ...models import Principal, ACL, Permission from ...models import Principal, ACL, Permission
class GivePermissionOnDatasetError(Exception):
message: str = "Failed to give permission on dataset"
@retry( @retry(
retry=retry_if_exception_type(UniqueViolationError), retry=retry_if_exception_type(GivePermissionOnDatasetError),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
sleep=1, wait=wait_exponential(multiplier=2, min=1, max=6),
) )
async def give_permission_on_dataset( async def give_permission_on_dataset(
principal: Principal, principal: Principal,
@ -50,6 +54,11 @@ async def give_permission_on_dataset(
# If no existing ACL entry is found, proceed to add a new one # If no existing ACL entry is found, proceed to add a new one
if existing_acl is None: if existing_acl is None:
acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission) try:
session.add(acl) acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission)
await session.commit() session.add(acl)
await session.commit()
except IntegrityError:
session.rollback()
raise GivePermissionOnDatasetError()

View file

@ -2,14 +2,19 @@ import json
import inspect import inspect
from uuid import UUID from uuid import UUID
from typing import Union, BinaryIO, Any, List, Optional from typing import Union, BinaryIO, Any, List, Optional
import cognee.modules.ingestion as ingestion import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name from cognee.modules.data.models import Data
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.models.DatasetData import DatasetData
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import give_permission_on_dataset from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.data.methods import (
get_authorized_existing_datasets,
get_dataset_data,
load_or_create_datasets,
)
from .save_data_item_to_storage import save_data_item_to_storage from .save_data_item_to_storage import save_data_item_to_storage
@ -56,20 +61,40 @@ async def ingest_data(
node_set: Optional[List[str]] = None, node_set: Optional[List[str]] = None,
dataset_id: UUID = None, dataset_id: UUID = None,
): ):
new_datapoints = []
existing_data_points = []
dataset_new_data_points = []
if not isinstance(data, list): if not isinstance(data, list):
# Convert data to a list as we work with lists further down. # Convert data to a list as we work with lists further down.
data = [data] data = [data]
file_paths = [] if dataset_id:
# Retrieve existing dataset
dataset = await get_specific_user_permission_datasets(user.id, "write", [dataset_id])
# Convert from list to Dataset element
if isinstance(dataset, list):
dataset = dataset[0]
else:
# Find existing dataset or create a new one
existing_datasets = await get_authorized_existing_datasets(
user=user, permission_type="write", datasets=[dataset_name]
)
dataset = await load_or_create_datasets(
dataset_names=[dataset_name],
existing_datasets=existing_datasets,
user=user,
)
if isinstance(dataset, list):
dataset = dataset[0]
dataset_data: list[Data] = await get_dataset_data(dataset.id)
dataset_data_map = {str(data.id): True for data in dataset_data}
# Process data
for data_item in data: for data_item in data:
file_path = await save_data_item_to_storage(data_item, dataset_name) file_path = await save_data_item_to_storage(data_item, dataset_name)
file_paths.append(file_path)
# Ingest data and add metadata # Ingest data and add metadata
# with open(file_path.replace("file://", ""), mode="rb") as file:
with open_data_file(file_path) as file: with open_data_file(file_path) as file:
classified_data = ingestion.classify(file, s3fs=fs) classified_data = ingestion.classify(file, s3fs=fs)
@ -80,90 +105,76 @@ async def ingest_data(
from sqlalchemy import select from sqlalchemy import select
from cognee.modules.data.models import Data
db_engine = get_relational_engine() db_engine = get_relational_engine()
# Check to see if data should be updated
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
if dataset_id:
# Retrieve existing dataset
dataset = await get_specific_user_permission_datasets(
user.id, "write", [dataset_id]
)
# Convert from list to Dataset element
if isinstance(dataset, list):
dataset = dataset[0]
dataset = await session.merge(
dataset
) # Add found dataset object into current session
else:
# Create new one
dataset = await create_dataset(dataset_name, user, session)
# Check to see if data should be updated
data_point = ( data_point = (
await session.execute(select(Data).filter(Data.id == data_id)) await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none() ).scalar_one_or_none()
ext_metadata = get_external_metadata_dict(data_item) ext_metadata = get_external_metadata_dict(data_item)
if node_set:
ext_metadata["node_set"] = node_set
if data_point is not None: if node_set:
data_point.name = file_metadata["name"] ext_metadata["node_set"] = node_set
data_point.raw_data_location = file_metadata["file_path"]
data_point.extension = file_metadata["extension"] if data_point is not None:
data_point.mime_type = file_metadata["mime_type"] data_point.name = file_metadata["name"]
data_point.owner_id = user.id data_point.raw_data_location = file_metadata["file_path"]
data_point.content_hash = file_metadata["content_hash"] data_point.extension = file_metadata["extension"]
data_point.external_metadata = ext_metadata data_point.mime_type = file_metadata["mime_type"]
data_point.node_set = json.dumps(node_set) if node_set else None data_point.owner_id = user.id
await session.merge(data_point) data_point.content_hash = file_metadata["content_hash"]
else: data_point.external_metadata = ext_metadata
data_point = Data( data_point.node_set = json.dumps(node_set) if node_set else None
id=data_id,
name=file_metadata["name"],
raw_data_location=file_metadata["file_path"],
extension=file_metadata["extension"],
mime_type=file_metadata["mime_type"],
owner_id=user.id,
content_hash=file_metadata["content_hash"],
external_metadata=ext_metadata,
node_set=json.dumps(node_set) if node_set else None,
token_count=-1,
)
session.add(data_point)
# Check if data is already in dataset # Check if data is already in dataset
dataset_data = ( if str(data_point.id) in dataset_data_map:
await session.execute( existing_data_points.append(data_point)
select(DatasetData).filter( else:
DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id dataset_new_data_points.append(data_point)
) dataset_data_map[str(data_point.id)] = True
) else:
).scalar_one_or_none() if str(data_id) in dataset_data_map:
# If data is not present in dataset add it continue
if dataset_data is None:
dataset.data.append(data_point)
await session.merge(dataset)
await session.commit() data_point = Data(
id=data_id,
name=file_metadata["name"],
raw_data_location=file_metadata["file_path"],
extension=file_metadata["extension"],
mime_type=file_metadata["mime_type"],
owner_id=user.id,
content_hash=file_metadata["content_hash"],
external_metadata=ext_metadata,
node_set=json.dumps(node_set) if node_set else None,
token_count=-1,
)
await give_permission_on_dataset(user, dataset.id, "read") new_datapoints.append(data_point)
await give_permission_on_dataset(user, dataset.id, "write") dataset_data_map[str(data_point.id)] = True
await give_permission_on_dataset(user, dataset.id, "delete")
await give_permission_on_dataset(user, dataset.id, "share")
return file_paths async with db_engine.get_async_session() as session:
if dataset not in session:
session.add(dataset)
await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id) if len(new_datapoints) > 0:
session.add_all(new_datapoints)
dataset.data.extend(new_datapoints)
datasets = await get_datasets_by_name(dataset_name, user.id) if len(existing_data_points) > 0:
for data_point in existing_data_points:
await session.merge(data_point)
# In case no files were processed no dataset will be created if len(dataset_new_data_points) > 0:
if datasets: for data_point in dataset_new_data_points:
dataset = datasets[0] await session.merge(data_point)
data_documents = await get_dataset_data(dataset_id=dataset.id) dataset.data.extend(dataset_new_data_points)
return data_documents
return [] await session.merge(dataset)
await session.commit()
return existing_data_points + dataset_new_data_points + new_datapoints
return await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)

View file

@ -23,7 +23,7 @@ async def main():
await add_nodes_and_edges_queue.clear.aio() await add_nodes_and_edges_queue.clear.aio()
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn
results = [] results = []
consumer_futures = [] consumer_futures = []
@ -44,7 +44,8 @@ async def main():
worker_future = data_point_saving_worker.spawn() worker_future = data_point_saving_worker.spawn()
consumer_futures.append(worker_future) consumer_futures.append(worker_future)
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1" # s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
s3_bucket_name = "s3://s3-test-laszlo/Pdf"
await cognee.add(s3_bucket_name, dataset_name="s3-files") await cognee.add(s3_bucket_name, dataset_name="s3-files")

View file

@ -1,6 +1,7 @@
import modal import modal
import asyncio import asyncio
from sqlalchemy.exc import OperationalError, DBAPIError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from distributed.app import app from distributed.app import app
from distributed.modal_image import image from distributed.modal_image import image
@ -13,6 +14,31 @@ from cognee.infrastructure.databases.vector import get_vector_engine
logger = get_logger("data_point_saving_worker") logger = get_logger("data_point_saving_worker")
class VectorDatabaseDeadlockError(Exception):
message = "A deadlock occurred while trying to add data points to the vector database."
def is_deadlock_error(error):
# SQLAlchemy-wrapped asyncpg
try:
import asyncpg
if isinstance(error.orig, asyncpg.exceptions.DeadlockDetectedError):
return True
except ImportError:
pass
# PostgreSQL: SQLSTATE 40P01 = deadlock_detected
if hasattr(error.orig, "pgcode") and error.orig.pgcode == "40P01":
return True
# SQLite: It doesn't support real deadlocks but may simulate them as "database is locked"
if "database is locked" in str(error.orig).lower():
return True
return False
@app.function( @app.function(
image=image, image=image,
timeout=86400, timeout=86400,
@ -40,9 +66,24 @@ async def data_point_saving_worker():
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.") print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
await vector_engine.create_data_points( @retry(
collection_name, data_points, distributed=False retry=retry_if_exception_type(VectorDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
) )
async def add_data_points():
try:
await vector_engine.create_data_points(
collection_name, data_points, distributed=False
)
except DBAPIError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
except OperationalError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
await add_data_points()
print("Finished adding data points.") print("Finished adding data points.")

View file

@ -1,6 +1,6 @@
import modal import modal
import asyncio import asyncio
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from distributed.app import app from distributed.app import app
from distributed.modal_image import image from distributed.modal_image import image
@ -8,11 +8,35 @@ from distributed.queues import add_nodes_and_edges_queue
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.config import get_graph_config
logger = get_logger("graph_saving_worker") logger = get_logger("graph_saving_worker")
class GraphDatabaseDeadlockError(Exception):
message = "A deadlock occurred while trying to add data points to the vector database."
def is_deadlock_error(error):
graph_config = get_graph_config()
if graph_config.graph_database_provider == "neo4j":
# Neo4j
from neo4j.exceptions import TransientError
if isinstance(error, TransientError) and (
error.code == "Neo.TransientError.Transaction.DeadlockDetected"
):
return True
# Kuzu
if "deadlock" in str(error).lower() or "cannot acquire lock" in str(error).lower():
return True
return False
@app.function( @app.function(
image=image, image=image,
timeout=86400, timeout=86400,
@ -42,11 +66,36 @@ async def graph_saving_worker():
nodes = nodes_and_edges[0] nodes = nodes_and_edges[0]
edges = nodes_and_edges[1] edges = nodes_and_edges[1]
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_nodes(new_nodes):
try:
await graph_engine.add_nodes(new_nodes, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_edges(new_edges):
try:
await graph_engine.add_edges(new_edges, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
if nodes: if nodes:
await graph_engine.add_nodes(nodes, distributed=False) await save_graph_nodes(nodes)
if edges: if edges:
await graph_engine.add_edges(edges, distributed=False) await save_graph_edges(edges)
print("Finished adding nodes and edges.") print("Finished adding nodes and edges.")
else: else: