diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index bed200e13..1df9560ae 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -296,13 +296,13 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's extract_graph_from_data, graph_model=graph_model, ontology_adapter=OntologyResolver(ontology_file=ontology_file_path), - task_config={"batch_size": 10}, + task_config={"batch_size": 25}, ), # Generate knowledge graphs from the document chunks. Task( summarize_text, - task_config={"batch_size": 10}, + task_config={"batch_size": 25}, ), - Task(add_data_points, task_config={"batch_size": 10}), + Task(add_data_points, task_config={"batch_size": 25}), ] return default_tasks diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 530fceb66..b5ab3ca40 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.dialects.postgresql import insert from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker -from tenacity import retry, retry_if_exception_type, stop_after_attempt +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError from cognee.exceptions import InvalidValueError @@ -113,9 +113,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): return False @retry( - retry=retry_if_exception_type(Union[DuplicateTableError, UniqueViolationError]), + retry=retry_if_exception_type((DuplicateTableError, UniqueViolationError)), stop=stop_after_attempt(3), - sleep=1, + wait=wait_exponential(multiplier=2, min=1, max=6), ) async def create_collection(self, collection_name: str, payload_schema=None): data_point_types = get_type_hints(DataPoint) @@ -159,7 +159,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): @retry( retry=retry_if_exception_type(DeadlockDetectedError), stop=stop_after_attempt(3), - sleep=1, + wait=wait_exponential(multiplier=2, min=1, max=6), ) @override_distributed(queued_add_data_points) async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): @@ -204,26 +204,26 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): for data_index, data_point in enumerate(data_points): # Check to see if data should be updated or a new data item should be created - data_point_db = ( - await session.execute( - select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id) - ) - ).scalar_one_or_none() + # data_point_db = ( + # await session.execute( + # select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id) + # ) + # ).scalar_one_or_none() # If data point exists update it, if not create a new one - if data_point_db: - data_point_db.id = data_point.id - data_point_db.vector = data_vectors[data_index] - data_point_db.payload = serialize_data(data_point.model_dump()) - pgvector_data_points.append(data_point_db) - else: - pgvector_data_points.append( - PGVectorDataPoint( - id=data_point.id, - vector=data_vectors[data_index], - payload=serialize_data(data_point.model_dump()), - ) + # if data_point_db: + # data_point_db.id = data_point.id + # data_point_db.vector = data_vectors[data_index] + # data_point_db.payload = serialize_data(data_point.model_dump()) + # pgvector_data_points.append(data_point_db) + # else: + pgvector_data_points.append( + PGVectorDataPoint( + id=data_point.id, + vector=data_vectors[data_index], + payload=serialize_data(data_point.model_dump()), ) + ) def to_dict(obj): return { diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index 8b009bcea..00a6a0411 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -1,11 +1,16 @@ from typing import List, Optional +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential + from cognee.shared.logging_utils import get_logger from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError +from distributed.utils import override_distributed +from distributed.tasks.queued_add_data_points import queued_add_data_points + from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..models.ScoredResult import ScoredResult from ..vector_db_interface import VectorDBInterface @@ -13,6 +18,18 @@ from ..vector_db_interface import VectorDBInterface logger = get_logger("WeaviateAdapter") +def is_retryable_request(error): + from weaviate.exceptions import UnexpectedStatusCodeException + from requests.exceptions import RequestException + + if isinstance(error, UnexpectedStatusCodeException): + # Retry on conflict, service unavailable, internal error + return error.status_code in {409, 503, 500} + if isinstance(error, RequestException): + return True # Includes timeout, connection error, etc. + return False + + class IndexSchema(DataPoint): """ Define a schema for indexing data points with textual content. @@ -124,6 +141,11 @@ class WeaviateAdapter(VectorDBInterface): client = await self.get_client() return await client.collections.exists(collection_name) + @retry( + retry=retry_if_exception(is_retryable_request), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2, min=1, max=6), + ) async def create_collection( self, collection_name: str, @@ -184,6 +206,12 @@ class WeaviateAdapter(VectorDBInterface): client = await self.get_client() return client.collections.get(collection_name) + @retry( + retry=retry_if_exception(is_retryable_request), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2, min=1, max=6), + ) + @override_distributed(queued_add_data_points) async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): """ Create or update data points in the specified collection in the Weaviate database. diff --git a/cognee/modules/data/methods/load_or_create_datasets.py b/cognee/modules/data/methods/load_or_create_datasets.py index c2dfc201d..43cd0513c 100644 --- a/cognee/modules/data/methods/load_or_create_datasets.py +++ b/cognee/modules/data/methods/load_or_create_datasets.py @@ -6,6 +6,7 @@ from cognee.modules.data.models import Dataset from cognee.modules.data.methods import create_dataset from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.data.exceptions import DatasetNotFoundError +from cognee.modules.users.permissions.methods import give_permission_on_dataset async def load_or_create_datasets( @@ -45,6 +46,11 @@ async def load_or_create_datasets( async with db_engine.get_async_session() as session: await create_dataset(identifier, user, session) + await give_permission_on_dataset(user, new_dataset.id, "read") + await give_permission_on_dataset(user, new_dataset.id, "write") + await give_permission_on_dataset(user, new_dataset.id, "delete") + await give_permission_on_dataset(user, new_dataset.id, "share") + result.append(new_dataset) return result diff --git a/cognee/modules/pipelines/operations/run_tasks_distributed.py b/cognee/modules/pipelines/operations/run_tasks_distributed.py index c0ec3cc93..62ba2184f 100644 --- a/cognee/modules/pipelines/operations/run_tasks_distributed.py +++ b/cognee/modules/pipelines/operations/run_tasks_distributed.py @@ -27,7 +27,7 @@ if modal: @app.function( image=image, timeout=86400, - max_containers=100, + max_containers=50, secrets=[modal.Secret.from_name("distributed_cognee")], ) async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context): diff --git a/cognee/modules/users/permissions/methods/give_permission_on_dataset.py b/cognee/modules/users/permissions/methods/give_permission_on_dataset.py index daa6aae6c..0ed536981 100644 --- a/cognee/modules/users/permissions/methods/give_permission_on_dataset.py +++ b/cognee/modules/users/permissions/methods/give_permission_on_dataset.py @@ -1,7 +1,7 @@ from uuid import UUID from sqlalchemy.future import select -from asyncpg import UniqueViolationError -from tenacity import retry, retry_if_exception_type, stop_after_attempt +from sqlalchemy.exc import IntegrityError +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.permissions import PERMISSION_TYPES @@ -10,10 +10,14 @@ from cognee.modules.users.exceptions import PermissionNotFoundError from ...models import Principal, ACL, Permission +class GivePermissionOnDatasetError(Exception): + message: str = "Failed to give permission on dataset" + + @retry( - retry=retry_if_exception_type(UniqueViolationError), + retry=retry_if_exception_type(GivePermissionOnDatasetError), stop=stop_after_attempt(3), - sleep=1, + wait=wait_exponential(multiplier=2, min=1, max=6), ) async def give_permission_on_dataset( principal: Principal, @@ -50,6 +54,11 @@ async def give_permission_on_dataset( # If no existing ACL entry is found, proceed to add a new one if existing_acl is None: - acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission) - session.add(acl) - await session.commit() + try: + acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission) + session.add(acl) + await session.commit() + except IntegrityError: + session.rollback() + + raise GivePermissionOnDatasetError() diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index 9fd45434a..0dae5412b 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -2,14 +2,19 @@ import json import inspect from uuid import UUID from typing import Union, BinaryIO, Any, List, Optional + import cognee.modules.ingestion as ingestion from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name -from cognee.modules.users.methods import get_default_user -from cognee.modules.data.models.DatasetData import DatasetData +from cognee.modules.data.models import Data from cognee.modules.users.models import User -from cognee.modules.users.permissions.methods import give_permission_on_dataset +from cognee.modules.users.methods import get_default_user from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets +from cognee.modules.data.methods import ( + get_authorized_existing_datasets, + get_dataset_data, + load_or_create_datasets, +) + from .save_data_item_to_storage import save_data_item_to_storage @@ -56,20 +61,40 @@ async def ingest_data( node_set: Optional[List[str]] = None, dataset_id: UUID = None, ): + new_datapoints = [] + existing_data_points = [] + dataset_new_data_points = [] + if not isinstance(data, list): # Convert data to a list as we work with lists further down. data = [data] - file_paths = [] + if dataset_id: + # Retrieve existing dataset + dataset = await get_specific_user_permission_datasets(user.id, "write", [dataset_id]) + # Convert from list to Dataset element + if isinstance(dataset, list): + dataset = dataset[0] + else: + # Find existing dataset or create a new one + existing_datasets = await get_authorized_existing_datasets( + user=user, permission_type="write", datasets=[dataset_name] + ) + dataset = await load_or_create_datasets( + dataset_names=[dataset_name], + existing_datasets=existing_datasets, + user=user, + ) + if isinstance(dataset, list): + dataset = dataset[0] + + dataset_data: list[Data] = await get_dataset_data(dataset.id) + dataset_data_map = {str(data.id): True for data in dataset_data} - # Process data for data_item in data: file_path = await save_data_item_to_storage(data_item, dataset_name) - file_paths.append(file_path) - # Ingest data and add metadata - # with open(file_path.replace("file://", ""), mode="rb") as file: with open_data_file(file_path) as file: classified_data = ingestion.classify(file, s3fs=fs) @@ -80,90 +105,76 @@ async def ingest_data( from sqlalchemy import select - from cognee.modules.data.models import Data - db_engine = get_relational_engine() + # Check to see if data should be updated async with db_engine.get_async_session() as session: - if dataset_id: - # Retrieve existing dataset - dataset = await get_specific_user_permission_datasets( - user.id, "write", [dataset_id] - ) - # Convert from list to Dataset element - if isinstance(dataset, list): - dataset = dataset[0] - - dataset = await session.merge( - dataset - ) # Add found dataset object into current session - else: - # Create new one - dataset = await create_dataset(dataset_name, user, session) - - # Check to see if data should be updated data_point = ( await session.execute(select(Data).filter(Data.id == data_id)) ).scalar_one_or_none() - ext_metadata = get_external_metadata_dict(data_item) - if node_set: - ext_metadata["node_set"] = node_set + ext_metadata = get_external_metadata_dict(data_item) - if data_point is not None: - data_point.name = file_metadata["name"] - data_point.raw_data_location = file_metadata["file_path"] - data_point.extension = file_metadata["extension"] - data_point.mime_type = file_metadata["mime_type"] - data_point.owner_id = user.id - data_point.content_hash = file_metadata["content_hash"] - data_point.external_metadata = ext_metadata - data_point.node_set = json.dumps(node_set) if node_set else None - await session.merge(data_point) - else: - data_point = Data( - id=data_id, - name=file_metadata["name"], - raw_data_location=file_metadata["file_path"], - extension=file_metadata["extension"], - mime_type=file_metadata["mime_type"], - owner_id=user.id, - content_hash=file_metadata["content_hash"], - external_metadata=ext_metadata, - node_set=json.dumps(node_set) if node_set else None, - token_count=-1, - ) - session.add(data_point) + if node_set: + ext_metadata["node_set"] = node_set + + if data_point is not None: + data_point.name = file_metadata["name"] + data_point.raw_data_location = file_metadata["file_path"] + data_point.extension = file_metadata["extension"] + data_point.mime_type = file_metadata["mime_type"] + data_point.owner_id = user.id + data_point.content_hash = file_metadata["content_hash"] + data_point.external_metadata = ext_metadata + data_point.node_set = json.dumps(node_set) if node_set else None # Check if data is already in dataset - dataset_data = ( - await session.execute( - select(DatasetData).filter( - DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id - ) - ) - ).scalar_one_or_none() - # If data is not present in dataset add it - if dataset_data is None: - dataset.data.append(data_point) - await session.merge(dataset) + if str(data_point.id) in dataset_data_map: + existing_data_points.append(data_point) + else: + dataset_new_data_points.append(data_point) + dataset_data_map[str(data_point.id)] = True + else: + if str(data_id) in dataset_data_map: + continue - await session.commit() + data_point = Data( + id=data_id, + name=file_metadata["name"], + raw_data_location=file_metadata["file_path"], + extension=file_metadata["extension"], + mime_type=file_metadata["mime_type"], + owner_id=user.id, + content_hash=file_metadata["content_hash"], + external_metadata=ext_metadata, + node_set=json.dumps(node_set) if node_set else None, + token_count=-1, + ) - await give_permission_on_dataset(user, dataset.id, "read") - await give_permission_on_dataset(user, dataset.id, "write") - await give_permission_on_dataset(user, dataset.id, "delete") - await give_permission_on_dataset(user, dataset.id, "share") + new_datapoints.append(data_point) + dataset_data_map[str(data_point.id)] = True - return file_paths + async with db_engine.get_async_session() as session: + if dataset not in session: + session.add(dataset) - await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id) + if len(new_datapoints) > 0: + session.add_all(new_datapoints) + dataset.data.extend(new_datapoints) - datasets = await get_datasets_by_name(dataset_name, user.id) + if len(existing_data_points) > 0: + for data_point in existing_data_points: + await session.merge(data_point) - # In case no files were processed no dataset will be created - if datasets: - dataset = datasets[0] - data_documents = await get_dataset_data(dataset_id=dataset.id) - return data_documents - return [] + if len(dataset_new_data_points) > 0: + for data_point in dataset_new_data_points: + await session.merge(data_point) + dataset.data.extend(dataset_new_data_points) + + await session.merge(dataset) + + await session.commit() + + return existing_data_points + dataset_new_data_points + new_datapoints + + return await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id) diff --git a/distributed/entrypoint.py b/distributed/entrypoint.py index 7b6f56911..38b78d9dc 100644 --- a/distributed/entrypoint.py +++ b/distributed/entrypoint.py @@ -23,7 +23,7 @@ async def main(): await add_nodes_and_edges_queue.clear.aio() number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn - number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn + number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn results = [] consumer_futures = [] @@ -44,7 +44,8 @@ async def main(): worker_future = data_point_saving_worker.spawn() consumer_futures.append(worker_future) - s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1" + # s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1" + s3_bucket_name = "s3://s3-test-laszlo/Pdf" await cognee.add(s3_bucket_name, dataset_name="s3-files") diff --git a/distributed/workers/data_point_saving_worker.py b/distributed/workers/data_point_saving_worker.py index 4c85f7f32..5d8935b95 100644 --- a/distributed/workers/data_point_saving_worker.py +++ b/distributed/workers/data_point_saving_worker.py @@ -1,6 +1,7 @@ import modal import asyncio - +from sqlalchemy.exc import OperationalError, DBAPIError +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from distributed.app import app from distributed.modal_image import image @@ -13,6 +14,31 @@ from cognee.infrastructure.databases.vector import get_vector_engine logger = get_logger("data_point_saving_worker") +class VectorDatabaseDeadlockError(Exception): + message = "A deadlock occurred while trying to add data points to the vector database." + + +def is_deadlock_error(error): + # SQLAlchemy-wrapped asyncpg + try: + import asyncpg + + if isinstance(error.orig, asyncpg.exceptions.DeadlockDetectedError): + return True + except ImportError: + pass + + # PostgreSQL: SQLSTATE 40P01 = deadlock_detected + if hasattr(error.orig, "pgcode") and error.orig.pgcode == "40P01": + return True + + # SQLite: It doesn't support real deadlocks but may simulate them as "database is locked" + if "database is locked" in str(error.orig).lower(): + return True + + return False + + @app.function( image=image, timeout=86400, @@ -40,9 +66,24 @@ async def data_point_saving_worker(): print(f"Adding {len(data_points)} data points to '{collection_name}' collection.") - await vector_engine.create_data_points( - collection_name, data_points, distributed=False + @retry( + retry=retry_if_exception_type(VectorDatabaseDeadlockError), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2, min=1, max=6), ) + async def add_data_points(): + try: + await vector_engine.create_data_points( + collection_name, data_points, distributed=False + ) + except DBAPIError as error: + if is_deadlock_error(error): + raise VectorDatabaseDeadlockError() + except OperationalError as error: + if is_deadlock_error(error): + raise VectorDatabaseDeadlockError() + + await add_data_points() print("Finished adding data points.") diff --git a/distributed/workers/graph_saving_worker.py b/distributed/workers/graph_saving_worker.py index 958ba4727..6525595cf 100644 --- a/distributed/workers/graph_saving_worker.py +++ b/distributed/workers/graph_saving_worker.py @@ -1,6 +1,6 @@ import modal import asyncio - +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from distributed.app import app from distributed.modal_image import image @@ -8,11 +8,35 @@ from distributed.queues import add_nodes_and_edges_queue from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.graph.config import get_graph_config logger = get_logger("graph_saving_worker") +class GraphDatabaseDeadlockError(Exception): + message = "A deadlock occurred while trying to add data points to the vector database." + + +def is_deadlock_error(error): + graph_config = get_graph_config() + + if graph_config.graph_database_provider == "neo4j": + # Neo4j + from neo4j.exceptions import TransientError + + if isinstance(error, TransientError) and ( + error.code == "Neo.TransientError.Transaction.DeadlockDetected" + ): + return True + + # Kuzu + if "deadlock" in str(error).lower() or "cannot acquire lock" in str(error).lower(): + return True + + return False + + @app.function( image=image, timeout=86400, @@ -42,11 +66,36 @@ async def graph_saving_worker(): nodes = nodes_and_edges[0] edges = nodes_and_edges[1] + @retry( + retry=retry_if_exception_type(GraphDatabaseDeadlockError), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2, min=1, max=6), + ) + async def save_graph_nodes(new_nodes): + try: + await graph_engine.add_nodes(new_nodes, distributed=False) + except Exception as error: + if is_deadlock_error(error): + raise GraphDatabaseDeadlockError() + + @retry( + retry=retry_if_exception_type(GraphDatabaseDeadlockError), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2, min=1, max=6), + ) + async def save_graph_edges(new_edges): + try: + await graph_engine.add_edges(new_edges, distributed=False) + except Exception as error: + if is_deadlock_error(error): + raise GraphDatabaseDeadlockError() + if nodes: - await graph_engine.add_nodes(nodes, distributed=False) + await save_graph_nodes(nodes) if edges: - await graph_engine.add_edges(edges, distributed=False) + await save_graph_edges(edges) + print("Finished adding nodes and edges.") else: