282 lines
9.4 KiB
Python
282 lines
9.4 KiB
Python
import os
|
|
import json
|
|
import pathlib
|
|
import asyncio
|
|
from typing import Optional
|
|
from pydantic import BaseModel
|
|
from dotenv import dotenv_values
|
|
from cognee.modules.chunking.models import DocumentChunk
|
|
from modal import App, Queue, Image
|
|
|
|
import cognee
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.shared.data_models import KnowledgeGraph
|
|
from cognee.modules.pipelines.operations import run_tasks
|
|
from cognee.modules.users.methods import get_default_user
|
|
from cognee.infrastructure.llm import get_max_chunk_tokens
|
|
from cognee.modules.chunking.TextChunker import TextChunker
|
|
from cognee.modules.data.methods import get_datasets_by_name
|
|
from cognee.modules.cognify.config import get_cognify_config
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
|
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
|
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
|
from cognee.modules.pipelines.tasks import Task
|
|
|
|
|
|
# Global tasks
|
|
from cognee.tasks.documents import (
|
|
classify_documents,
|
|
extract_chunks_from_documents,
|
|
)
|
|
from cognee.tasks.storage.index_data_points import index_data_points
|
|
|
|
# Local tasks
|
|
from .tasks.extract_graph_from_data import extract_graph_from_data
|
|
from .tasks.summarize_text import summarize_text
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# App and Queue Initialization
|
|
# ------------------------------------------------------------------------------
|
|
|
|
# Initialize the Modal application
|
|
app = App("cognee_modal_distributed")
|
|
logger = get_logger("cognee_modal_distributed")
|
|
|
|
local_env_vars = dict(dotenv_values(".env"))
|
|
logger.info("Modal deployment started with the following environmental variables:")
|
|
logger.info(json.dumps(local_env_vars, indent=4))
|
|
|
|
image = (
|
|
Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
|
|
.add_local_file("pyproject.toml", remote_path="/root/pyproject.toml", copy=True)
|
|
.add_local_file("poetry.lock", remote_path="/root/poetry.lock", copy=True)
|
|
.env(local_env_vars)
|
|
.poetry_install_from_file(poetry_pyproject_toml="pyproject.toml")
|
|
# .pip_install("protobuf", "h2", "neo4j", "asyncpg", "pgvector")
|
|
.add_local_python_source("../cognee")
|
|
)
|
|
|
|
|
|
# Create (or get) two queues:
|
|
# - graph_nodes_and_edges: Stores messages produced by the producer functions.
|
|
# - finished_producers: Keeps track of the number of finished producer jobs.
|
|
graph_nodes_and_edges = Queue.from_name("graph_nodes_and_edges", create_if_missing=True)
|
|
|
|
finished_producers = Queue.from_name("finished_producers", create_if_missing=True)
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Cognee pipeline steps
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
def add_data_to_save_queue(document_chunks: list[DocumentChunk]):
|
|
future = producer.spawn(file_name=document_name, chunk_list=event.result)
|
|
futures.append(future)
|
|
|
|
# Preprocessing steps. This gets called in the entrypoint
|
|
async def get_preprocessing_steps(chunker=TextChunker) -> list[Task]:
|
|
preprocessing_tasks = [
|
|
Task(classify_documents),
|
|
Task( # Extract text chunks based on the document type.
|
|
extract_chunks_from_documents,
|
|
max_chunk_size=None or get_max_chunk_tokens(),
|
|
chunker=chunker,
|
|
),
|
|
Task(
|
|
add_data_to_save_queue,
|
|
task_config={"batch_size": 50},
|
|
),
|
|
]
|
|
|
|
return preprocessing_tasks
|
|
|
|
|
|
# This is the last step of the pipeline that gets executed on modal executors (functions)
|
|
async def save_data_points(data_points: list = None, data_point_connections: list = None):
|
|
data_point_connections = data_point_connections or []
|
|
|
|
nodes = []
|
|
edges = []
|
|
|
|
added_nodes = {}
|
|
added_edges = {}
|
|
visited_properties = {}
|
|
|
|
results = await asyncio.gather(
|
|
*[
|
|
get_graph_from_model(
|
|
data_point,
|
|
added_nodes=added_nodes,
|
|
added_edges=added_edges,
|
|
visited_properties=visited_properties,
|
|
)
|
|
for data_point in data_points
|
|
]
|
|
)
|
|
|
|
for result_nodes, result_edges in results:
|
|
nodes.extend(result_nodes)
|
|
edges.extend(result_edges)
|
|
|
|
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
|
|
|
|
await index_data_points(nodes)
|
|
|
|
graph_nodes_and_edges.put((nodes, edges + data_point_connections))
|
|
|
|
|
|
# This is the pipeline for the modal executors
|
|
async def get_graph_tasks(
|
|
graph_model: BaseModel = KnowledgeGraph,
|
|
ontology_file_path: Optional[str] = None,
|
|
) -> list[Task]:
|
|
cognee_config = get_cognify_config()
|
|
|
|
ontology_adapter = OntologyResolver(ontology_file=ontology_file_path)
|
|
|
|
step_two_tasks = [
|
|
Task(
|
|
extract_graph_from_data,
|
|
graph_model=graph_model,
|
|
ontology_adapter=ontology_adapter,
|
|
),
|
|
Task(
|
|
summarize_text,
|
|
summarization_model=cognee_config.summarization_model,
|
|
),
|
|
Task(save_data_points),
|
|
]
|
|
|
|
return step_two_tasks
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Producer Function
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
@app.function(image=image, timeout=86400, max_containers=100)
|
|
async def producer(file_name: str, chunk_list: list):
|
|
modal_tasks = await get_graph_tasks()
|
|
async for _ in run_tasks(
|
|
modal_tasks, data=chunk_list, pipeline_name=f"modal_execution_file_{file_name}"
|
|
):
|
|
pass
|
|
|
|
print(f"File execution finished: {file_name}")
|
|
|
|
return file_name
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Consumer Function
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
@app.function(image=image, timeout=86400, max_containers=100)
|
|
async def consumer(number_of_files: int):
|
|
graph_engine = await get_graph_engine()
|
|
|
|
while True:
|
|
if graph_nodes_and_edges.len() != 0:
|
|
nodes_and_edges = graph_nodes_and_edges.get(block=False)
|
|
if nodes_and_edges is not None:
|
|
if nodes_and_edges[0] is not None:
|
|
await graph_engine.add_nodes(nodes_and_edges[0])
|
|
if nodes_and_edges[1] is not None:
|
|
await graph_engine.add_edges(nodes_and_edges[1])
|
|
else:
|
|
print(f"Nodes and edges are: {nodes_and_edges}")
|
|
else:
|
|
await asyncio.sleep(5)
|
|
|
|
number_of_finished_jobs = finished_producers.get(block=False)
|
|
|
|
if number_of_finished_jobs == number_of_files:
|
|
# We put it back for the other consumers to see that we finished
|
|
finished_producers.put(number_of_finished_jobs)
|
|
|
|
print("Finished processing all nodes and edges; stopping graph engine queue.")
|
|
return True
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Entrypoint
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
@app.local_entrypoint()
|
|
async def main():
|
|
# Clear queues
|
|
graph_nodes_and_edges.clear()
|
|
finished_producers.clear()
|
|
|
|
dataset_name = "main"
|
|
data_directory_name = ".data"
|
|
data_directory_path = os.path.join(pathlib.Path(__file__).parent, data_directory_name)
|
|
|
|
number_of_consumers = 1 # Total number of consumer functions to spawn
|
|
batch_size = 50 # Batch size for producers
|
|
|
|
results = []
|
|
consumer_futures = []
|
|
|
|
# Delete DBs and saved files from metastore
|
|
await cognee.prune.prune_data()
|
|
await cognee.prune.prune_system(metadata=True)
|
|
|
|
# Add files to the metastore
|
|
await cognee.add(data=data_directory_path, dataset_name=dataset_name)
|
|
|
|
user = await get_default_user()
|
|
datasets = await get_datasets_by_name(dataset_name, user.id)
|
|
documents = await get_dataset_data(dataset_id=datasets[0].id)
|
|
|
|
print(f"We have {len(documents)} documents in the dataset.")
|
|
|
|
preprocessing_tasks = await get_preprocessing_steps(user)
|
|
|
|
# Start consumer functions
|
|
for _ in range(number_of_consumers):
|
|
consumer_future = consumer.spawn(number_of_files=len(documents))
|
|
consumer_futures.append(consumer_future)
|
|
|
|
# Process producer jobs in batches
|
|
for i in range(0, len(documents), batch_size):
|
|
batch = documents[i : i + batch_size]
|
|
futures = []
|
|
for item in batch:
|
|
document_name = item.name
|
|
async for event in run_tasks(
|
|
preprocessing_tasks, data=[item], pipeline_name="preprocessing_steps"
|
|
):
|
|
if (
|
|
isinstance(event, TaskExecutionCompleted)
|
|
and event.task is extract_chunks_from_documents
|
|
):
|
|
future = producer.spawn(file_name=document_name, chunk_list=event.result)
|
|
futures.append(future)
|
|
|
|
batch_results = []
|
|
for future in futures:
|
|
try:
|
|
result = future.get()
|
|
except Exception as e:
|
|
result = e
|
|
batch_results.append(result)
|
|
|
|
results.extend(batch_results)
|
|
finished_producers.put(len(results))
|
|
|
|
for consumer_future in consumer_futures:
|
|
try:
|
|
print("Finished but waiting")
|
|
consumer_final = consumer_future.get()
|
|
print(f"We got all futures{consumer_final}")
|
|
except Exception as e:
|
|
print(e)
|
|
pass
|
|
|
|
print(results)
|