fix: add queue for data points saving

This commit is contained in:
Boris Arzentar 2025-07-04 18:26:22 +02:00
parent 4eba76ca1f
commit f8f1bb3576
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
12 changed files with 105 additions and 38 deletions

View file

@ -12,13 +12,13 @@ from typing import Optional, Any, List, Dict, Type, Tuple
from cognee.infrastructure.engine import DataPoint
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.utils import override_distributed
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.modules.storage.utils import JSONEncoder
from distributed.utils import override_distributed
from distributed.tasks.queued_add_nodes import queued_add_nodes
from distributed.tasks.queued_add_edges import queued_add_edges

View file

@ -1,14 +1,12 @@
import asyncio
from uuid import UUID, uuid4
from sqlalchemy.inspection import inspect
from typing import List, Optional, Union, get_type_hints
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger
@ -16,6 +14,9 @@ from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
from distributed.utils import override_distributed
from distributed.tasks.queued_add_data_points import queued_add_data_points
from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..utils import normalize_distances
@ -160,6 +161,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(3),
sleep=1,
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name):

View file

@ -35,7 +35,6 @@ def override_run_tasks(new_gen):
default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true"
distributed = default_distributed_value if distributed is None else distributed
print(f"run_tasks_distributed: {distributed}")
if distributed:
async for run_info in new_gen(*args, **kwargs):
yield run_info

View file

@ -66,9 +66,6 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
data_count = len(data) if isinstance(data, list) else 1
print(f"Data count: {data_count}")
print(f"Data item to process: {type(data)} - {data}")
arguments = [
[tasks] * data_count,
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],

View file

@ -7,8 +7,9 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.operations.setup import setup
from distributed.app import app
from distributed.queues import save_data_points_queue
from distributed.workers.data_point_saver_worker import data_point_saver_worker
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
from distributed.workers.graph_saving_worker import graph_saving_worker
from distributed.workers.data_point_saving_worker import data_point_saving_worker
logger = get_logger()
@ -19,9 +20,10 @@ os.environ["COGNEE_DISTRIBUTED"] = "True"
@app.local_entrypoint()
async def main():
# Clear queues
await save_data_points_queue.clear.aio()
await add_nodes_and_edges_queue.clear.aio()
number_of_data_saving_workers = 1 # Total number of data_point_saver_worker functions to spawn
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn
results = []
consumer_futures = []
@ -32,9 +34,14 @@ async def main():
await setup()
# Start data_point_saver_worker functions
for _ in range(number_of_data_saving_workers):
worker_future = data_point_saver_worker.spawn()
# Start graph_saving_worker functions
for _ in range(number_of_graph_saving_workers):
worker_future = graph_saving_worker.spawn()
consumer_futures.append(worker_future)
# Start data_point_saving_worker functions
for _ in range(number_of_data_point_saving_workers):
worker_future = data_point_saving_worker.spawn()
consumer_futures.append(worker_future)
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
@ -44,11 +51,12 @@ async def main():
await cognee.cognify(datasets=["s3-files"])
# Push empty tuple into the queue to signal the end of data.
await save_data_points_queue.put.aio(())
await add_nodes_and_edges_queue.put.aio(())
await add_data_points_queue.put.aio(())
for consumer_future in consumer_futures:
try:
print("Finished but waiting for saving worker to finish.")
print("Finished but waiting for saving workers to finish.")
consumer_final = consumer_future.get()
print(f"All workers are done: {consumer_final}")
except Exception as e:

View file

@ -1,7 +1,5 @@
from modal import Queue
# Create (or get) queues:
# - save_data_points_queue: Stores messages produced by the producer functions.
save_data_points_queue = Queue.from_name("save_data_points_queue", create_if_missing=True)
add_nodes_and_edges_queue = Queue.from_name("add_nodes_and_edges_queue", create_if_missing=True)
add_data_points_queue = Queue.from_name("add_data_points_queue", create_if_missing=True)

View file

@ -0,0 +1,12 @@
async def queued_add_data_points(collection_name, data_points_batch):
from ..queues import add_data_points_queue
try:
await add_data_points_queue.put.aio((collection_name, data_points_batch))
except Exception:
first_half, second_half = (
data_points_batch[: len(data_points_batch) // 2],
data_points_batch[len(data_points_batch) // 2 :],
)
await add_data_points_queue.put.aio((collection_name, first_half))
await add_data_points_queue.put.aio((collection_name, second_half))

View file

@ -1,12 +1,12 @@
async def queued_add_edges(edge_batch):
from ..queues import save_data_points_queue
from ..queues import add_nodes_and_edges_queue
try:
await save_data_points_queue.put.aio(([], edge_batch))
await add_nodes_and_edges_queue.put.aio(([], edge_batch))
except Exception:
first_half, second_half = (
edge_batch[: len(edge_batch) // 2],
edge_batch[len(edge_batch) // 2 :],
)
await save_data_points_queue.put.aio(([], first_half))
await save_data_points_queue.put.aio(([], second_half))
await add_nodes_and_edges_queue.put.aio(([], first_half))
await add_nodes_and_edges_queue.put.aio(([], second_half))

View file

@ -1,12 +1,12 @@
async def queued_add_nodes(node_batch):
from ..queues import save_data_points_queue
from ..queues import add_nodes_and_edges_queue
try:
await save_data_points_queue.put.aio((node_batch, []))
await add_nodes_and_edges_queue.put.aio((node_batch, []))
except Exception:
first_half, second_half = (
node_batch[: len(node_batch) // 2],
node_batch[len(node_batch) // 2 :],
)
await save_data_points_queue.put.aio((first_half, []))
await save_data_points_queue.put.aio((second_half, []))
await add_nodes_and_edges_queue.put.aio((first_half, []))
await add_nodes_and_edges_queue.put.aio((second_half, []))

View file

@ -0,0 +1,51 @@
import modal
import asyncio
from distributed.app import app
from distributed.modal_image import image
from distributed.queues import add_data_points_queue
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
logger = get_logger("data_point_saving_worker")
@app.function(
image=image,
timeout=86400,
max_containers=5,
secrets=[modal.Secret.from_name("distributed_cognee")],
)
async def data_point_saving_worker():
print("Started processing of data points; starting vector engine queue.")
vector_engine = get_vector_engine()
while True:
if await add_data_points_queue.len.aio() != 0:
try:
add_data_points_request = await add_data_points_queue.get.aio(block=False)
except modal.exception.DeserializationError as error:
logger.error(f"Deserialization error: {str(error)}")
continue
if len(add_data_points_request) == 0:
print("Finished processing all data points; stopping vector engine queue.")
return True
if len(add_data_points_request) == 2:
(collection_name, data_points) = add_data_points_request
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
await vector_engine.create_data_points(
collection_name, data_points, distributed=False
)
print("Finished adding data points.")
else:
print("No jobs, go to sleep.")
await asyncio.sleep(5)

View file

@ -4,29 +4,29 @@ import asyncio
from distributed.app import app
from distributed.modal_image import image
from distributed.queues import save_data_points_queue
from distributed.queues import add_nodes_and_edges_queue
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
logger = get_logger("data_point_saver_worker")
logger = get_logger("graph_saving_worker")
@app.function(
image=image,
timeout=86400,
max_containers=100,
max_containers=5,
secrets=[modal.Secret.from_name("distributed_cognee")],
)
async def data_point_saver_worker():
async def graph_saving_worker():
print("Started processing of nodes and edges; starting graph engine queue.")
graph_engine = await get_graph_engine()
while True:
if await save_data_points_queue.len.aio() != 0:
if await add_nodes_and_edges_queue.len.aio() != 0:
try:
nodes_and_edges = await save_data_points_queue.get.aio(block=False)
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
except modal.exception.DeserializationError as error:
logger.error(f"Deserialization error: {str(error)}")
continue
@ -37,7 +37,7 @@ async def data_point_saver_worker():
if len(nodes_and_edges) == 2:
print(
f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
)
nodes = nodes_and_edges[0]
edges = nodes_and_edges[1]
@ -47,7 +47,7 @@ async def data_point_saver_worker():
if edges:
await graph_engine.add_edges(edges, distributed=False)
print("Finished processing nodes and edges.")
print("Finished adding nodes and edges.")
else:
print("No jobs, go to sleep.")