fix: add queue for data points saving
This commit is contained in:
parent
4eba76ca1f
commit
f8f1bb3576
12 changed files with 105 additions and 38 deletions
|
|
@ -12,13 +12,13 @@ from typing import Optional, Any, List, Dict, Type, Tuple
|
|||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.infrastructure.databases.graph.utils import override_distributed
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
)
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
from distributed.utils import override_distributed
|
||||
from distributed.tasks.queued_add_nodes import queued_add_nodes
|
||||
from distributed.tasks.queued_add_edges import queued_add_edges
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import asyncio
|
||||
from uuid import UUID, uuid4
|
||||
from sqlalchemy.inspection import inspect
|
||||
from typing import List, Optional, Union, get_type_hints
|
||||
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -16,6 +14,9 @@ from cognee.infrastructure.engine import DataPoint
|
|||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from distributed.utils import override_distributed
|
||||
from distributed.tasks.queued_add_data_points import queued_add_data_points
|
||||
|
||||
from ...relational.ModelBase import Base
|
||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from ..utils import normalize_distances
|
||||
|
|
@ -160,6 +161,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
stop=stop_after_attempt(3),
|
||||
sleep=1,
|
||||
)
|
||||
@override_distributed(queued_add_data_points)
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
if not await self.has_collection(collection_name):
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ def override_run_tasks(new_gen):
|
|||
default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true"
|
||||
distributed = default_distributed_value if distributed is None else distributed
|
||||
|
||||
print(f"run_tasks_distributed: {distributed}")
|
||||
if distributed:
|
||||
async for run_info in new_gen(*args, **kwargs):
|
||||
yield run_info
|
||||
|
|
|
|||
|
|
@ -66,9 +66,6 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
|||
|
||||
data_count = len(data) if isinstance(data, list) else 1
|
||||
|
||||
print(f"Data count: {data_count}")
|
||||
print(f"Data item to process: {type(data)} - {data}")
|
||||
|
||||
arguments = [
|
||||
[tasks] * data_count,
|
||||
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
|
||||
|
|
|
|||
|
|
@ -7,8 +7,9 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.modules.engine.operations.setup import setup
|
||||
|
||||
from distributed.app import app
|
||||
from distributed.queues import save_data_points_queue
|
||||
from distributed.workers.data_point_saver_worker import data_point_saver_worker
|
||||
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
|
||||
from distributed.workers.graph_saving_worker import graph_saving_worker
|
||||
from distributed.workers.data_point_saving_worker import data_point_saving_worker
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -19,9 +20,10 @@ os.environ["COGNEE_DISTRIBUTED"] = "True"
|
|||
@app.local_entrypoint()
|
||||
async def main():
|
||||
# Clear queues
|
||||
await save_data_points_queue.clear.aio()
|
||||
await add_nodes_and_edges_queue.clear.aio()
|
||||
|
||||
number_of_data_saving_workers = 1 # Total number of data_point_saver_worker functions to spawn
|
||||
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
|
||||
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn
|
||||
|
||||
results = []
|
||||
consumer_futures = []
|
||||
|
|
@ -32,9 +34,14 @@ async def main():
|
|||
|
||||
await setup()
|
||||
|
||||
# Start data_point_saver_worker functions
|
||||
for _ in range(number_of_data_saving_workers):
|
||||
worker_future = data_point_saver_worker.spawn()
|
||||
# Start graph_saving_worker functions
|
||||
for _ in range(number_of_graph_saving_workers):
|
||||
worker_future = graph_saving_worker.spawn()
|
||||
consumer_futures.append(worker_future)
|
||||
|
||||
# Start data_point_saving_worker functions
|
||||
for _ in range(number_of_data_point_saving_workers):
|
||||
worker_future = data_point_saving_worker.spawn()
|
||||
consumer_futures.append(worker_future)
|
||||
|
||||
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
|
||||
|
|
@ -44,11 +51,12 @@ async def main():
|
|||
await cognee.cognify(datasets=["s3-files"])
|
||||
|
||||
# Push empty tuple into the queue to signal the end of data.
|
||||
await save_data_points_queue.put.aio(())
|
||||
await add_nodes_and_edges_queue.put.aio(())
|
||||
await add_data_points_queue.put.aio(())
|
||||
|
||||
for consumer_future in consumer_futures:
|
||||
try:
|
||||
print("Finished but waiting for saving worker to finish.")
|
||||
print("Finished but waiting for saving workers to finish.")
|
||||
consumer_final = consumer_future.get()
|
||||
print(f"All workers are done: {consumer_final}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
from modal import Queue
|
||||
|
||||
|
||||
# Create (or get) queues:
|
||||
# - save_data_points_queue: Stores messages produced by the producer functions.
|
||||
|
||||
save_data_points_queue = Queue.from_name("save_data_points_queue", create_if_missing=True)
|
||||
add_nodes_and_edges_queue = Queue.from_name("add_nodes_and_edges_queue", create_if_missing=True)
|
||||
add_data_points_queue = Queue.from_name("add_data_points_queue", create_if_missing=True)
|
||||
|
|
|
|||
12
distributed/tasks/queued_add_data_points.py
Normal file
12
distributed/tasks/queued_add_data_points.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
async def queued_add_data_points(collection_name, data_points_batch):
|
||||
from ..queues import add_data_points_queue
|
||||
|
||||
try:
|
||||
await add_data_points_queue.put.aio((collection_name, data_points_batch))
|
||||
except Exception:
|
||||
first_half, second_half = (
|
||||
data_points_batch[: len(data_points_batch) // 2],
|
||||
data_points_batch[len(data_points_batch) // 2 :],
|
||||
)
|
||||
await add_data_points_queue.put.aio((collection_name, first_half))
|
||||
await add_data_points_queue.put.aio((collection_name, second_half))
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
async def queued_add_edges(edge_batch):
|
||||
from ..queues import save_data_points_queue
|
||||
from ..queues import add_nodes_and_edges_queue
|
||||
|
||||
try:
|
||||
await save_data_points_queue.put.aio(([], edge_batch))
|
||||
await add_nodes_and_edges_queue.put.aio(([], edge_batch))
|
||||
except Exception:
|
||||
first_half, second_half = (
|
||||
edge_batch[: len(edge_batch) // 2],
|
||||
edge_batch[len(edge_batch) // 2 :],
|
||||
)
|
||||
await save_data_points_queue.put.aio(([], first_half))
|
||||
await save_data_points_queue.put.aio(([], second_half))
|
||||
await add_nodes_and_edges_queue.put.aio(([], first_half))
|
||||
await add_nodes_and_edges_queue.put.aio(([], second_half))
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
async def queued_add_nodes(node_batch):
|
||||
from ..queues import save_data_points_queue
|
||||
from ..queues import add_nodes_and_edges_queue
|
||||
|
||||
try:
|
||||
await save_data_points_queue.put.aio((node_batch, []))
|
||||
await add_nodes_and_edges_queue.put.aio((node_batch, []))
|
||||
except Exception:
|
||||
first_half, second_half = (
|
||||
node_batch[: len(node_batch) // 2],
|
||||
node_batch[len(node_batch) // 2 :],
|
||||
)
|
||||
await save_data_points_queue.put.aio((first_half, []))
|
||||
await save_data_points_queue.put.aio((second_half, []))
|
||||
await add_nodes_and_edges_queue.put.aio((first_half, []))
|
||||
await add_nodes_and_edges_queue.put.aio((second_half, []))
|
||||
|
|
|
|||
51
distributed/workers/data_point_saving_worker.py
Normal file
51
distributed/workers/data_point_saving_worker.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import modal
|
||||
import asyncio
|
||||
|
||||
|
||||
from distributed.app import app
|
||||
from distributed.modal_image import image
|
||||
from distributed.queues import add_data_points_queue
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
|
||||
logger = get_logger("data_point_saving_worker")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
timeout=86400,
|
||||
max_containers=5,
|
||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||
)
|
||||
async def data_point_saving_worker():
|
||||
print("Started processing of data points; starting vector engine queue.")
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
while True:
|
||||
if await add_data_points_queue.len.aio() != 0:
|
||||
try:
|
||||
add_data_points_request = await add_data_points_queue.get.aio(block=False)
|
||||
except modal.exception.DeserializationError as error:
|
||||
logger.error(f"Deserialization error: {str(error)}")
|
||||
continue
|
||||
|
||||
if len(add_data_points_request) == 0:
|
||||
print("Finished processing all data points; stopping vector engine queue.")
|
||||
return True
|
||||
|
||||
if len(add_data_points_request) == 2:
|
||||
(collection_name, data_points) = add_data_points_request
|
||||
|
||||
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
|
||||
|
||||
await vector_engine.create_data_points(
|
||||
collection_name, data_points, distributed=False
|
||||
)
|
||||
|
||||
print("Finished adding data points.")
|
||||
|
||||
else:
|
||||
print("No jobs, go to sleep.")
|
||||
await asyncio.sleep(5)
|
||||
|
|
@ -4,29 +4,29 @@ import asyncio
|
|||
|
||||
from distributed.app import app
|
||||
from distributed.modal_image import image
|
||||
from distributed.queues import save_data_points_queue
|
||||
from distributed.queues import add_nodes_and_edges_queue
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
|
||||
logger = get_logger("data_point_saver_worker")
|
||||
logger = get_logger("graph_saving_worker")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
timeout=86400,
|
||||
max_containers=100,
|
||||
max_containers=5,
|
||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||
)
|
||||
async def data_point_saver_worker():
|
||||
async def graph_saving_worker():
|
||||
print("Started processing of nodes and edges; starting graph engine queue.")
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
while True:
|
||||
if await save_data_points_queue.len.aio() != 0:
|
||||
if await add_nodes_and_edges_queue.len.aio() != 0:
|
||||
try:
|
||||
nodes_and_edges = await save_data_points_queue.get.aio(block=False)
|
||||
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
|
||||
except modal.exception.DeserializationError as error:
|
||||
logger.error(f"Deserialization error: {str(error)}")
|
||||
continue
|
||||
|
|
@ -37,7 +37,7 @@ async def data_point_saver_worker():
|
|||
|
||||
if len(nodes_and_edges) == 2:
|
||||
print(
|
||||
f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
||||
f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
||||
)
|
||||
nodes = nodes_and_edges[0]
|
||||
edges = nodes_and_edges[1]
|
||||
|
|
@ -47,7 +47,7 @@ async def data_point_saver_worker():
|
|||
|
||||
if edges:
|
||||
await graph_engine.add_edges(edges, distributed=False)
|
||||
print("Finished processing nodes and edges.")
|
||||
print("Finished adding nodes and edges.")
|
||||
|
||||
else:
|
||||
print("No jobs, go to sleep.")
|
||||
Loading…
Add table
Reference in a new issue