fix: add queue for data points saving

This commit is contained in:
Boris Arzentar 2025-07-04 18:26:22 +02:00
parent 4eba76ca1f
commit f8f1bb3576
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
12 changed files with 105 additions and 38 deletions

View file

@ -12,13 +12,13 @@ from typing import Optional, Any, List, Dict, Type, Tuple
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.utils import override_distributed
from cognee.infrastructure.databases.graph.graph_db_interface import ( from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface, GraphDBInterface,
record_graph_changes, record_graph_changes,
) )
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from distributed.utils import override_distributed
from distributed.tasks.queued_add_nodes import queued_add_nodes from distributed.tasks.queued_add_nodes import queued_add_nodes
from distributed.tasks.queued_add_edges import queued_add_edges from distributed.tasks.queued_add_edges import queued_add_edges

View file

@ -1,14 +1,12 @@
import asyncio import asyncio
from uuid import UUID, uuid4
from sqlalchemy.inspection import inspect
from typing import List, Optional, Union, get_type_hints from typing import List, Optional, Union, get_type_hints
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from tenacity import retry, retry_if_exception_type, stop_after_attempt from tenacity import retry, retry_if_exception_type, stop_after_attempt
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
@ -16,6 +14,9 @@ from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from distributed.utils import override_distributed
from distributed.tasks.queued_add_data_points import queued_add_data_points
from ...relational.ModelBase import Base from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..utils import normalize_distances from ..utils import normalize_distances
@ -160,6 +161,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
sleep=1, sleep=1,
) )
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):

View file

@ -35,7 +35,6 @@ def override_run_tasks(new_gen):
default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true" default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true"
distributed = default_distributed_value if distributed is None else distributed distributed = default_distributed_value if distributed is None else distributed
print(f"run_tasks_distributed: {distributed}")
if distributed: if distributed:
async for run_info in new_gen(*args, **kwargs): async for run_info in new_gen(*args, **kwargs):
yield run_info yield run_info

View file

@ -66,9 +66,6 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
data_count = len(data) if isinstance(data, list) else 1 data_count = len(data) if isinstance(data, list) else 1
print(f"Data count: {data_count}")
print(f"Data item to process: {type(data)} - {data}")
arguments = [ arguments = [
[tasks] * data_count, [tasks] * data_count,
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data], [[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],

View file

@ -7,8 +7,9 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.operations.setup import setup from cognee.modules.engine.operations.setup import setup
from distributed.app import app from distributed.app import app
from distributed.queues import save_data_points_queue from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
from distributed.workers.data_point_saver_worker import data_point_saver_worker from distributed.workers.graph_saving_worker import graph_saving_worker
from distributed.workers.data_point_saving_worker import data_point_saving_worker
logger = get_logger() logger = get_logger()
@ -19,9 +20,10 @@ os.environ["COGNEE_DISTRIBUTED"] = "True"
@app.local_entrypoint() @app.local_entrypoint()
async def main(): async def main():
# Clear queues # Clear queues
await save_data_points_queue.clear.aio() await add_nodes_and_edges_queue.clear.aio()
number_of_data_saving_workers = 1 # Total number of data_point_saver_worker functions to spawn number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn
results = [] results = []
consumer_futures = [] consumer_futures = []
@ -32,9 +34,14 @@ async def main():
await setup() await setup()
# Start data_point_saver_worker functions # Start graph_saving_worker functions
for _ in range(number_of_data_saving_workers): for _ in range(number_of_graph_saving_workers):
worker_future = data_point_saver_worker.spawn() worker_future = graph_saving_worker.spawn()
consumer_futures.append(worker_future)
# Start data_point_saving_worker functions
for _ in range(number_of_data_point_saving_workers):
worker_future = data_point_saving_worker.spawn()
consumer_futures.append(worker_future) consumer_futures.append(worker_future)
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1" s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
@ -44,11 +51,12 @@ async def main():
await cognee.cognify(datasets=["s3-files"]) await cognee.cognify(datasets=["s3-files"])
# Push empty tuple into the queue to signal the end of data. # Push empty tuple into the queue to signal the end of data.
await save_data_points_queue.put.aio(()) await add_nodes_and_edges_queue.put.aio(())
await add_data_points_queue.put.aio(())
for consumer_future in consumer_futures: for consumer_future in consumer_futures:
try: try:
print("Finished but waiting for saving worker to finish.") print("Finished but waiting for saving workers to finish.")
consumer_final = consumer_future.get() consumer_final = consumer_future.get()
print(f"All workers are done: {consumer_final}") print(f"All workers are done: {consumer_final}")
except Exception as e: except Exception as e:

View file

@ -1,7 +1,5 @@
from modal import Queue from modal import Queue
# Create (or get) queues: add_nodes_and_edges_queue = Queue.from_name("add_nodes_and_edges_queue", create_if_missing=True)
# - save_data_points_queue: Stores messages produced by the producer functions. add_data_points_queue = Queue.from_name("add_data_points_queue", create_if_missing=True)
save_data_points_queue = Queue.from_name("save_data_points_queue", create_if_missing=True)

View file

@ -0,0 +1,12 @@
async def queued_add_data_points(collection_name, data_points_batch):
from ..queues import add_data_points_queue
try:
await add_data_points_queue.put.aio((collection_name, data_points_batch))
except Exception:
first_half, second_half = (
data_points_batch[: len(data_points_batch) // 2],
data_points_batch[len(data_points_batch) // 2 :],
)
await add_data_points_queue.put.aio((collection_name, first_half))
await add_data_points_queue.put.aio((collection_name, second_half))

View file

@ -1,12 +1,12 @@
async def queued_add_edges(edge_batch): async def queued_add_edges(edge_batch):
from ..queues import save_data_points_queue from ..queues import add_nodes_and_edges_queue
try: try:
await save_data_points_queue.put.aio(([], edge_batch)) await add_nodes_and_edges_queue.put.aio(([], edge_batch))
except Exception: except Exception:
first_half, second_half = ( first_half, second_half = (
edge_batch[: len(edge_batch) // 2], edge_batch[: len(edge_batch) // 2],
edge_batch[len(edge_batch) // 2 :], edge_batch[len(edge_batch) // 2 :],
) )
await save_data_points_queue.put.aio(([], first_half)) await add_nodes_and_edges_queue.put.aio(([], first_half))
await save_data_points_queue.put.aio(([], second_half)) await add_nodes_and_edges_queue.put.aio(([], second_half))

View file

@ -1,12 +1,12 @@
async def queued_add_nodes(node_batch): async def queued_add_nodes(node_batch):
from ..queues import save_data_points_queue from ..queues import add_nodes_and_edges_queue
try: try:
await save_data_points_queue.put.aio((node_batch, [])) await add_nodes_and_edges_queue.put.aio((node_batch, []))
except Exception: except Exception:
first_half, second_half = ( first_half, second_half = (
node_batch[: len(node_batch) // 2], node_batch[: len(node_batch) // 2],
node_batch[len(node_batch) // 2 :], node_batch[len(node_batch) // 2 :],
) )
await save_data_points_queue.put.aio((first_half, [])) await add_nodes_and_edges_queue.put.aio((first_half, []))
await save_data_points_queue.put.aio((second_half, [])) await add_nodes_and_edges_queue.put.aio((second_half, []))

View file

@ -0,0 +1,51 @@
import modal
import asyncio
from distributed.app import app
from distributed.modal_image import image
from distributed.queues import add_data_points_queue
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
logger = get_logger("data_point_saving_worker")
@app.function(
image=image,
timeout=86400,
max_containers=5,
secrets=[modal.Secret.from_name("distributed_cognee")],
)
async def data_point_saving_worker():
print("Started processing of data points; starting vector engine queue.")
vector_engine = get_vector_engine()
while True:
if await add_data_points_queue.len.aio() != 0:
try:
add_data_points_request = await add_data_points_queue.get.aio(block=False)
except modal.exception.DeserializationError as error:
logger.error(f"Deserialization error: {str(error)}")
continue
if len(add_data_points_request) == 0:
print("Finished processing all data points; stopping vector engine queue.")
return True
if len(add_data_points_request) == 2:
(collection_name, data_points) = add_data_points_request
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
await vector_engine.create_data_points(
collection_name, data_points, distributed=False
)
print("Finished adding data points.")
else:
print("No jobs, go to sleep.")
await asyncio.sleep(5)

View file

@ -4,29 +4,29 @@ import asyncio
from distributed.app import app from distributed.app import app
from distributed.modal_image import image from distributed.modal_image import image
from distributed.queues import save_data_points_queue from distributed.queues import add_nodes_and_edges_queue
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
logger = get_logger("data_point_saver_worker") logger = get_logger("graph_saving_worker")
@app.function( @app.function(
image=image, image=image,
timeout=86400, timeout=86400,
max_containers=100, max_containers=5,
secrets=[modal.Secret.from_name("distributed_cognee")], secrets=[modal.Secret.from_name("distributed_cognee")],
) )
async def data_point_saver_worker(): async def graph_saving_worker():
print("Started processing of nodes and edges; starting graph engine queue.") print("Started processing of nodes and edges; starting graph engine queue.")
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
while True: while True:
if await save_data_points_queue.len.aio() != 0: if await add_nodes_and_edges_queue.len.aio() != 0:
try: try:
nodes_and_edges = await save_data_points_queue.get.aio(block=False) nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
except modal.exception.DeserializationError as error: except modal.exception.DeserializationError as error:
logger.error(f"Deserialization error: {str(error)}") logger.error(f"Deserialization error: {str(error)}")
continue continue
@ -37,7 +37,7 @@ async def data_point_saver_worker():
if len(nodes_and_edges) == 2: if len(nodes_and_edges) == 2:
print( print(
f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges." f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
) )
nodes = nodes_and_edges[0] nodes = nodes_and_edges[0]
edges = nodes_and_edges[1] edges = nodes_and_edges[1]
@ -47,7 +47,7 @@ async def data_point_saver_worker():
if edges: if edges:
await graph_engine.add_edges(edges, distributed=False) await graph_engine.add_edges(edges, distributed=False)
print("Finished processing nodes and edges.") print("Finished adding nodes and edges.")
else: else:
print("No jobs, go to sleep.") print("No jobs, go to sleep.")