cognee/distributed/entrypoint.py
hajdul88 faeca138d9
fix: fixes distributed pipeline (#1454)
<!-- .github/pull_request_template.md -->

## Description
This PR fixes distributed pipeline + updates core changes in distr
logic.

## Type of Change
<!-- Please check the relevant option -->
- [x] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Changes Made
Fixes distributed pipeline:
-Changed spawning logic + adds incremental loading to
run_tasks_diistributed
-Adds batching to consumer nodes
-Fixes consumer stopping criteria by adding stop signal + handling
-Changed edge embedding solution to avoid huge network load in a case of
a multicontainer environment

## Testing
Tested it by running 1GB on modal + manually

## Screenshots/Videos (if applicable)
None

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## Related Issues
None

## Additional Notes
None

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Boris <boris@topoteretes.com>
Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
2025-10-09 14:06:25 +02:00

83 lines
2.6 KiB
Python

import os
import asyncio
import cognee
from cognee.api.v1.prune import prune
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.operations.setup import setup
from distributed.app import app
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
from distributed.workers.graph_saving_worker import graph_saving_worker
from distributed.workers.data_point_saving_worker import data_point_saving_worker
from distributed.signal import QueueSignal
logger = get_logger()
os.environ["COGNEE_DISTRIBUTED"] = "True"
@app.local_entrypoint()
async def main():
# Clear queues
await add_nodes_and_edges_queue.clear.aio()
await add_data_points_queue.clear.aio()
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn (MAX 1)
number_of_data_point_saving_workers = (
10 # Total number of graph_saving_worker to spawn (MAX 10)
)
consumer_futures = []
await prune.prune_data() # This prunes the data from the file storage
# Delete DBs and saved files from metastore
await prune.prune_system(metadata=True)
await setup()
# Start graph_saving_worker functions
for _ in range(number_of_graph_saving_workers):
worker_future = graph_saving_worker.spawn()
consumer_futures.append(worker_future)
# Start data_point_saving_worker functions
for _ in range(number_of_data_point_saving_workers):
worker_future = data_point_saving_worker.spawn()
consumer_futures.append(worker_future)
""" Example: Setting and adding S3 path as input
s3_bucket_path = os.getenv("S3_BUCKET_PATH")
s3_data_path = "s3://" + s3_bucket_path
await cognee.add(s3_data_path, dataset_name="s3-files")
"""
await cognee.add(
[
"Audi is a German car manufacturer",
"The Netherlands is next to Germany",
"Berlin is the capital of Germany",
"The Rhine is a major European river",
"BMW produces luxury vehicles",
],
dataset_name="s3-files",
)
await cognee.cognify(datasets=["s3-files"])
# Put Processing end signal into the queues to stop the consumers
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
await add_data_points_queue.put.aio(QueueSignal.STOP)
for consumer_future in consumer_futures:
try:
print("Finished but waiting for saving workers to finish.")
consumer_final = consumer_future.get()
print(f"All workers are done: {consumer_final}")
except Exception as e:
logger.error(e)
if __name__ == "__main__":
asyncio.run(main())