fix: add error handling
This commit is contained in:
parent
f8f1bb3576
commit
685d282f5c
10 changed files with 265 additions and 120 deletions
|
|
@ -296,13 +296,13 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
extract_graph_from_data,
|
extract_graph_from_data,
|
||||||
graph_model=graph_model,
|
graph_model=graph_model,
|
||||||
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": 25},
|
||||||
), # Generate knowledge graphs from the document chunks.
|
), # Generate knowledge graphs from the document chunks.
|
||||||
Task(
|
Task(
|
||||||
summarize_text,
|
summarize_text,
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": 25},
|
||||||
),
|
),
|
||||||
Task(add_data_points, task_config={"batch_size": 10}),
|
Task(add_data_points, task_config={"batch_size": 25}),
|
||||||
]
|
]
|
||||||
|
|
||||||
return default_tasks
|
return default_tasks
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
|
|
@ -113,9 +113,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(Union[DuplicateTableError, UniqueViolationError]),
|
retry=retry_if_exception_type((DuplicateTableError, UniqueViolationError)),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
sleep=1,
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
|
|
@ -159,7 +159,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
sleep=1,
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
@override_distributed(queued_add_data_points)
|
@override_distributed(queued_add_data_points)
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
|
|
@ -204,26 +204,26 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
for data_index, data_point in enumerate(data_points):
|
for data_index, data_point in enumerate(data_points):
|
||||||
# Check to see if data should be updated or a new data item should be created
|
# Check to see if data should be updated or a new data item should be created
|
||||||
data_point_db = (
|
# data_point_db = (
|
||||||
await session.execute(
|
# await session.execute(
|
||||||
select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
|
# select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
|
||||||
)
|
# )
|
||||||
).scalar_one_or_none()
|
# ).scalar_one_or_none()
|
||||||
|
|
||||||
# If data point exists update it, if not create a new one
|
# If data point exists update it, if not create a new one
|
||||||
if data_point_db:
|
# if data_point_db:
|
||||||
data_point_db.id = data_point.id
|
# data_point_db.id = data_point.id
|
||||||
data_point_db.vector = data_vectors[data_index]
|
# data_point_db.vector = data_vectors[data_index]
|
||||||
data_point_db.payload = serialize_data(data_point.model_dump())
|
# data_point_db.payload = serialize_data(data_point.model_dump())
|
||||||
pgvector_data_points.append(data_point_db)
|
# pgvector_data_points.append(data_point_db)
|
||||||
else:
|
# else:
|
||||||
pgvector_data_points.append(
|
pgvector_data_points.append(
|
||||||
PGVectorDataPoint(
|
PGVectorDataPoint(
|
||||||
id=data_point.id,
|
id=data_point.id,
|
||||||
vector=data_vectors[data_index],
|
vector=data_vectors[data_index],
|
||||||
payload=serialize_data(data_point.model_dump()),
|
payload=serialize_data(data_point.model_dump()),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(obj):
|
def to_dict(obj):
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,16 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
|
|
||||||
|
from distributed.utils import override_distributed
|
||||||
|
from distributed.tasks.queued_add_data_points import queued_add_data_points
|
||||||
|
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
|
|
@ -13,6 +18,18 @@ from ..vector_db_interface import VectorDBInterface
|
||||||
logger = get_logger("WeaviateAdapter")
|
logger = get_logger("WeaviateAdapter")
|
||||||
|
|
||||||
|
|
||||||
|
def is_retryable_request(error):
|
||||||
|
from weaviate.exceptions import UnexpectedStatusCodeException
|
||||||
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
|
if isinstance(error, UnexpectedStatusCodeException):
|
||||||
|
# Retry on conflict, service unavailable, internal error
|
||||||
|
return error.status_code in {409, 503, 500}
|
||||||
|
if isinstance(error, RequestException):
|
||||||
|
return True # Includes timeout, connection error, etc.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class IndexSchema(DataPoint):
|
class IndexSchema(DataPoint):
|
||||||
"""
|
"""
|
||||||
Define a schema for indexing data points with textual content.
|
Define a schema for indexing data points with textual content.
|
||||||
|
|
@ -124,6 +141,11 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
return await client.collections.exists(collection_name)
|
return await client.collections.exists(collection_name)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception(is_retryable_request),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
|
)
|
||||||
async def create_collection(
|
async def create_collection(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
@ -184,6 +206,12 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
return client.collections.get(collection_name)
|
return client.collections.get(collection_name)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception(is_retryable_request),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
|
)
|
||||||
|
@override_distributed(queued_add_data_points)
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
"""
|
"""
|
||||||
Create or update data points in the specified collection in the Weaviate database.
|
Create or update data points in the specified collection in the Weaviate database.
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from cognee.modules.data.models import Dataset
|
||||||
from cognee.modules.data.methods import create_dataset
|
from cognee.modules.data.methods import create_dataset
|
||||||
from cognee.modules.data.methods import get_unique_dataset_id
|
from cognee.modules.data.methods import get_unique_dataset_id
|
||||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||||
|
from cognee.modules.users.permissions.methods import give_permission_on_dataset
|
||||||
|
|
||||||
|
|
||||||
async def load_or_create_datasets(
|
async def load_or_create_datasets(
|
||||||
|
|
@ -45,6 +46,11 @@ async def load_or_create_datasets(
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
await create_dataset(identifier, user, session)
|
await create_dataset(identifier, user, session)
|
||||||
|
|
||||||
|
await give_permission_on_dataset(user, new_dataset.id, "read")
|
||||||
|
await give_permission_on_dataset(user, new_dataset.id, "write")
|
||||||
|
await give_permission_on_dataset(user, new_dataset.id, "delete")
|
||||||
|
await give_permission_on_dataset(user, new_dataset.id, "share")
|
||||||
|
|
||||||
result.append(new_dataset)
|
result.append(new_dataset)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ if modal:
|
||||||
@app.function(
|
@app.function(
|
||||||
image=image,
|
image=image,
|
||||||
timeout=86400,
|
timeout=86400,
|
||||||
max_containers=100,
|
max_containers=50,
|
||||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||||
)
|
)
|
||||||
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
|
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from asyncpg import UniqueViolationError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.users.permissions import PERMISSION_TYPES
|
from cognee.modules.users.permissions import PERMISSION_TYPES
|
||||||
|
|
@ -10,10 +10,14 @@ from cognee.modules.users.exceptions import PermissionNotFoundError
|
||||||
from ...models import Principal, ACL, Permission
|
from ...models import Principal, ACL, Permission
|
||||||
|
|
||||||
|
|
||||||
|
class GivePermissionOnDatasetError(Exception):
|
||||||
|
message: str = "Failed to give permission on dataset"
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(UniqueViolationError),
|
retry=retry_if_exception_type(GivePermissionOnDatasetError),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
sleep=1,
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
async def give_permission_on_dataset(
|
async def give_permission_on_dataset(
|
||||||
principal: Principal,
|
principal: Principal,
|
||||||
|
|
@ -50,6 +54,11 @@ async def give_permission_on_dataset(
|
||||||
|
|
||||||
# If no existing ACL entry is found, proceed to add a new one
|
# If no existing ACL entry is found, proceed to add a new one
|
||||||
if existing_acl is None:
|
if existing_acl is None:
|
||||||
acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission)
|
try:
|
||||||
session.add(acl)
|
acl = ACL(principal_id=principal.id, dataset_id=dataset_id, permission=permission)
|
||||||
await session.commit()
|
session.add(acl)
|
||||||
|
await session.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
raise GivePermissionOnDatasetError()
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,19 @@ import json
|
||||||
import inspect
|
import inspect
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Union, BinaryIO, Any, List, Optional
|
from typing import Union, BinaryIO, Any, List, Optional
|
||||||
|
|
||||||
import cognee.modules.ingestion as ingestion
|
import cognee.modules.ingestion as ingestion
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
|
from cognee.modules.data.models import Data
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.modules.data.models.DatasetData import DatasetData
|
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.permissions.methods import give_permission_on_dataset
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||||
|
from cognee.modules.data.methods import (
|
||||||
|
get_authorized_existing_datasets,
|
||||||
|
get_dataset_data,
|
||||||
|
load_or_create_datasets,
|
||||||
|
)
|
||||||
|
|
||||||
from .save_data_item_to_storage import save_data_item_to_storage
|
from .save_data_item_to_storage import save_data_item_to_storage
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -56,20 +61,40 @@ async def ingest_data(
|
||||||
node_set: Optional[List[str]] = None,
|
node_set: Optional[List[str]] = None,
|
||||||
dataset_id: UUID = None,
|
dataset_id: UUID = None,
|
||||||
):
|
):
|
||||||
|
new_datapoints = []
|
||||||
|
existing_data_points = []
|
||||||
|
dataset_new_data_points = []
|
||||||
|
|
||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
# Convert data to a list as we work with lists further down.
|
# Convert data to a list as we work with lists further down.
|
||||||
data = [data]
|
data = [data]
|
||||||
|
|
||||||
file_paths = []
|
if dataset_id:
|
||||||
|
# Retrieve existing dataset
|
||||||
|
dataset = await get_specific_user_permission_datasets(user.id, "write", [dataset_id])
|
||||||
|
# Convert from list to Dataset element
|
||||||
|
if isinstance(dataset, list):
|
||||||
|
dataset = dataset[0]
|
||||||
|
else:
|
||||||
|
# Find existing dataset or create a new one
|
||||||
|
existing_datasets = await get_authorized_existing_datasets(
|
||||||
|
user=user, permission_type="write", datasets=[dataset_name]
|
||||||
|
)
|
||||||
|
dataset = await load_or_create_datasets(
|
||||||
|
dataset_names=[dataset_name],
|
||||||
|
existing_datasets=existing_datasets,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if isinstance(dataset, list):
|
||||||
|
dataset = dataset[0]
|
||||||
|
|
||||||
|
dataset_data: list[Data] = await get_dataset_data(dataset.id)
|
||||||
|
dataset_data_map = {str(data.id): True for data in dataset_data}
|
||||||
|
|
||||||
# Process data
|
|
||||||
for data_item in data:
|
for data_item in data:
|
||||||
file_path = await save_data_item_to_storage(data_item, dataset_name)
|
file_path = await save_data_item_to_storage(data_item, dataset_name)
|
||||||
|
|
||||||
file_paths.append(file_path)
|
|
||||||
|
|
||||||
# Ingest data and add metadata
|
# Ingest data and add metadata
|
||||||
# with open(file_path.replace("file://", ""), mode="rb") as file:
|
|
||||||
with open_data_file(file_path) as file:
|
with open_data_file(file_path) as file:
|
||||||
classified_data = ingestion.classify(file, s3fs=fs)
|
classified_data = ingestion.classify(file, s3fs=fs)
|
||||||
|
|
||||||
|
|
@ -80,90 +105,76 @@ async def ingest_data(
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from cognee.modules.data.models import Data
|
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
# Check to see if data should be updated
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
if dataset_id:
|
|
||||||
# Retrieve existing dataset
|
|
||||||
dataset = await get_specific_user_permission_datasets(
|
|
||||||
user.id, "write", [dataset_id]
|
|
||||||
)
|
|
||||||
# Convert from list to Dataset element
|
|
||||||
if isinstance(dataset, list):
|
|
||||||
dataset = dataset[0]
|
|
||||||
|
|
||||||
dataset = await session.merge(
|
|
||||||
dataset
|
|
||||||
) # Add found dataset object into current session
|
|
||||||
else:
|
|
||||||
# Create new one
|
|
||||||
dataset = await create_dataset(dataset_name, user, session)
|
|
||||||
|
|
||||||
# Check to see if data should be updated
|
|
||||||
data_point = (
|
data_point = (
|
||||||
await session.execute(select(Data).filter(Data.id == data_id))
|
await session.execute(select(Data).filter(Data.id == data_id))
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
ext_metadata = get_external_metadata_dict(data_item)
|
ext_metadata = get_external_metadata_dict(data_item)
|
||||||
if node_set:
|
|
||||||
ext_metadata["node_set"] = node_set
|
|
||||||
|
|
||||||
if data_point is not None:
|
if node_set:
|
||||||
data_point.name = file_metadata["name"]
|
ext_metadata["node_set"] = node_set
|
||||||
data_point.raw_data_location = file_metadata["file_path"]
|
|
||||||
data_point.extension = file_metadata["extension"]
|
if data_point is not None:
|
||||||
data_point.mime_type = file_metadata["mime_type"]
|
data_point.name = file_metadata["name"]
|
||||||
data_point.owner_id = user.id
|
data_point.raw_data_location = file_metadata["file_path"]
|
||||||
data_point.content_hash = file_metadata["content_hash"]
|
data_point.extension = file_metadata["extension"]
|
||||||
data_point.external_metadata = ext_metadata
|
data_point.mime_type = file_metadata["mime_type"]
|
||||||
data_point.node_set = json.dumps(node_set) if node_set else None
|
data_point.owner_id = user.id
|
||||||
await session.merge(data_point)
|
data_point.content_hash = file_metadata["content_hash"]
|
||||||
else:
|
data_point.external_metadata = ext_metadata
|
||||||
data_point = Data(
|
data_point.node_set = json.dumps(node_set) if node_set else None
|
||||||
id=data_id,
|
|
||||||
name=file_metadata["name"],
|
|
||||||
raw_data_location=file_metadata["file_path"],
|
|
||||||
extension=file_metadata["extension"],
|
|
||||||
mime_type=file_metadata["mime_type"],
|
|
||||||
owner_id=user.id,
|
|
||||||
content_hash=file_metadata["content_hash"],
|
|
||||||
external_metadata=ext_metadata,
|
|
||||||
node_set=json.dumps(node_set) if node_set else None,
|
|
||||||
token_count=-1,
|
|
||||||
)
|
|
||||||
session.add(data_point)
|
|
||||||
|
|
||||||
# Check if data is already in dataset
|
# Check if data is already in dataset
|
||||||
dataset_data = (
|
if str(data_point.id) in dataset_data_map:
|
||||||
await session.execute(
|
existing_data_points.append(data_point)
|
||||||
select(DatasetData).filter(
|
else:
|
||||||
DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id
|
dataset_new_data_points.append(data_point)
|
||||||
)
|
dataset_data_map[str(data_point.id)] = True
|
||||||
)
|
else:
|
||||||
).scalar_one_or_none()
|
if str(data_id) in dataset_data_map:
|
||||||
# If data is not present in dataset add it
|
continue
|
||||||
if dataset_data is None:
|
|
||||||
dataset.data.append(data_point)
|
|
||||||
await session.merge(dataset)
|
|
||||||
|
|
||||||
await session.commit()
|
data_point = Data(
|
||||||
|
id=data_id,
|
||||||
|
name=file_metadata["name"],
|
||||||
|
raw_data_location=file_metadata["file_path"],
|
||||||
|
extension=file_metadata["extension"],
|
||||||
|
mime_type=file_metadata["mime_type"],
|
||||||
|
owner_id=user.id,
|
||||||
|
content_hash=file_metadata["content_hash"],
|
||||||
|
external_metadata=ext_metadata,
|
||||||
|
node_set=json.dumps(node_set) if node_set else None,
|
||||||
|
token_count=-1,
|
||||||
|
)
|
||||||
|
|
||||||
await give_permission_on_dataset(user, dataset.id, "read")
|
new_datapoints.append(data_point)
|
||||||
await give_permission_on_dataset(user, dataset.id, "write")
|
dataset_data_map[str(data_point.id)] = True
|
||||||
await give_permission_on_dataset(user, dataset.id, "delete")
|
|
||||||
await give_permission_on_dataset(user, dataset.id, "share")
|
|
||||||
|
|
||||||
return file_paths
|
async with db_engine.get_async_session() as session:
|
||||||
|
if dataset not in session:
|
||||||
|
session.add(dataset)
|
||||||
|
|
||||||
await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)
|
if len(new_datapoints) > 0:
|
||||||
|
session.add_all(new_datapoints)
|
||||||
|
dataset.data.extend(new_datapoints)
|
||||||
|
|
||||||
datasets = await get_datasets_by_name(dataset_name, user.id)
|
if len(existing_data_points) > 0:
|
||||||
|
for data_point in existing_data_points:
|
||||||
|
await session.merge(data_point)
|
||||||
|
|
||||||
# In case no files were processed no dataset will be created
|
if len(dataset_new_data_points) > 0:
|
||||||
if datasets:
|
for data_point in dataset_new_data_points:
|
||||||
dataset = datasets[0]
|
await session.merge(data_point)
|
||||||
data_documents = await get_dataset_data(dataset_id=dataset.id)
|
dataset.data.extend(dataset_new_data_points)
|
||||||
return data_documents
|
|
||||||
return []
|
await session.merge(dataset)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return existing_data_points + dataset_new_data_points + new_datapoints
|
||||||
|
|
||||||
|
return await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ async def main():
|
||||||
await add_nodes_and_edges_queue.clear.aio()
|
await add_nodes_and_edges_queue.clear.aio()
|
||||||
|
|
||||||
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
|
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
|
||||||
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn
|
number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
consumer_futures = []
|
consumer_futures = []
|
||||||
|
|
@ -44,7 +44,8 @@ async def main():
|
||||||
worker_future = data_point_saving_worker.spawn()
|
worker_future = data_point_saving_worker.spawn()
|
||||||
consumer_futures.append(worker_future)
|
consumer_futures.append(worker_future)
|
||||||
|
|
||||||
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
|
# s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
|
||||||
|
s3_bucket_name = "s3://s3-test-laszlo/Pdf"
|
||||||
|
|
||||||
await cognee.add(s3_bucket_name, dataset_name="s3-files")
|
await cognee.add(s3_bucket_name, dataset_name="s3-files")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import modal
|
import modal
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from sqlalchemy.exc import OperationalError, DBAPIError
|
||||||
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from distributed.app import app
|
from distributed.app import app
|
||||||
from distributed.modal_image import image
|
from distributed.modal_image import image
|
||||||
|
|
@ -13,6 +14,31 @@ from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
logger = get_logger("data_point_saving_worker")
|
logger = get_logger("data_point_saving_worker")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDatabaseDeadlockError(Exception):
|
||||||
|
message = "A deadlock occurred while trying to add data points to the vector database."
|
||||||
|
|
||||||
|
|
||||||
|
def is_deadlock_error(error):
|
||||||
|
# SQLAlchemy-wrapped asyncpg
|
||||||
|
try:
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
if isinstance(error.orig, asyncpg.exceptions.DeadlockDetectedError):
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# PostgreSQL: SQLSTATE 40P01 = deadlock_detected
|
||||||
|
if hasattr(error.orig, "pgcode") and error.orig.pgcode == "40P01":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# SQLite: It doesn't support real deadlocks but may simulate them as "database is locked"
|
||||||
|
if "database is locked" in str(error.orig).lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
@app.function(
|
||||||
image=image,
|
image=image,
|
||||||
timeout=86400,
|
timeout=86400,
|
||||||
|
|
@ -40,9 +66,24 @@ async def data_point_saving_worker():
|
||||||
|
|
||||||
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
|
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
|
||||||
|
|
||||||
await vector_engine.create_data_points(
|
@retry(
|
||||||
collection_name, data_points, distributed=False
|
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
|
async def add_data_points():
|
||||||
|
try:
|
||||||
|
await vector_engine.create_data_points(
|
||||||
|
collection_name, data_points, distributed=False
|
||||||
|
)
|
||||||
|
except DBAPIError as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise VectorDatabaseDeadlockError()
|
||||||
|
except OperationalError as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise VectorDatabaseDeadlockError()
|
||||||
|
|
||||||
|
await add_data_points()
|
||||||
|
|
||||||
print("Finished adding data points.")
|
print("Finished adding data points.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import modal
|
import modal
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from distributed.app import app
|
from distributed.app import app
|
||||||
from distributed.modal_image import image
|
from distributed.modal_image import image
|
||||||
|
|
@ -8,11 +8,35 @@ from distributed.queues import add_nodes_and_edges_queue
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("graph_saving_worker")
|
logger = get_logger("graph_saving_worker")
|
||||||
|
|
||||||
|
|
||||||
|
class GraphDatabaseDeadlockError(Exception):
|
||||||
|
message = "A deadlock occurred while trying to add data points to the vector database."
|
||||||
|
|
||||||
|
|
||||||
|
def is_deadlock_error(error):
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
if graph_config.graph_database_provider == "neo4j":
|
||||||
|
# Neo4j
|
||||||
|
from neo4j.exceptions import TransientError
|
||||||
|
|
||||||
|
if isinstance(error, TransientError) and (
|
||||||
|
error.code == "Neo.TransientError.Transaction.DeadlockDetected"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Kuzu
|
||||||
|
if "deadlock" in str(error).lower() or "cannot acquire lock" in str(error).lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
@app.function(
|
||||||
image=image,
|
image=image,
|
||||||
timeout=86400,
|
timeout=86400,
|
||||||
|
|
@ -42,11 +66,36 @@ async def graph_saving_worker():
|
||||||
nodes = nodes_and_edges[0]
|
nodes = nodes_and_edges[0]
|
||||||
edges = nodes_and_edges[1]
|
edges = nodes_and_edges[1]
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
|
)
|
||||||
|
async def save_graph_nodes(new_nodes):
|
||||||
|
try:
|
||||||
|
await graph_engine.add_nodes(new_nodes, distributed=False)
|
||||||
|
except Exception as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise GraphDatabaseDeadlockError()
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
|
)
|
||||||
|
async def save_graph_edges(new_edges):
|
||||||
|
try:
|
||||||
|
await graph_engine.add_edges(new_edges, distributed=False)
|
||||||
|
except Exception as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise GraphDatabaseDeadlockError()
|
||||||
|
|
||||||
if nodes:
|
if nodes:
|
||||||
await graph_engine.add_nodes(nodes, distributed=False)
|
await save_graph_nodes(nodes)
|
||||||
|
|
||||||
if edges:
|
if edges:
|
||||||
await graph_engine.add_edges(edges, distributed=False)
|
await save_graph_edges(edges)
|
||||||
|
|
||||||
print("Finished adding nodes and edges.")
|
print("Finished adding nodes and edges.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue