115 lines
4.2 KiB
Python
115 lines
4.2 KiB
Python
import pathlib
|
|
from os import path
|
|
|
|
from cognee.api.v1.add import add
|
|
from cognee.api.v1.prune import prune
|
|
from cognee.infrastructure.llm.utils import get_max_chunk_tokens
|
|
from cognee.modules.chunking.TextChunker import TextChunker
|
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
|
from cognee.modules.data.processing.document_types import Document
|
|
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
|
from cognee.modules.pipelines.tasks.task import Task
|
|
from cognee.modules.users.methods.get_default_user import get_default_user
|
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
|
from cognee.modules.data.methods.get_datasets_by_name import get_datasets_by_name
|
|
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.tasks.documents.classify_documents import classify_documents
|
|
from cognee.tasks.documents.extract_chunks_from_documents import extract_chunks_from_documents
|
|
|
|
from distributed.app import app
|
|
from distributed.queues import finished_jobs_queue, save_data_points_queue
|
|
from distributed.workers.data_point_saver_worker import data_point_saver_worker
|
|
from distributed.workers.graph_extraction_worker import graph_extraction_worker
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
@app.local_entrypoint()
|
|
async def main():
|
|
# Clear queues
|
|
finished_jobs_queue.clear()
|
|
save_data_points_queue.clear()
|
|
|
|
dataset_name = "main"
|
|
data_directory_name = ".data"
|
|
data_directory_path = path.join(pathlib.Path(__file__).parent, data_directory_name)
|
|
|
|
number_of_data_saving_workers = 1 # Total number of graph_extraction_worker functions to spawn
|
|
document_batch_size = 50 # Batch size for producers
|
|
|
|
results = []
|
|
consumer_futures = []
|
|
|
|
# Delete DBs and saved files from metastore
|
|
await prune.prune_data()
|
|
await prune.prune_system(metadata=True)
|
|
|
|
# Add files to the metastore
|
|
await add(data=data_directory_path, dataset_name=dataset_name)
|
|
|
|
user = await get_default_user()
|
|
datasets = await get_datasets_by_name(dataset_name, user.id)
|
|
documents = await get_dataset_data(dataset_id=datasets[0].id)
|
|
|
|
print(f"We have {len(documents)} documents in the dataset.")
|
|
|
|
# Start data_point_saver_worker functions
|
|
for _ in range(number_of_data_saving_workers):
|
|
worker_future = data_point_saver_worker.spawn(total_number_of_workers=len(documents))
|
|
consumer_futures.append(worker_future)
|
|
|
|
producer_futures = []
|
|
|
|
def process_chunks_remotely(document_chunks: list[DocumentChunk], document: Document):
|
|
producer_future = graph_extraction_worker.spawn(
|
|
user=user, document_name=document.name, document_chunks=document_chunks
|
|
)
|
|
producer_futures.append(producer_future)
|
|
return producer_future
|
|
|
|
# Produce chunks and spawn a graph_extraction_worker job for each batch of chunks
|
|
for i in range(0, len(documents), document_batch_size):
|
|
batch = documents[i : i + document_batch_size]
|
|
|
|
for item in batch:
|
|
async for worker_feature in run_tasks(
|
|
[
|
|
Task(classify_documents),
|
|
Task(
|
|
extract_chunks_from_documents,
|
|
max_chunk_size=get_max_chunk_tokens(),
|
|
chunker=TextChunker,
|
|
),
|
|
Task(
|
|
process_chunks_remotely,
|
|
document=item,
|
|
task_config={"batch_size": 50},
|
|
),
|
|
],
|
|
data=[item],
|
|
user=user,
|
|
pipeline_name="chunk_processing",
|
|
):
|
|
pass
|
|
|
|
batch_results = []
|
|
for producer_future in producer_futures:
|
|
try:
|
|
result = producer_future.get()
|
|
except Exception as e:
|
|
result = e
|
|
batch_results.append(result)
|
|
|
|
results.extend(batch_results)
|
|
finished_jobs_queue.put(len(results))
|
|
|
|
for consumer_future in consumer_futures:
|
|
try:
|
|
print("Finished but waiting")
|
|
consumer_final = consumer_future.get()
|
|
print(f"We got all futures {consumer_final}")
|
|
except Exception as e:
|
|
logger.error(e)
|
|
|
|
print(results)
|