fix: add queue for data points saving
This commit is contained in:
parent
4eba76ca1f
commit
f8f1bb3576
12 changed files with 105 additions and 38 deletions
|
|
@ -12,13 +12,13 @@ from typing import Optional, Any, List, Dict, Type, Tuple
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
from cognee.infrastructure.databases.graph.utils import override_distributed
|
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
GraphDBInterface,
|
GraphDBInterface,
|
||||||
record_graph_changes,
|
record_graph_changes,
|
||||||
)
|
)
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
|
||||||
|
from distributed.utils import override_distributed
|
||||||
from distributed.tasks.queued_add_nodes import queued_add_nodes
|
from distributed.tasks.queued_add_nodes import queued_add_nodes
|
||||||
from distributed.tasks.queued_add_edges import queued_add_edges
|
from distributed.tasks.queued_add_edges import queued_add_edges
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,12 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID, uuid4
|
|
||||||
from sqlalchemy.inspection import inspect
|
|
||||||
from typing import List, Optional, Union, get_type_hints
|
from typing import List, Optional, Union, get_type_hints
|
||||||
|
from sqlalchemy.inspection import inspect
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
|
||||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||||
|
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -16,6 +14,9 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
|
from distributed.utils import override_distributed
|
||||||
|
from distributed.tasks.queued_add_data_points import queued_add_data_points
|
||||||
|
|
||||||
from ...relational.ModelBase import Base
|
from ...relational.ModelBase import Base
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
from ..utils import normalize_distances
|
from ..utils import normalize_distances
|
||||||
|
|
@ -160,6 +161,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
sleep=1,
|
sleep=1,
|
||||||
)
|
)
|
||||||
|
@override_distributed(queued_add_data_points)
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,6 @@ def override_run_tasks(new_gen):
|
||||||
default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true"
|
default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true"
|
||||||
distributed = default_distributed_value if distributed is None else distributed
|
distributed = default_distributed_value if distributed is None else distributed
|
||||||
|
|
||||||
print(f"run_tasks_distributed: {distributed}")
|
|
||||||
if distributed:
|
if distributed:
|
||||||
async for run_info in new_gen(*args, **kwargs):
|
async for run_info in new_gen(*args, **kwargs):
|
||||||
yield run_info
|
yield run_info
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,6 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
||||||
|
|
||||||
data_count = len(data) if isinstance(data, list) else 1
|
data_count = len(data) if isinstance(data, list) else 1
|
||||||
|
|
||||||
print(f"Data count: {data_count}")
|
|
||||||
print(f"Data item to process: {type(data)} - {data}")
|
|
||||||
|
|
||||||
arguments = [
|
arguments = [
|
||||||
[tasks] * data_count,
|
[tasks] * data_count,
|
||||||
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
|
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,9 @@ from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
|
|
||||||
from distributed.app import app
|
from distributed.app import app
|
||||||
from distributed.queues import save_data_points_queue
|
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
|
||||||
from distributed.workers.data_point_saver_worker import data_point_saver_worker
|
from distributed.workers.graph_saving_worker import graph_saving_worker
|
||||||
|
from distributed.workers.data_point_saving_worker import data_point_saving_worker
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -19,9 +20,10 @@ os.environ["COGNEE_DISTRIBUTED"] = "True"
|
||||||
@app.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
async def main():
|
async def main():
|
||||||
# Clear queues
|
# Clear queues
|
||||||
await save_data_points_queue.clear.aio()
|
await add_nodes_and_edges_queue.clear.aio()
|
||||||
|
|
||||||
number_of_data_saving_workers = 1 # Total number of data_point_saver_worker functions to spawn
|
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
|
||||||
|
number_of_data_point_saving_workers = 2 # Total number of graph_saving_worker to spawn
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
consumer_futures = []
|
consumer_futures = []
|
||||||
|
|
@ -32,9 +34,14 @@ async def main():
|
||||||
|
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
# Start data_point_saver_worker functions
|
# Start graph_saving_worker functions
|
||||||
for _ in range(number_of_data_saving_workers):
|
for _ in range(number_of_graph_saving_workers):
|
||||||
worker_future = data_point_saver_worker.spawn()
|
worker_future = graph_saving_worker.spawn()
|
||||||
|
consumer_futures.append(worker_future)
|
||||||
|
|
||||||
|
# Start data_point_saving_worker functions
|
||||||
|
for _ in range(number_of_data_point_saving_workers):
|
||||||
|
worker_future = data_point_saving_worker.spawn()
|
||||||
consumer_futures.append(worker_future)
|
consumer_futures.append(worker_future)
|
||||||
|
|
||||||
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
|
s3_bucket_name = "s3://s3-test-laszlo/Database for KG v1"
|
||||||
|
|
@ -44,11 +51,12 @@ async def main():
|
||||||
await cognee.cognify(datasets=["s3-files"])
|
await cognee.cognify(datasets=["s3-files"])
|
||||||
|
|
||||||
# Push empty tuple into the queue to signal the end of data.
|
# Push empty tuple into the queue to signal the end of data.
|
||||||
await save_data_points_queue.put.aio(())
|
await add_nodes_and_edges_queue.put.aio(())
|
||||||
|
await add_data_points_queue.put.aio(())
|
||||||
|
|
||||||
for consumer_future in consumer_futures:
|
for consumer_future in consumer_futures:
|
||||||
try:
|
try:
|
||||||
print("Finished but waiting for saving worker to finish.")
|
print("Finished but waiting for saving workers to finish.")
|
||||||
consumer_final = consumer_future.get()
|
consumer_final = consumer_future.get()
|
||||||
print(f"All workers are done: {consumer_final}")
|
print(f"All workers are done: {consumer_final}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
from modal import Queue
|
from modal import Queue
|
||||||
|
|
||||||
|
|
||||||
# Create (or get) queues:
|
add_nodes_and_edges_queue = Queue.from_name("add_nodes_and_edges_queue", create_if_missing=True)
|
||||||
# - save_data_points_queue: Stores messages produced by the producer functions.
|
add_data_points_queue = Queue.from_name("add_data_points_queue", create_if_missing=True)
|
||||||
|
|
||||||
save_data_points_queue = Queue.from_name("save_data_points_queue", create_if_missing=True)
|
|
||||||
|
|
|
||||||
12
distributed/tasks/queued_add_data_points.py
Normal file
12
distributed/tasks/queued_add_data_points.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
async def queued_add_data_points(collection_name, data_points_batch):
|
||||||
|
from ..queues import add_data_points_queue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await add_data_points_queue.put.aio((collection_name, data_points_batch))
|
||||||
|
except Exception:
|
||||||
|
first_half, second_half = (
|
||||||
|
data_points_batch[: len(data_points_batch) // 2],
|
||||||
|
data_points_batch[len(data_points_batch) // 2 :],
|
||||||
|
)
|
||||||
|
await add_data_points_queue.put.aio((collection_name, first_half))
|
||||||
|
await add_data_points_queue.put.aio((collection_name, second_half))
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
async def queued_add_edges(edge_batch):
|
async def queued_add_edges(edge_batch):
|
||||||
from ..queues import save_data_points_queue
|
from ..queues import add_nodes_and_edges_queue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await save_data_points_queue.put.aio(([], edge_batch))
|
await add_nodes_and_edges_queue.put.aio(([], edge_batch))
|
||||||
except Exception:
|
except Exception:
|
||||||
first_half, second_half = (
|
first_half, second_half = (
|
||||||
edge_batch[: len(edge_batch) // 2],
|
edge_batch[: len(edge_batch) // 2],
|
||||||
edge_batch[len(edge_batch) // 2 :],
|
edge_batch[len(edge_batch) // 2 :],
|
||||||
)
|
)
|
||||||
await save_data_points_queue.put.aio(([], first_half))
|
await add_nodes_and_edges_queue.put.aio(([], first_half))
|
||||||
await save_data_points_queue.put.aio(([], second_half))
|
await add_nodes_and_edges_queue.put.aio(([], second_half))
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
async def queued_add_nodes(node_batch):
|
async def queued_add_nodes(node_batch):
|
||||||
from ..queues import save_data_points_queue
|
from ..queues import add_nodes_and_edges_queue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await save_data_points_queue.put.aio((node_batch, []))
|
await add_nodes_and_edges_queue.put.aio((node_batch, []))
|
||||||
except Exception:
|
except Exception:
|
||||||
first_half, second_half = (
|
first_half, second_half = (
|
||||||
node_batch[: len(node_batch) // 2],
|
node_batch[: len(node_batch) // 2],
|
||||||
node_batch[len(node_batch) // 2 :],
|
node_batch[len(node_batch) // 2 :],
|
||||||
)
|
)
|
||||||
await save_data_points_queue.put.aio((first_half, []))
|
await add_nodes_and_edges_queue.put.aio((first_half, []))
|
||||||
await save_data_points_queue.put.aio((second_half, []))
|
await add_nodes_and_edges_queue.put.aio((second_half, []))
|
||||||
|
|
|
||||||
51
distributed/workers/data_point_saving_worker.py
Normal file
51
distributed/workers/data_point_saving_worker.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
import modal
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
from distributed.app import app
|
||||||
|
from distributed.modal_image import image
|
||||||
|
from distributed.queues import add_data_points_queue
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("data_point_saving_worker")
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
image=image,
|
||||||
|
timeout=86400,
|
||||||
|
max_containers=5,
|
||||||
|
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||||
|
)
|
||||||
|
async def data_point_saving_worker():
|
||||||
|
print("Started processing of data points; starting vector engine queue.")
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if await add_data_points_queue.len.aio() != 0:
|
||||||
|
try:
|
||||||
|
add_data_points_request = await add_data_points_queue.get.aio(block=False)
|
||||||
|
except modal.exception.DeserializationError as error:
|
||||||
|
logger.error(f"Deserialization error: {str(error)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(add_data_points_request) == 0:
|
||||||
|
print("Finished processing all data points; stopping vector engine queue.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if len(add_data_points_request) == 2:
|
||||||
|
(collection_name, data_points) = add_data_points_request
|
||||||
|
|
||||||
|
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
|
||||||
|
|
||||||
|
await vector_engine.create_data_points(
|
||||||
|
collection_name, data_points, distributed=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Finished adding data points.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("No jobs, go to sleep.")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
@ -4,29 +4,29 @@ import asyncio
|
||||||
|
|
||||||
from distributed.app import app
|
from distributed.app import app
|
||||||
from distributed.modal_image import image
|
from distributed.modal_image import image
|
||||||
from distributed.queues import save_data_points_queue
|
from distributed.queues import add_nodes_and_edges_queue
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("data_point_saver_worker")
|
logger = get_logger("graph_saving_worker")
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
@app.function(
|
||||||
image=image,
|
image=image,
|
||||||
timeout=86400,
|
timeout=86400,
|
||||||
max_containers=100,
|
max_containers=5,
|
||||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||||
)
|
)
|
||||||
async def data_point_saver_worker():
|
async def graph_saving_worker():
|
||||||
print("Started processing of nodes and edges; starting graph engine queue.")
|
print("Started processing of nodes and edges; starting graph engine queue.")
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if await save_data_points_queue.len.aio() != 0:
|
if await add_nodes_and_edges_queue.len.aio() != 0:
|
||||||
try:
|
try:
|
||||||
nodes_and_edges = await save_data_points_queue.get.aio(block=False)
|
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
|
||||||
except modal.exception.DeserializationError as error:
|
except modal.exception.DeserializationError as error:
|
||||||
logger.error(f"Deserialization error: {str(error)}")
|
logger.error(f"Deserialization error: {str(error)}")
|
||||||
continue
|
continue
|
||||||
|
|
@ -37,7 +37,7 @@ async def data_point_saver_worker():
|
||||||
|
|
||||||
if len(nodes_and_edges) == 2:
|
if len(nodes_and_edges) == 2:
|
||||||
print(
|
print(
|
||||||
f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
||||||
)
|
)
|
||||||
nodes = nodes_and_edges[0]
|
nodes = nodes_and_edges[0]
|
||||||
edges = nodes_and_edges[1]
|
edges = nodes_and_edges[1]
|
||||||
|
|
@ -47,7 +47,7 @@ async def data_point_saver_worker():
|
||||||
|
|
||||||
if edges:
|
if edges:
|
||||||
await graph_engine.add_edges(edges, distributed=False)
|
await graph_engine.add_edges(edges, distributed=False)
|
||||||
print("Finished processing nodes and edges.")
|
print("Finished adding nodes and edges.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("No jobs, go to sleep.")
|
print("No jobs, go to sleep.")
|
||||||
Loading…
Add table
Reference in a new issue