<!-- .github/pull_request_template.md --> ## Description This PR fixes distributed pipeline + updates core changes in distr logic. ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Changes Made Fixes distributed pipeline: -Changed spawning logic + adds incremental loading to run_tasks_diistributed -Adds batching to consumer nodes -Fixes consumer stopping criteria by adding stop signal + handling -Changed edge embedding solution to avoid huge network load in a case of a multicontainer environment ## Testing Tested it by running 1GB on modal + manually ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues None ## Additional Notes None ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Boris <boris@topoteretes.com> Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
123 lines
4.6 KiB
Python
123 lines
4.6 KiB
Python
import os
|
|
import modal
|
|
import asyncio
|
|
from sqlalchemy.exc import OperationalError, DBAPIError
|
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
|
|
from distributed.app import app
|
|
from distributed.signal import QueueSignal
|
|
from distributed.modal_image import image
|
|
from distributed.queues import add_data_points_queue
|
|
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
|
|
logger = get_logger("data_point_saving_worker")
|
|
|
|
|
|
class VectorDatabaseDeadlockError(Exception):
|
|
message = "A deadlock occurred while trying to add data points to the vector database."
|
|
|
|
|
|
def is_deadlock_error(error):
|
|
# SQLAlchemy-wrapped asyncpg
|
|
try:
|
|
import asyncpg
|
|
|
|
if isinstance(error.orig, asyncpg.exceptions.DeadlockDetectedError):
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
# PostgreSQL: SQLSTATE 40P01 = deadlock_detected
|
|
if hasattr(error.orig, "pgcode") and error.orig.pgcode == "40P01":
|
|
return True
|
|
|
|
# SQLite: It doesn't support real deadlocks but may simulate them as "database is locked"
|
|
if "database is locked" in str(error.orig).lower():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
|
|
|
|
|
@app.function(
|
|
retries=3,
|
|
image=image,
|
|
timeout=86400,
|
|
max_containers=10,
|
|
secrets=[modal.Secret.from_name(secret_name)],
|
|
)
|
|
async def data_point_saving_worker():
|
|
print("Started processing of data points; starting vector engine queue.")
|
|
vector_engine = get_vector_engine()
|
|
# Defines how many data packets do we glue together from the modal queue before embedding call and ingestion
|
|
BATCH_SIZE = 25
|
|
stop_seen = False
|
|
|
|
while True:
|
|
if stop_seen:
|
|
print("Finished processing all data points; stopping vector engine queue consumer.")
|
|
return True
|
|
|
|
if await add_data_points_queue.len.aio() != 0:
|
|
try:
|
|
print("Remaining elements in queue:")
|
|
print(await add_data_points_queue.len.aio())
|
|
|
|
# collect batched requests
|
|
batched_points = {}
|
|
for _ in range(min(BATCH_SIZE, await add_data_points_queue.len.aio())):
|
|
add_data_points_request = await add_data_points_queue.get.aio(block=False)
|
|
|
|
if not add_data_points_request:
|
|
continue
|
|
|
|
if add_data_points_request == QueueSignal.STOP:
|
|
await add_data_points_queue.put.aio(QueueSignal.STOP)
|
|
stop_seen = True
|
|
break
|
|
|
|
if len(add_data_points_request) == 2:
|
|
collection_name, data_points = add_data_points_request
|
|
if collection_name not in batched_points:
|
|
batched_points[collection_name] = []
|
|
batched_points[collection_name].extend(data_points)
|
|
else:
|
|
print("NoneType or invalid request detected.")
|
|
|
|
if batched_points:
|
|
for collection_name, data_points in batched_points.items():
|
|
print(
|
|
f"Adding {len(data_points)} data points to '{collection_name}' collection."
|
|
)
|
|
|
|
@retry(
|
|
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
|
)
|
|
async def add_data_points():
|
|
try:
|
|
await vector_engine.create_data_points(
|
|
collection_name, data_points, distributed=False
|
|
)
|
|
except DBAPIError as error:
|
|
if is_deadlock_error(error):
|
|
raise VectorDatabaseDeadlockError()
|
|
except OperationalError as error:
|
|
if is_deadlock_error(error):
|
|
raise VectorDatabaseDeadlockError()
|
|
|
|
await add_data_points()
|
|
print(f"Finished adding data points to '{collection_name}'.")
|
|
|
|
except modal.exception.DeserializationError as error:
|
|
logger.error(f"Deserialization error: {str(error)}")
|
|
continue
|
|
|
|
else:
|
|
print("No jobs, go to sleep.")
|
|
await asyncio.sleep(5)
|