diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 74cae62eb..338541efd 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -12,13 +12,13 @@ from typing import Optional, Any, List, Dict, Type, Tuple from cognee.infrastructure.engine import DataPoint from cognee.shared.logging_utils import get_logger, ERROR -from cognee.infrastructure.databases.graph.utils import override_distributed from cognee.infrastructure.databases.graph.graph_db_interface import ( GraphDBInterface, record_graph_changes, ) from cognee.modules.storage.utils import JSONEncoder +from distributed.utils import override_distributed from distributed.tasks.queued_add_nodes import queued_add_nodes from distributed.tasks.queued_add_edges import queued_add_edges diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index f22e287c2..530fceb66 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,14 +1,12 @@ import asyncio -from uuid import UUID, uuid4 -from sqlalchemy.inspection import inspect from typing import List, Optional, Union, get_type_hints - +from sqlalchemy.inspection import inspect from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.dialects.postgresql import insert -from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from tenacity import retry, retry_if_exception_type, stop_after_attempt +from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError from cognee.exceptions import InvalidValueError from cognee.shared.logging_utils import get_logger @@ -16,6 +14,9 @@ from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.databases.relational import get_relational_engine +from distributed.utils import override_distributed +from distributed.tasks.queued_add_data_points import queued_add_data_points + from ...relational.ModelBase import Base from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ..utils import normalize_distances @@ -160,6 +161,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): stop=stop_after_attempt(3), sleep=1, ) + @override_distributed(queued_add_data_points) async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): data_point_types = get_type_hints(DataPoint) if not await self.has_collection(collection_name): diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 45f5a75d8..a301b2e32 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -35,7 +35,6 @@ def override_run_tasks(new_gen): default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true" distributed = default_distributed_value if distributed is None else distributed - print(f"run_tasks_distributed: {distributed}") if distributed: async for run_info in new_gen(*args, **kwargs): yield run_info diff --git a/cognee/modules/pipelines/operations/run_tasks_distributed.py b/cognee/modules/pipelines/operations/run_tasks_distributed.py index 012f52233..c0ec3cc93 100644 --- a/cognee/modules/pipelines/operations/run_tasks_distributed.py +++ b/cognee/modules/pipelines/operations/run_tasks_distributed.py @@ -66,9 +66,6 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co data_count = len(data) if isinstance(data, list) else 1 - print(f"Data count: {data_count}") - print(f"Data item to process: {type(data)} - {data}") - arguments = [ [tasks] * data_count, [[data_item] for data_item in data[:data_count]] if data_count > 1 else [data], diff --git a/distributed/entrypoint.py b/distributed/entrypoint.py index 301867ea5..7b6f56911 100644 --- a/distributed/entrypoint.py +++ b/distributed/entrypoint.py @@ -7,8 +7,9 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.engine.operations.setup import setup from distributed.app import app -from distributed.queues import save_data_points_queue -from distributed.workers.data_point_saver_worker import data_point_saver_worker +from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue +from distributed.workers.graph_saving_worker import graph_saving_worker +from distributed.workers.data_point_saving_worker import data_point_saving_worker logger = get_logger() @@ -19,9 +20,10 @@ os.environ["COGNEE_DISTRIBUTED"] = "True" @app.local_entrypoint() async def main(): # Clear queues - await save_data_points_queue.clear.aio() + await add_nodes_and_edges_queue.clear.aio() - number_of_data_saving_workers = 1 # Total number of data_point_saver_worker functions to spawn + number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn + number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn results = [] consumer_futures = [] @@ -32,9 +34,14 @@ async def main(): await setup() - # Start data_point_saver_worker functions - for _ in range(number_of_data_saving_workers): - worker_future = data_point_saver_worker.spawn() + # Start graph_saving_worker functions + for _ in range(number_of_graph_saving_workers): + worker_future = graph_saving_worker.spawn() + consumer_futures.append(worker_future) + + # Start data_point_saving_worker functions + for _ in range(number_of_data_point_saving_workers): + worker_future = data_point_saving_worker.spawn() consumer_futures.append(worker_future) s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1" @@ -44,11 +51,12 @@ async def main(): await cognee.cognify(datasets=["s3-files"]) # Push empty tuple into the queue to signal the end of data. - await save_data_points_queue.put.aio(()) + await add_nodes_and_edges_queue.put.aio(()) + await add_data_points_queue.put.aio(()) for consumer_future in consumer_futures: try: - print("Finished but waiting for saving worker to finish.") + print("Finished but waiting for saving workers to finish.") consumer_final = consumer_future.get() print(f"All workers are done: {consumer_final}") except Exception as e: diff --git a/distributed/queues.py b/distributed/queues.py index 5626f5c90..628d02d73 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -1,7 +1,5 @@ from modal import Queue -# Create (or get) queues: -# - save_data_points_queue: Stores messages produced by the producer functions. - -save_data_points_queue = Queue.from_name("save_data_points_queue", create_if_missing=True) +add_nodes_and_edges_queue = Queue.from_name("add_nodes_and_edges_queue", create_if_missing=True) +add_data_points_queue = Queue.from_name("add_data_points_queue", create_if_missing=True) diff --git a/distributed/tasks/queued_add_data_points.py b/distributed/tasks/queued_add_data_points.py new file mode 100644 index 000000000..48798111c --- /dev/null +++ b/distributed/tasks/queued_add_data_points.py @@ -0,0 +1,12 @@ +async def queued_add_data_points(collection_name, data_points_batch): + from ..queues import add_data_points_queue + + try: + await add_data_points_queue.put.aio((collection_name, data_points_batch)) + except Exception: + first_half, second_half = ( + data_points_batch[: len(data_points_batch) // 2], + data_points_batch[len(data_points_batch) // 2 :], + ) + await add_data_points_queue.put.aio((collection_name, first_half)) + await add_data_points_queue.put.aio((collection_name, second_half)) diff --git a/distributed/tasks/queued_add_edges.py b/distributed/tasks/queued_add_edges.py index 44a0c4530..e6d3ff53d 100644 --- a/distributed/tasks/queued_add_edges.py +++ b/distributed/tasks/queued_add_edges.py @@ -1,12 +1,12 @@ async def queued_add_edges(edge_batch): - from ..queues import save_data_points_queue + from ..queues import add_nodes_and_edges_queue try: - await save_data_points_queue.put.aio(([], edge_batch)) + await add_nodes_and_edges_queue.put.aio(([], edge_batch)) except Exception: first_half, second_half = ( edge_batch[: len(edge_batch) // 2], edge_batch[len(edge_batch) // 2 :], ) - await save_data_points_queue.put.aio(([], first_half)) - await save_data_points_queue.put.aio(([], second_half)) + await add_nodes_and_edges_queue.put.aio(([], first_half)) + await add_nodes_and_edges_queue.put.aio(([], second_half)) diff --git a/distributed/tasks/queued_add_nodes.py b/distributed/tasks/queued_add_nodes.py index 44278c734..72bdc6fcd 100644 --- a/distributed/tasks/queued_add_nodes.py +++ b/distributed/tasks/queued_add_nodes.py @@ -1,12 +1,12 @@ async def queued_add_nodes(node_batch): - from ..queues import save_data_points_queue + from ..queues import add_nodes_and_edges_queue try: - await save_data_points_queue.put.aio((node_batch, [])) + await add_nodes_and_edges_queue.put.aio((node_batch, [])) except Exception: first_half, second_half = ( node_batch[: len(node_batch) // 2], node_batch[len(node_batch) // 2 :], ) - await save_data_points_queue.put.aio((first_half, [])) - await save_data_points_queue.put.aio((second_half, [])) + await add_nodes_and_edges_queue.put.aio((first_half, [])) + await add_nodes_and_edges_queue.put.aio((second_half, [])) diff --git a/cognee/infrastructure/databases/graph/utils.py b/distributed/utils.py similarity index 100% rename from cognee/infrastructure/databases/graph/utils.py rename to distributed/utils.py diff --git a/distributed/workers/data_point_saving_worker.py b/distributed/workers/data_point_saving_worker.py new file mode 100644 index 000000000..4c85f7f32 --- /dev/null +++ b/distributed/workers/data_point_saving_worker.py @@ -0,0 +1,51 @@ +import modal +import asyncio + + +from distributed.app import app +from distributed.modal_image import image +from distributed.queues import add_data_points_queue + +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.databases.vector import get_vector_engine + + +logger = get_logger("data_point_saving_worker") + + +@app.function( + image=image, + timeout=86400, + max_containers=5, + secrets=[modal.Secret.from_name("distributed_cognee")], +) +async def data_point_saving_worker(): + print("Started processing of data points; starting vector engine queue.") + vector_engine = get_vector_engine() + + while True: + if await add_data_points_queue.len.aio() != 0: + try: + add_data_points_request = await add_data_points_queue.get.aio(block=False) + except modal.exception.DeserializationError as error: + logger.error(f"Deserialization error: {str(error)}") + continue + + if len(add_data_points_request) == 0: + print("Finished processing all data points; stopping vector engine queue.") + return True + + if len(add_data_points_request) == 2: + (collection_name, data_points) = add_data_points_request + + print(f"Adding {len(data_points)} data points to '{collection_name}' collection.") + + await vector_engine.create_data_points( + collection_name, data_points, distributed=False + ) + + print("Finished adding data points.") + + else: + print("No jobs, go to sleep.") + await asyncio.sleep(5) diff --git a/distributed/workers/data_point_saver_worker.py b/distributed/workers/graph_saving_worker.py similarity index 73% rename from distributed/workers/data_point_saver_worker.py rename to distributed/workers/graph_saving_worker.py index 8db8f32ee..958ba4727 100644 --- a/distributed/workers/data_point_saver_worker.py +++ b/distributed/workers/graph_saving_worker.py @@ -4,29 +4,29 @@ import asyncio from distributed.app import app from distributed.modal_image import image -from distributed.queues import save_data_points_queue +from distributed.queues import add_nodes_and_edges_queue from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.graph import get_graph_engine -logger = get_logger("data_point_saver_worker") +logger = get_logger("graph_saving_worker") @app.function( image=image, timeout=86400, - max_containers=100, + max_containers=5, secrets=[modal.Secret.from_name("distributed_cognee")], ) -async def data_point_saver_worker(): +async def graph_saving_worker(): print("Started processing of nodes and edges; starting graph engine queue.") graph_engine = await get_graph_engine() while True: - if await save_data_points_queue.len.aio() != 0: + if await add_nodes_and_edges_queue.len.aio() != 0: try: - nodes_and_edges = await save_data_points_queue.get.aio(block=False) + nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False) except modal.exception.DeserializationError as error: logger.error(f"Deserialization error: {str(error)}") continue @@ -37,7 +37,7 @@ async def data_point_saver_worker(): if len(nodes_and_edges) == 2: print( - f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges." + f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges." ) nodes = nodes_and_edges[0] edges = nodes_and_edges[1] @@ -47,7 +47,7 @@ async def data_point_saver_worker(): if edges: await graph_engine.add_edges(edges, distributed=False) - print("Finished processing nodes and edges.") + print("Finished adding nodes and edges.") else: print("No jobs, go to sleep.")