<!-- .github/pull_request_template.md --> ## Description This PR fixes distributed pipeline + updates core changes in distr logic. ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Changes Made Fixes distributed pipeline: -Changed spawning logic + adds incremental loading to run_tasks_diistributed -Adds batching to consumer nodes -Fixes consumer stopping criteria by adding stop signal + handling -Changed edge embedding solution to avoid huge network load in a case of a multicontainer environment ## Testing Tested it by running 1GB on modal + manually ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues None ## Additional Notes None ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Boris <boris@topoteretes.com> Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import os
|
|
import asyncio
|
|
|
|
import cognee
|
|
from cognee.api.v1.prune import prune
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.modules.engine.operations.setup import setup
|
|
|
|
from distributed.app import app
|
|
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
|
|
from distributed.workers.graph_saving_worker import graph_saving_worker
|
|
from distributed.workers.data_point_saving_worker import data_point_saving_worker
|
|
from distributed.signal import QueueSignal
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
os.environ["COGNEE_DISTRIBUTED"] = "True"
|
|
|
|
|
|
@app.local_entrypoint()
|
|
async def main():
|
|
# Clear queues
|
|
await add_nodes_and_edges_queue.clear.aio()
|
|
await add_data_points_queue.clear.aio()
|
|
|
|
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn (MAX 1)
|
|
number_of_data_point_saving_workers = (
|
|
10 # Total number of graph_saving_worker to spawn (MAX 10)
|
|
)
|
|
|
|
consumer_futures = []
|
|
|
|
await prune.prune_data() # This prunes the data from the file storage
|
|
# Delete DBs and saved files from metastore
|
|
await prune.prune_system(metadata=True)
|
|
|
|
await setup()
|
|
|
|
# Start graph_saving_worker functions
|
|
for _ in range(number_of_graph_saving_workers):
|
|
worker_future = graph_saving_worker.spawn()
|
|
consumer_futures.append(worker_future)
|
|
|
|
# Start data_point_saving_worker functions
|
|
for _ in range(number_of_data_point_saving_workers):
|
|
worker_future = data_point_saving_worker.spawn()
|
|
consumer_futures.append(worker_future)
|
|
|
|
""" Example: Setting and adding S3 path as input
|
|
s3_bucket_path = os.getenv("S3_BUCKET_PATH")
|
|
s3_data_path = "s3://" + s3_bucket_path
|
|
|
|
await cognee.add(s3_data_path, dataset_name="s3-files")
|
|
"""
|
|
await cognee.add(
|
|
[
|
|
"Audi is a German car manufacturer",
|
|
"The Netherlands is next to Germany",
|
|
"Berlin is the capital of Germany",
|
|
"The Rhine is a major European river",
|
|
"BMW produces luxury vehicles",
|
|
],
|
|
dataset_name="s3-files",
|
|
)
|
|
|
|
await cognee.cognify(datasets=["s3-files"])
|
|
|
|
# Put Processing end signal into the queues to stop the consumers
|
|
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
|
|
await add_data_points_queue.put.aio(QueueSignal.STOP)
|
|
|
|
for consumer_future in consumer_futures:
|
|
try:
|
|
print("Finished but waiting for saving workers to finish.")
|
|
consumer_final = consumer_future.get()
|
|
print(f"All workers are done: {consumer_final}")
|
|
except Exception as e:
|
|
logger.error(e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|