fix: add error handling
This commit is contained in:
parent
f8f1bb3576
commit
685d282f5c
10 changed files with 265 additions and 120 deletions
|
|
@ -296,13 +296,13 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
extract_graph_from_data,
|
||||
graph_model=graph_model,
|
||||
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
||||
task_config={"batch_size": 10},
|
||||
task_config={"batch_size": 25},
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
task_config={"batch_size": 10},
|
||||
task_config={"batch_size": 25},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(add_data_points, task_config={"batch_size": 25}),
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
|
|
@ -113,9 +113,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
return False
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(Union[DuplicateTableError, UniqueViolationError]),
|
||||
retry=retry_if_exception_type((DuplicateTableError, UniqueViolationError)),
|
||||
stop=stop_after_attempt(3),
|
||||
sleep=1,
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
|
|
@ -159,7 +159,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
@retry(
|
||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||
stop=stop_after_attempt(3),
|
||||
sleep=1,
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
@override_distributed(queued_add_data_points)
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
|
|
@ -204,26 +204,26 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
for data_index, data_point in enumerate(data_points):
|
||||
# Check to see if data should be updated or a new data item should be created
|
||||
data_point_db = (
|
||||
await session.execute(
|
||||
select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
# data_point_db = (
|
||||
# await session.execute(
|
||||
# select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
|
||||
# )
|
||||
# ).scalar_one_or_none()
|
||||
|
||||
# If data point exists update it, if not create a new one
|
||||
if data_point_db:
|
||||
data_point_db.id = data_point.id
|
||||
data_point_db.vector = data_vectors[data_index]
|
||||
data_point_db.payload = serialize_data(data_point.model_dump())
|
||||
pgvector_data_points.append(data_point_db)
|
||||
else:
|
||||
pgvector_data_points.append(
|
||||
PGVectorDataPoint(
|
||||
id=data_point.id,
|
||||
vector=data_vectors[data_index],
|
||||
payload=serialize_data(data_point.model_dump()),
|
||||
)
|
||||
# if data_point_db:
|
||||
# data_point_db.id = data_point.id
|
||||
# data_point_db.vector = data_vectors[data_index]
|
||||
# data_point_db.payload = serialize_data(data_point.model_dump())
|
||||
# pgvector_data_points.append(data_point_db)
|
||||
# else:
|
||||
pgvector_data_points.append(
|
||||
PGVectorDataPoint(
|
||||
id=data_point.id,
|
||||
vector=data_vectors[data_index],
|
||||
payload=serialize_data(data_point.model_dump()),
|
||||
)
|
||||
)
|
||||
|
||||
def to_dict(obj):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
from distributed.utils import override_distributed
|
||||
from distributed.tasks.queued_add_data_points import queued_add_data_points
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
|
@ -13,6 +18,18 @@ from ..vector_db_interface import VectorDBInterface
|
|||
logger = get_logger("WeaviateAdapter")
|
||||
|
||||
|
||||
def is_retryable_request(error):
|
||||
from weaviate.exceptions import UnexpectedStatusCodeException
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
if isinstance(error, UnexpectedStatusCodeException):
|
||||
# Retry on conflict, service unavailable, internal error
|
||||
return error.status_code in {409, 503, 500}
|
||||
if isinstance(error, RequestException):
|
||||
return True # Includes timeout, connection error, etc.
|
||||
return False
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
"""
|
||||
Define a schema for indexing data points with textual content.
|
||||
|
|
@ -124,6 +141,11 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
client = await self.get_client()
|
||||
return await client.collections.exists(collection_name)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(is_retryable_request),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
|
|
@ -184,6 +206,12 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
client = await self.get_client()
|
||||
return client.collections.get(collection_name)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(is_retryable_request),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
@override_distributed(queued_add_data_points)
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
"""
|
||||
Create or update data points in the specified collection in the Weaviate database.
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from cognee.modules.data.models import Dataset
|
|||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.modules.data.methods import get_unique_dataset_id
|
||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||
from cognee.modules.users.permissions.methods import give_permission_on_dataset
|
||||
|
||||
|
||||
async def load_or_create_datasets(
|
||||
|
|
@ -45,6 +46,11 @@ async def load_or_create_datasets(
|
|||
async with db_engine.get_async_session() as session:
|
||||
await create_dataset(identifier, user, session)
|
||||
|
||||
await give_permission_on_dataset(user, new_dataset.id, "read")
|
||||
await give_permission_on_dataset(user, new_dataset.id, "write")
|
||||
await give_permission_on_dataset(user, new_dataset.id, "delete")
|
||||
await give_permission_on_dataset(user, new_dataset.id, "share")
|
||||
|
||||
result.append(new_dataset)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ if modal:
|
|||
@app.function(
|
||||
image=image,
|
||||
timeout=86400,
|
||||
max_containers=100,
|
||||
max_containers=50,
|
||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||
)
|
||||
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy.future import select
|
||||
from asyncpg import UniqueViolationError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.permissions import PERMISSION_TYPES
|
||||
|
|
@ -10,10 +10,14 @@ from cognee.modules.users.exceptions import PermissionNotFoundError
|
|||
from ...models import Principal, ACL, Permission
|
||||
|
||||
|
||||
class GivePermissionOnDatasetError(Exception):
|
||||
message: str = "Failed to give permission on dataset"
|
||||
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(UniqueViolationError),
|
||||
retry=retry_if_exception_type(GivePermissionOnDatasetError),
|
||||
stop=stop_after_attempt(3),
|
||||
sleep=1,
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def give_permission_on_dataset(
|
||||
principal: Principal,
|
||||
|
|
@ -50,6 +54,11 @@ async def give_permission_on_dataset(
|
|||
|
||||
# If no existing ACL entry is found, proceed to add a new one
|
||||
if existing_acl is None:
|
||||
acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission)
|
||||
session.add(acl)
|
||||
await session.commit()
|
||||
try:
|
||||
acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission)
|
||||
session.add(acl)
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
session.rollback()
|
||||
|
||||
raise GivePermissionOnDatasetError()
|
||||
|
|
|
|||
|
|
@ -2,14 +2,19 @@ import json
|
|||
import inspect
|
||||
from uuid import UUID
|
||||
from typing import Union, BinaryIO, Any, List, Optional
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.data.models.DatasetData import DatasetData
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.permissions.methods import give_permission_on_dataset
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.modules.data.methods import (
|
||||
get_authorized_existing_datasets,
|
||||
get_dataset_data,
|
||||
load_or_create_datasets,
|
||||
)
|
||||
|
||||
from .save_data_item_to_storage import save_data_item_to_storage
|
||||
|
||||
|
||||
|
|
@ -56,20 +61,40 @@ async def ingest_data(
|
|||
node_set: Optional[List[str]] = None,
|
||||
dataset_id: UUID = None,
|
||||
):
|
||||
new_datapoints = []
|
||||
existing_data_points = []
|
||||
dataset_new_data_points = []
|
||||
|
||||
if not isinstance(data, list):
|
||||
# Convert data to a list as we work with lists further down.
|
||||
data = [data]
|
||||
|
||||
file_paths = []
|
||||
if dataset_id:
|
||||
# Retrieve existing dataset
|
||||
dataset = await get_specific_user_permission_datasets(user.id, "write", [dataset_id])
|
||||
# Convert from list to Dataset element
|
||||
if isinstance(dataset, list):
|
||||
dataset = dataset[0]
|
||||
else:
|
||||
# Find existing dataset or create a new one
|
||||
existing_datasets = await get_authorized_existing_datasets(
|
||||
user=user, permission_type="write", datasets=[dataset_name]
|
||||
)
|
||||
dataset = await load_or_create_datasets(
|
||||
dataset_names=[dataset_name],
|
||||
existing_datasets=existing_datasets,
|
||||
user=user,
|
||||
)
|
||||
if isinstance(dataset, list):
|
||||
dataset = dataset[0]
|
||||
|
||||
dataset_data: list[Data] = await get_dataset_data(dataset.id)
|
||||
dataset_data_map = {str(data.id): True for data in dataset_data}
|
||||
|
||||
# Process data
|
||||
for data_item in data:
|
||||
file_path = await save_data_item_to_storage(data_item, dataset_name)
|
||||
|
||||
file_paths.append(file_path)
|
||||
|
||||
# Ingest data and add metadata
|
||||
# with open(file_path.replace("file://", ""), mode="rb") as file:
|
||||
with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file, s3fs=fs)
|
||||
|
||||
|
|
@ -80,90 +105,76 @@ async def ingest_data(
|
|||
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.modules.data.models import Data
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
# Check to see if data should be updated
|
||||
async with db_engine.get_async_session() as session:
|
||||
if dataset_id:
|
||||
# Retrieve existing dataset
|
||||
dataset = await get_specific_user_permission_datasets(
|
||||
user.id, "write", [dataset_id]
|
||||
)
|
||||
# Convert from list to Dataset element
|
||||
if isinstance(dataset, list):
|
||||
dataset = dataset[0]
|
||||
|
||||
dataset = await session.merge(
|
||||
dataset
|
||||
) # Add found dataset object into current session
|
||||
else:
|
||||
# Create new one
|
||||
dataset = await create_dataset(dataset_name, user, session)
|
||||
|
||||
# Check to see if data should be updated
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
ext_metadata = get_external_metadata_dict(data_item)
|
||||
if node_set:
|
||||
ext_metadata["node_set"] = node_set
|
||||
ext_metadata = get_external_metadata_dict(data_item)
|
||||
|
||||
if data_point is not None:
|
||||
data_point.name = file_metadata["name"]
|
||||
data_point.raw_data_location = file_metadata["file_path"]
|
||||
data_point.extension = file_metadata["extension"]
|
||||
data_point.mime_type = file_metadata["mime_type"]
|
||||
data_point.owner_id = user.id
|
||||
data_point.content_hash = file_metadata["content_hash"]
|
||||
data_point.external_metadata = ext_metadata
|
||||
data_point.node_set = json.dumps(node_set) if node_set else None
|
||||
await session.merge(data_point)
|
||||
else:
|
||||
data_point = Data(
|
||||
id=data_id,
|
||||
name=file_metadata["name"],
|
||||
raw_data_location=file_metadata["file_path"],
|
||||
extension=file_metadata["extension"],
|
||||
mime_type=file_metadata["mime_type"],
|
||||
owner_id=user.id,
|
||||
content_hash=file_metadata["content_hash"],
|
||||
external_metadata=ext_metadata,
|
||||
node_set=json.dumps(node_set) if node_set else None,
|
||||
token_count=-1,
|
||||
)
|
||||
session.add(data_point)
|
||||
if node_set:
|
||||
ext_metadata["node_set"] = node_set
|
||||
|
||||
if data_point is not None:
|
||||
data_point.name = file_metadata["name"]
|
||||
data_point.raw_data_location = file_metadata["file_path"]
|
||||
data_point.extension = file_metadata["extension"]
|
||||
data_point.mime_type = file_metadata["mime_type"]
|
||||
data_point.owner_id = user.id
|
||||
data_point.content_hash = file_metadata["content_hash"]
|
||||
data_point.external_metadata = ext_metadata
|
||||
data_point.node_set = json.dumps(node_set) if node_set else None
|
||||
|
||||
# Check if data is already in dataset
|
||||
dataset_data = (
|
||||
await session.execute(
|
||||
select(DatasetData).filter(
|
||||
DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
# If data is not present in dataset add it
|
||||
if dataset_data is None:
|
||||
dataset.data.append(data_point)
|
||||
await session.merge(dataset)
|
||||
if str(data_point.id) in dataset_data_map:
|
||||
existing_data_points.append(data_point)
|
||||
else:
|
||||
dataset_new_data_points.append(data_point)
|
||||
dataset_data_map[str(data_point.id)] = True
|
||||
else:
|
||||
if str(data_id) in dataset_data_map:
|
||||
continue
|
||||
|
||||
await session.commit()
|
||||
data_point = Data(
|
||||
id=data_id,
|
||||
name=file_metadata["name"],
|
||||
raw_data_location=file_metadata["file_path"],
|
||||
extension=file_metadata["extension"],
|
||||
mime_type=file_metadata["mime_type"],
|
||||
owner_id=user.id,
|
||||
content_hash=file_metadata["content_hash"],
|
||||
external_metadata=ext_metadata,
|
||||
node_set=json.dumps(node_set) if node_set else None,
|
||||
token_count=-1,
|
||||
)
|
||||
|
||||
await give_permission_on_dataset(user, dataset.id, "read")
|
||||
await give_permission_on_dataset(user, dataset.id, "write")
|
||||
await give_permission_on_dataset(user, dataset.id, "delete")
|
||||
await give_permission_on_dataset(user, dataset.id, "share")
|
||||
new_datapoints.append(data_point)
|
||||
dataset_data_map[str(data_point.id)] = True
|
||||
|
||||
return file_paths
|
||||
async with db_engine.get_async_session() as session:
|
||||
if dataset not in session:
|
||||
session.add(dataset)
|
||||
|
||||
await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)
|
||||
if len(new_datapoints) > 0:
|
||||
session.add_all(new_datapoints)
|
||||
dataset.data.extend(new_datapoints)
|
||||
|
||||
datasets = await get_datasets_by_name(dataset_name, user.id)
|
||||
if len(existing_data_points) > 0:
|
||||
for data_point in existing_data_points:
|
||||
await session.merge(data_point)
|
||||
|
||||
# In case no files were processed no dataset will be created
|
||||
if datasets:
|
||||
dataset = datasets[0]
|
||||
data_documents = await get_dataset_data(dataset_id=dataset.id)
|
||||
return data_documents
|
||||
return []
|
||||
if len(dataset_new_data_points) > 0:
|
||||
for data_point in dataset_new_data_points:
|
||||
await session.merge(data_point)
|
||||
dataset.data.extend(dataset_new_data_points)
|
||||
|
||||
await session.merge(dataset)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return existing_data_points + dataset_new_data_points + new_datapoints
|
||||
|
||||
return await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ async def main():
|
|||
await add_nodes_and_edges_queue.clear.aio()
|
||||
|
||||
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
|
||||
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn
|
||||
number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn
|
||||
|
||||
results = []
|
||||
consumer_futures = []
|
||||
|
|
@ -44,7 +44,8 @@ async def main():
|
|||
worker_future = data_point_saving_worker.spawn()
|
||||
consumer_futures.append(worker_future)
|
||||
|
||||
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
|
||||
# s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
|
||||
s3_bucket_name = "s3://s3-test-laszlo/Pdf"
|
||||
|
||||
await cognee.add(s3_bucket_name, dataset_name="s3-files")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import modal
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy.exc import OperationalError, DBAPIError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
from distributed.app import app
|
||||
from distributed.modal_image import image
|
||||
|
|
@ -13,6 +14,31 @@ from cognee.infrastructure.databases.vector import get_vector_engine
|
|||
logger = get_logger("data_point_saving_worker")
|
||||
|
||||
|
||||
class VectorDatabaseDeadlockError(Exception):
|
||||
message = "A deadlock occurred while trying to add data points to the vector database."
|
||||
|
||||
|
||||
def is_deadlock_error(error):
|
||||
# SQLAlchemy-wrapped asyncpg
|
||||
try:
|
||||
import asyncpg
|
||||
|
||||
if isinstance(error.orig, asyncpg.exceptions.DeadlockDetectedError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# PostgreSQL: SQLSTATE 40P01 = deadlock_detected
|
||||
if hasattr(error.orig, "pgcode") and error.orig.pgcode == "40P01":
|
||||
return True
|
||||
|
||||
# SQLite: It doesn't support real deadlocks but may simulate them as "database is locked"
|
||||
if "database is locked" in str(error.orig).lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
timeout=86400,
|
||||
|
|
@ -40,9 +66,24 @@ async def data_point_saving_worker():
|
|||
|
||||
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
|
||||
|
||||
await vector_engine.create_data_points(
|
||||
collection_name, data_points, distributed=False
|
||||
@retry(
|
||||
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def add_data_points():
|
||||
try:
|
||||
await vector_engine.create_data_points(
|
||||
collection_name, data_points, distributed=False
|
||||
)
|
||||
except DBAPIError as error:
|
||||
if is_deadlock_error(error):
|
||||
raise VectorDatabaseDeadlockError()
|
||||
except OperationalError as error:
|
||||
if is_deadlock_error(error):
|
||||
raise VectorDatabaseDeadlockError()
|
||||
|
||||
await add_data_points()
|
||||
|
||||
print("Finished adding data points.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import modal
|
||||
import asyncio
|
||||
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
from distributed.app import app
|
||||
from distributed.modal_image import image
|
||||
|
|
@ -8,11 +8,35 @@ from distributed.queues import add_nodes_and_edges_queue
|
|||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
|
||||
logger = get_logger("graph_saving_worker")
|
||||
|
||||
|
||||
class GraphDatabaseDeadlockError(Exception):
|
||||
message = "A deadlock occurred while trying to add data points to the vector database."
|
||||
|
||||
|
||||
def is_deadlock_error(error):
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider == "neo4j":
|
||||
# Neo4j
|
||||
from neo4j.exceptions import TransientError
|
||||
|
||||
if isinstance(error, TransientError) and (
|
||||
error.code == "Neo.TransientError.Transaction.DeadlockDetected"
|
||||
):
|
||||
return True
|
||||
|
||||
# Kuzu
|
||||
if "deadlock" in str(error).lower() or "cannot acquire lock" in str(error).lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
timeout=86400,
|
||||
|
|
@ -42,11 +66,36 @@ async def graph_saving_worker():
|
|||
nodes = nodes_and_edges[0]
|
||||
edges = nodes_and_edges[1]
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def save_graph_nodes(new_nodes):
|
||||
try:
|
||||
await graph_engine.add_nodes(new_nodes, distributed=False)
|
||||
except Exception as error:
|
||||
if is_deadlock_error(error):
|
||||
raise GraphDatabaseDeadlockError()
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def save_graph_edges(new_edges):
|
||||
try:
|
||||
await graph_engine.add_edges(new_edges, distributed=False)
|
||||
except Exception as error:
|
||||
if is_deadlock_error(error):
|
||||
raise GraphDatabaseDeadlockError()
|
||||
|
||||
if nodes:
|
||||
await graph_engine.add_nodes(nodes, distributed=False)
|
||||
await save_graph_nodes(nodes)
|
||||
|
||||
if edges:
|
||||
await graph_engine.add_edges(edges, distributed=False)
|
||||
await save_graph_edges(edges)
|
||||
|
||||
print("Finished adding nodes and edges.")
|
||||
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue