<!-- .github/pull_request_template.md --> ## Description This PR fixes distributed pipeline + updates core changes in distr logic. ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Changes Made Fixes distributed pipeline: -Changed spawning logic + adds incremental loading to run_tasks_diistributed -Adds batching to consumer nodes -Fixes consumer stopping criteria by adding stop signal + handling -Changed edge embedding solution to avoid huge network load in a case of a multicontainer environment ## Testing Tested it by running 1GB on modal + manually ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues None ## Additional Notes None ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Boris <boris@topoteretes.com> Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
128 lines
4.6 KiB
Python
128 lines
4.6 KiB
Python
import os
|
|
import modal
|
|
import asyncio
|
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
|
|
from distributed.app import app
|
|
from distributed.signal import QueueSignal
|
|
from distributed.modal_image import image
|
|
from distributed.queues import add_nodes_and_edges_queue
|
|
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
|
|
|
logger = get_logger("graph_saving_worker")
|
|
|
|
|
|
class GraphDatabaseDeadlockError(Exception):
|
|
message = "A deadlock occurred while trying to add data points to the vector database."
|
|
|
|
|
|
def is_deadlock_error(error):
|
|
graph_config = get_graph_config()
|
|
|
|
if graph_config.graph_database_provider == "neo4j":
|
|
# Neo4j
|
|
from neo4j.exceptions import TransientError
|
|
|
|
if isinstance(error, TransientError) and (
|
|
error.code == "Neo.TransientError.Transaction.DeadlockDetected"
|
|
):
|
|
return True
|
|
|
|
# Kuzu
|
|
if "deadlock" in str(error).lower() or "cannot acquire lock" in str(error).lower():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
|
|
|
|
|
@app.function(
|
|
retries=3,
|
|
image=image,
|
|
timeout=86400,
|
|
max_containers=1,
|
|
secrets=[modal.Secret.from_name(secret_name)],
|
|
)
|
|
async def graph_saving_worker():
|
|
print("Started processing of nodes and edges; starting graph engine queue.")
|
|
graph_engine = await get_graph_engine()
|
|
# Defines how many data packets do we glue together from the queue before ingesting them into the graph database
|
|
BATCH_SIZE = 25
|
|
stop_seen = False
|
|
|
|
while True:
|
|
if stop_seen:
|
|
print("Finished processing all data points; stopping graph engine queue consumer.")
|
|
return True
|
|
|
|
if await add_nodes_and_edges_queue.len.aio() != 0:
|
|
try:
|
|
print("Remaining elements in queue:")
|
|
print(await add_nodes_and_edges_queue.len.aio())
|
|
|
|
all_nodes, all_edges = [], []
|
|
for _ in range(min(BATCH_SIZE, await add_nodes_and_edges_queue.len.aio())):
|
|
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
|
|
|
|
if not nodes_and_edges:
|
|
continue
|
|
|
|
if nodes_and_edges == QueueSignal.STOP:
|
|
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
|
|
stop_seen = True
|
|
break
|
|
|
|
if len(nodes_and_edges) == 2:
|
|
nodes, edges = nodes_and_edges
|
|
all_nodes.extend(nodes)
|
|
all_edges.extend(edges)
|
|
else:
|
|
print("None Type detected.")
|
|
|
|
if all_nodes or all_edges:
|
|
print(f"Adding {len(all_nodes)} nodes and {len(all_edges)} edges.")
|
|
|
|
@retry(
|
|
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
|
)
|
|
async def save_graph_nodes(new_nodes):
|
|
try:
|
|
await graph_engine.add_nodes(new_nodes, distributed=False)
|
|
except Exception as error:
|
|
if is_deadlock_error(error):
|
|
raise GraphDatabaseDeadlockError()
|
|
|
|
@retry(
|
|
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
|
)
|
|
async def save_graph_edges(new_edges):
|
|
try:
|
|
await graph_engine.add_edges(new_edges, distributed=False)
|
|
except Exception as error:
|
|
if is_deadlock_error(error):
|
|
raise GraphDatabaseDeadlockError()
|
|
|
|
if all_nodes:
|
|
await save_graph_nodes(all_nodes)
|
|
|
|
if all_edges:
|
|
await save_graph_edges(all_edges)
|
|
|
|
print("Finished adding nodes and edges.")
|
|
|
|
except modal.exception.DeserializationError as error:
|
|
logger.error(f"Deserialization error: {str(error)}")
|
|
continue
|
|
|
|
else:
|
|
print("No jobs, go to sleep.")
|
|
await asyncio.sleep(5)
|