fix: fixes distributed pipeline (#1454)
<!-- .github/pull_request_template.md --> ## Description This PR fixes distributed pipeline + updates core changes in distr logic. ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Changes Made Fixes distributed pipeline: -Changed spawning logic + adds incremental loading to run_tasks_diistributed -Adds batching to consumer nodes -Fixes consumer stopping criteria by adding stop signal + handling -Changed edge embedding solution to avoid huge network load in a case of a multicontainer environment ## Testing Tested it by running 1GB on modal + manually ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues None ## Additional Notes None ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Boris <boris@topoteretes.com> Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
This commit is contained in:
parent
01632988fe
commit
faeca138d9
23 changed files with 654 additions and 12760 deletions
73
.github/workflows/distributed_test.yml
vendored
Normal file
73
.github/workflows/distributed_test.yml
vendored
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
name: Distributed Cognee test with modal
|
||||
permissions:
|
||||
contents: read
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
python-version:
|
||||
required: false
|
||||
type: string
|
||||
default: '3.11.x'
|
||||
secrets:
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
run-server-start-test:
|
||||
name: Distributed Cognee test (Modal)
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "distributed postgres"
|
||||
|
||||
- name: Run Distributed Cognee (Modal)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
MODAL_SECRET_NAME: ${{ secrets.MODAL_SECRET_NAME }}
|
||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||
GRAPH_DATABASE_URL: ${{ secrets.AZURE_NEO4j_URL }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ secrets.AZURE_NEO4J_USERNAME }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.AZURE_NEO4J_PW }}
|
||||
DB_PROVIDER: "postgres"
|
||||
DB_NAME: ${{ secrets.AZURE_POSTGRES_DB_NAME }}
|
||||
DB_HOST: ${{ secrets.AZURE_POSTGRES_HOST }}
|
||||
DB_PORT: ${{ secrets.AZURE_POSTGRES_PORT }}
|
||||
DB_USERNAME: ${{ secrets.AZURE_POSTGRES_USERNAME }}
|
||||
DB_PASSWORD: ${{ secrets.AZURE_POSTGRES_PW }}
|
||||
VECTOR_DB_PROVIDER: "pgvector"
|
||||
COGNEE_DISTRIBUTED: "true"
|
||||
run: uv run modal run ./distributed/entrypoint.py
|
||||
8
.github/workflows/test_suites.yml
vendored
8
.github/workflows/test_suites.yml
vendored
|
|
@ -27,6 +27,12 @@ jobs:
|
|||
uses: ./.github/workflows/e2e_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
distributed-tests:
|
||||
name: Distributed Cognee Test
|
||||
needs: [ basic-tests, e2e-tests, graph-db-tests ]
|
||||
uses: ./.github/workflows/distributed_test.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
uses: ./.github/workflows/cli_tests.yml
|
||||
|
|
@ -104,7 +110,7 @@ jobs:
|
|||
|
||||
db-examples-tests:
|
||||
name: DB Examples Tests
|
||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
|
||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests, distributed-tests]
|
||||
uses: ./.github/workflows/db_examples_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
|
|
|
|||
|
|
@ -117,5 +117,4 @@ async def add_rule_associations(data: str, rules_nodeset_name: str):
|
|||
|
||||
if len(edges_to_save) > 0:
|
||||
await graph_engine.add_edges(edges_to_save)
|
||||
|
||||
await index_graph_edges()
|
||||
await index_graph_edges(edges_to_save)
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
auth=auth,
|
||||
max_connection_lifetime=120,
|
||||
notifications_min_severity="OFF",
|
||||
keep_alive=True,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -205,7 +206,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
{
|
||||
"node_id": str(node.id),
|
||||
"label": type(node).__name__,
|
||||
"properties": self.serialize_properties(node.model_dump()),
|
||||
"properties": self.serialize_properties(dict(node)),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
|||
logger = get_logger("deadlock_retry")
|
||||
|
||||
|
||||
def deadlock_retry(max_retries=5):
|
||||
def deadlock_retry(max_retries=10):
|
||||
"""
|
||||
Decorator that automatically retries an asynchronous function when rate limit errors occur.
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ async def get_dataset_data(dataset_id: UUID) -> list[Data]:
|
|||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
|
||||
select(Data)
|
||||
.join(Data.datasets)
|
||||
.filter((Dataset.id == dataset_id))
|
||||
.order_by(Data.data_size.desc())
|
||||
)
|
||||
|
||||
data = list(result.scalars().all())
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import Optional
|
|||
|
||||
class TableRow(DataPoint):
|
||||
name: str
|
||||
is_a: Optional[TableType] = None
|
||||
description: str
|
||||
properties: str
|
||||
|
||||
|
|
|
|||
|
|
@ -4,35 +4,28 @@ import asyncio
|
|||
from uuid import UUID
|
||||
from typing import Any, List
|
||||
from functools import wraps
|
||||
from sqlalchemy import select
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
||||
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
|
||||
from cognee.tasks.ingestion import resolve_data_directories
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
PipelineRunStarted,
|
||||
PipelineRunYield,
|
||||
PipelineRunAlreadyCompleted,
|
||||
)
|
||||
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
||||
|
||||
from cognee.modules.pipelines.operations import (
|
||||
log_pipeline_run_start,
|
||||
log_pipeline_run_complete,
|
||||
log_pipeline_run_error,
|
||||
)
|
||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||
from .run_tasks_data_item import run_tasks_data_item
|
||||
from ..tasks.task import Task
|
||||
|
||||
|
||||
|
|
@ -68,176 +61,6 @@ async def run_tasks(
|
|||
context: dict = None,
|
||||
incremental_loading: bool = False,
|
||||
):
|
||||
async def _run_tasks_data_item_incremental(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
):
|
||||
db_engine = get_relational_engine()
|
||||
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
||||
# If data is being added to Cognee for the first time calculate the id of the data
|
||||
if not isinstance(data_item, Data):
|
||||
file_path = await save_data_item_to_storage(data_item)
|
||||
# Ingest data and add metadata
|
||||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
else:
|
||||
# If data was already processed by Cognee get data id
|
||||
data_id = data_item.id
|
||||
|
||||
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
if data_point:
|
||||
if (
|
||||
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
||||
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
):
|
||||
yield {
|
||||
"run_info": PipelineRunAlreadyCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
# Update pipeline status for Data element
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
data_point.pipeline_status[pipeline_name] = {
|
||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
}
|
||||
await session.merge(data_point)
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
except Exception as error:
|
||||
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
||||
logger.error(
|
||||
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
||||
)
|
||||
yield {
|
||||
"run_info": PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
|
||||
raise error
|
||||
|
||||
async def _run_tasks_data_item_regular(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
):
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
}
|
||||
|
||||
async def _run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
incremental_loading,
|
||||
):
|
||||
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
||||
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
||||
result = None
|
||||
if incremental_loading:
|
||||
async for result in _run_tasks_data_item_incremental(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
async for result in _run_tasks_data_item_regular(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
@ -269,7 +92,7 @@ async def run_tasks(
|
|||
# Create async tasks per data item that will run the pipeline for the data item
|
||||
data_item_tasks = [
|
||||
asyncio.create_task(
|
||||
_run_tasks_data_item(
|
||||
run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
|
|
|
|||
261
cognee/modules/pipelines/operations/run_tasks_data_item.py
Normal file
261
cognee/modules/pipelines/operations/run_tasks_data_item.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""
|
||||
Data item processing functions for pipeline operations.
|
||||
|
||||
This module contains reusable functions for processing individual data items
|
||||
within pipeline operations, supporting both incremental and regular processing modes.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, AsyncGenerator, Optional
|
||||
from sqlalchemy import select
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.tasks.ingestion import save_data_item_to_storage
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
PipelineRunYield,
|
||||
PipelineRunAlreadyCompleted,
|
||||
)
|
||||
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
||||
from cognee.modules.pipelines.operations.run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||
from ..tasks.task import Task
|
||||
|
||||
logger = get_logger("run_tasks_data_item")
|
||||
|
||||
|
||||
async def run_tasks_data_item_incremental(
|
||||
data_item: Any,
|
||||
dataset: Dataset,
|
||||
tasks: list[Task],
|
||||
pipeline_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[Dict[str, Any]],
|
||||
user: User,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Process a single data item with incremental loading support.
|
||||
|
||||
This function handles incremental processing by checking if the data item
|
||||
has already been processed for the given pipeline and dataset. If it has,
|
||||
it skips processing and returns a completion status.
|
||||
|
||||
Args:
|
||||
data_item: The data item to process
|
||||
dataset: The dataset containing the data item
|
||||
tasks: List of tasks to execute on the data item
|
||||
pipeline_name: Name of the pipeline
|
||||
pipeline_id: Unique identifier for the pipeline
|
||||
pipeline_run_id: Unique identifier for this pipeline run
|
||||
context: Optional context dictionary
|
||||
user: User performing the operation
|
||||
|
||||
Yields:
|
||||
Dict containing run_info and data_id for each processing step
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
||||
# If data is being added to Cognee for the first time calculate the id of the data
|
||||
if not isinstance(data_item, Data):
|
||||
file_path = await save_data_item_to_storage(data_item)
|
||||
# Ingest data and add metadata
|
||||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
else:
|
||||
# If data was already processed by Cognee get data id
|
||||
data_id = data_item.id
|
||||
|
||||
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
if data_point:
|
||||
if (
|
||||
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
||||
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
):
|
||||
yield {
|
||||
"run_info": PipelineRunAlreadyCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
# Update pipeline status for Data element
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
data_point.pipeline_status[pipeline_name] = {
|
||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
}
|
||||
await session.merge(data_point)
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
except Exception as error:
|
||||
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
||||
logger.error(
|
||||
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
||||
)
|
||||
yield {
|
||||
"run_info": PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
|
||||
raise error
|
||||
|
||||
|
||||
async def run_tasks_data_item_regular(
|
||||
data_item: Any,
|
||||
dataset: Dataset,
|
||||
tasks: list[Task],
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[Dict[str, Any]],
|
||||
user: User,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Process a single data item in regular (non-incremental) mode.
|
||||
|
||||
This function processes a data item without checking for previous processing
|
||||
status, executing all tasks on the data item.
|
||||
|
||||
Args:
|
||||
data_item: The data item to process
|
||||
dataset: The dataset containing the data item
|
||||
tasks: List of tasks to execute on the data item
|
||||
pipeline_id: Unique identifier for the pipeline
|
||||
pipeline_run_id: Unique identifier for this pipeline run
|
||||
context: Optional context dictionary
|
||||
user: User performing the operation
|
||||
|
||||
Yields:
|
||||
Dict containing run_info for each processing step
|
||||
"""
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
async def run_tasks_data_item(
|
||||
data_item: Any,
|
||||
dataset: Dataset,
|
||||
tasks: list[Task],
|
||||
pipeline_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[Dict[str, Any]],
|
||||
user: User,
|
||||
incremental_loading: bool,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Process a single data item, choosing between incremental and regular processing.
|
||||
|
||||
This is the main entry point for data item processing that delegates to either
|
||||
incremental or regular processing based on the incremental_loading flag.
|
||||
|
||||
Args:
|
||||
data_item: The data item to process
|
||||
dataset: The dataset containing the data item
|
||||
tasks: List of tasks to execute on the data item
|
||||
pipeline_name: Name of the pipeline
|
||||
pipeline_id: Unique identifier for the pipeline
|
||||
pipeline_run_id: Unique identifier for this pipeline run
|
||||
context: Optional context dictionary
|
||||
user: User performing the operation
|
||||
incremental_loading: Whether to use incremental processing
|
||||
|
||||
Returns:
|
||||
Dict containing the final processing result, or None if processing was skipped
|
||||
"""
|
||||
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
||||
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
||||
result = None
|
||||
if incremental_loading:
|
||||
async for result in run_tasks_data_item_incremental(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
async for result in run_tasks_data_item_regular(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
|
@ -3,49 +3,96 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
modal = None
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.pipelines.models import (
|
||||
PipelineRunStarted,
|
||||
PipelineRunYield,
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
)
|
||||
from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
|
||||
from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
|
||||
from cognee.modules.pipelines.operations import (
|
||||
log_pipeline_run_start,
|
||||
log_pipeline_run_complete,
|
||||
log_pipeline_run_error,
|
||||
)
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
||||
from cognee.tasks.ingestion import resolve_data_directories
|
||||
from .run_tasks_data_item import run_tasks_data_item
|
||||
|
||||
logger = get_logger("run_tasks_distributed()")
|
||||
|
||||
|
||||
if modal:
|
||||
import os
|
||||
from distributed.app import app
|
||||
from distributed.modal_image import image
|
||||
|
||||
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
||||
|
||||
@app.function(
|
||||
retries=3,
|
||||
image=image,
|
||||
timeout=86400,
|
||||
max_containers=50,
|
||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||
secrets=[modal.Secret.from_name(secret_name)],
|
||||
)
|
||||
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
|
||||
pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
|
||||
async def run_tasks_on_modal(
|
||||
data_item,
|
||||
dataset_id: UUID,
|
||||
tasks: List[Task],
|
||||
pipeline_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_run_id: str,
|
||||
context: Optional[dict],
|
||||
user: User,
|
||||
incremental_loading: bool,
|
||||
):
|
||||
"""
|
||||
Wrapper that runs the run_tasks_data_item function.
|
||||
This is the function/code that runs on modal executor and produces the graph/vector db objects
|
||||
"""
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
run_info = None
|
||||
async with get_relational_engine().get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
run_info = pipeline_run_info
|
||||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
return run_info
|
||||
result = await run_tasks_data_item(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
|
||||
async def run_tasks_distributed(
|
||||
tasks: List[Task],
|
||||
dataset_id: UUID,
|
||||
data: List[Any] = None,
|
||||
user: User = None,
|
||||
pipeline_name: str = "unknown_pipeline",
|
||||
context: dict = None,
|
||||
incremental_loading: bool = False,
|
||||
):
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
# Get dataset object
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
|
@ -53,9 +100,7 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
|||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
|
||||
yield PipelineRunStarted(
|
||||
|
|
@ -65,30 +110,67 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
|||
payload=data,
|
||||
)
|
||||
|
||||
data_count = len(data) if isinstance(data, list) else 1
|
||||
try:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
arguments = [
|
||||
[tasks] * data_count,
|
||||
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
|
||||
[user] * data_count,
|
||||
[pipeline_name] * data_count,
|
||||
[context] * data_count,
|
||||
]
|
||||
data = await resolve_data_directories(data)
|
||||
|
||||
async for result in run_tasks_on_modal.map.aio(*arguments):
|
||||
logger.info(f"Received result: {result}")
|
||||
number_of_data_items = len(data) if isinstance(data, list) else 1
|
||||
|
||||
yield PipelineRunYield(
|
||||
data_item_tasks = [
|
||||
data,
|
||||
[dataset.id] * number_of_data_items,
|
||||
[tasks] * number_of_data_items,
|
||||
[pipeline_name] * number_of_data_items,
|
||||
[pipeline_id] * number_of_data_items,
|
||||
[pipeline_run_id] * number_of_data_items,
|
||||
[context] * number_of_data_items,
|
||||
[user] * number_of_data_items,
|
||||
[incremental_loading] * number_of_data_items,
|
||||
]
|
||||
|
||||
results = []
|
||||
async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
|
||||
if not result:
|
||||
continue
|
||||
results.append(result)
|
||||
|
||||
# Remove skipped results
|
||||
results = [r for r in results if r]
|
||||
|
||||
# If any data item failed, raise PipelineRunFailedError
|
||||
errored = [
|
||||
r
|
||||
for r in results
|
||||
if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
|
||||
]
|
||||
if errored:
|
||||
raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
|
||||
|
||||
await log_pipeline_run_complete(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
||||
)
|
||||
|
||||
yield PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
data_ingestion_info=results,
|
||||
)
|
||||
|
||||
await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
|
||||
except Exception as error:
|
||||
await log_pipeline_run_error(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
||||
)
|
||||
|
||||
yield PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
yield PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
data_ingestion_info=locals().get("results"),
|
||||
)
|
||||
|
||||
if not isinstance(error, PipelineRunFailedError):
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
belongs_to_set=interactions_node_set,
|
||||
)
|
||||
|
||||
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
|
||||
await add_data_points(data_points=[cognee_user_interaction])
|
||||
|
||||
relationships = []
|
||||
relationship_name = "used_graph_element_to_answer"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.modules.retrieval.base_feedback import BaseFeedback
|
||||
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
|
||||
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.storage import add_data_points, index_graph_edges
|
||||
|
||||
logger = get_logger("CompletionRetriever")
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ class UserQAFeedback(BaseFeedback):
|
|||
belongs_to_set=feedbacks_node_set,
|
||||
)
|
||||
|
||||
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
|
||||
await add_data_points(data_points=[cognee_user_feedback])
|
||||
|
||||
relationships = []
|
||||
relationship_name = "gives_feedback_to"
|
||||
|
|
@ -76,6 +76,7 @@ class UserQAFeedback(BaseFeedback):
|
|||
if len(relationships) > 0:
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.add_edges(relationships)
|
||||
await index_graph_edges(relationships)
|
||||
await graph_engine.apply_feedback_weight(
|
||||
node_ids=to_node_ids, weight=feedback_sentiment.score
|
||||
)
|
||||
|
|
|
|||
|
|
@ -124,5 +124,4 @@ async def add_rule_associations(
|
|||
|
||||
if len(edges_to_save) > 0:
|
||||
await graph_engine.add_edges(edges_to_save)
|
||||
|
||||
await index_graph_edges()
|
||||
await index_graph_edges(edges_to_save)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||
from cognee.tasks.storage import index_graph_edges
|
||||
from cognee.tasks.storage.add_data_points import add_data_points
|
||||
from cognee.modules.ontology.ontology_config import Config
|
||||
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||
|
|
@ -88,6 +89,7 @@ async def integrate_chunk_graphs(
|
|||
|
||||
if len(graph_edges) > 0:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
await index_graph_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,7 @@ from cognee.tasks.storage.exceptions import (
|
|||
)
|
||||
|
||||
|
||||
async def add_data_points(
|
||||
data_points: List[DataPoint], update_edge_collection: bool = True
|
||||
) -> List[DataPoint]:
|
||||
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
"""
|
||||
Add a batch of data points to the graph database by extracting nodes and edges,
|
||||
deduplicating them, and indexing them for retrieval.
|
||||
|
|
@ -25,9 +23,6 @@ async def add_data_points(
|
|||
Args:
|
||||
data_points (List[DataPoint]):
|
||||
A list of data points to process and insert into the graph.
|
||||
update_edge_collection (bool, optional):
|
||||
Whether to update the edge index after adding edges.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
List[DataPoint]:
|
||||
|
|
@ -73,12 +68,10 @@ async def add_data_points(
|
|||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await index_data_points(nodes)
|
||||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
|
||||
if update_edge_collection:
|
||||
await index_graph_edges()
|
||||
await index_graph_edges(edges)
|
||||
|
||||
return data_points
|
||||
|
|
|
|||
|
|
@ -1,15 +1,18 @@
|
|||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from collections import Counter
|
||||
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.models.EdgeType import EdgeType
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def index_graph_edges():
|
||||
async def index_graph_edges(
|
||||
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
|
||||
):
|
||||
"""
|
||||
Indexes graph edges by creating and managing vector indexes for relationship types.
|
||||
|
||||
|
|
@ -35,13 +38,17 @@ async def index_graph_edges():
|
|||
index_points = {}
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
if edges_data is None:
|
||||
graph_engine = await get_graph_engine()
|
||||
_, edges_data = await graph_engine.get_graph_data()
|
||||
logger.warning(
|
||||
"Your graph edge embedding is deprecated, please pass edges to the index_graph_edges directly."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize engines: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
_, edges_data = await graph_engine.get_graph_data()
|
||||
|
||||
edge_types = Counter(
|
||||
item.get("relationship_name")
|
||||
for edge in edges_data
|
||||
|
|
|
|||
|
|
@ -29,6 +29,3 @@ RUN poetry install --extras neo4j --extras postgres --extras aws --extras distri
|
|||
|
||||
COPY cognee/ /app/cognee
|
||||
COPY distributed/ /app/distributed
|
||||
RUN chmod +x /app/distributed/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/app/distributed/entrypoint.sh"]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from distributed.app import app
|
|||
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
|
||||
from distributed.workers.graph_saving_worker import graph_saving_worker
|
||||
from distributed.workers.data_point_saving_worker import data_point_saving_worker
|
||||
from distributed.signal import QueueSignal
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -23,13 +24,14 @@ async def main():
|
|||
await add_nodes_and_edges_queue.clear.aio()
|
||||
await add_data_points_queue.clear.aio()
|
||||
|
||||
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
|
||||
number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn
|
||||
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn (MAX 1)
|
||||
number_of_data_point_saving_workers = (
|
||||
10 # Total number of graph_saving_worker to spawn (MAX 10)
|
||||
)
|
||||
|
||||
results = []
|
||||
consumer_futures = []
|
||||
|
||||
# await prune.prune_data() # We don't want to delete files on s3
|
||||
await prune.prune_data() # This prunes the data from the file storage
|
||||
# Delete DBs and saved files from metastore
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
|
|
@ -45,16 +47,28 @@ async def main():
|
|||
worker_future = data_point_saving_worker.spawn()
|
||||
consumer_futures.append(worker_future)
|
||||
|
||||
""" Example: Setting and adding S3 path as input
|
||||
s3_bucket_path = os.getenv("S3_BUCKET_PATH")
|
||||
s3_data_path = "s3://" + s3_bucket_path
|
||||
|
||||
await cognee.add(s3_data_path, dataset_name="s3-files")
|
||||
"""
|
||||
await cognee.add(
|
||||
[
|
||||
"Audi is a German car manufacturer",
|
||||
"The Netherlands is next to Germany",
|
||||
"Berlin is the capital of Germany",
|
||||
"The Rhine is a major European river",
|
||||
"BMW produces luxury vehicles",
|
||||
],
|
||||
dataset_name="s3-files",
|
||||
)
|
||||
|
||||
await cognee.cognify(datasets=["s3-files"])
|
||||
|
||||
# Push empty tuple into the queue to signal the end of data.
|
||||
await add_nodes_and_edges_queue.put.aio(())
|
||||
await add_data_points_queue.put.aio(())
|
||||
# Put Processing end signal into the queues to stop the consumers
|
||||
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
|
||||
await add_data_points_queue.put.aio(QueueSignal.STOP)
|
||||
|
||||
for consumer_future in consumer_futures:
|
||||
try:
|
||||
|
|
@ -64,8 +78,6 @@ async def main():
|
|||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
print(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
12238
distributed/poetry.lock
generated
12238
distributed/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,185 +0,0 @@
|
|||
[project]
|
||||
name = "cognee"
|
||||
version = "0.2.2.dev0"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = [
|
||||
{ name = "Vasilije Markovic" },
|
||||
{ name = "Boris Arzentar" },
|
||||
]
|
||||
requires-python = ">=3.10,<=3.13"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Operating System :: MacOS :: MacOS X",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
]
|
||||
dependencies = [
|
||||
"openai>=1.80.1,<2",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"pydantic>=2.11.7,<3.0.0",
|
||||
"pydantic-settings>=2.10.1,<3",
|
||||
"typing_extensions>=4.12.2,<5.0.0",
|
||||
"nltk>=3.9.1,<4.0.0",
|
||||
"numpy>=1.26.4, <=4.0.0",
|
||||
"pandas>=2.2.2,<3.0.0",
|
||||
# Note: New s3fs and boto3 versions don't work well together
|
||||
# Always use comaptible fixed versions of these two dependencies
|
||||
"s3fs[boto3]==2025.3.2",
|
||||
"sqlalchemy>=2.0.39,<3.0.0",
|
||||
"aiosqlite>=0.20.0,<1.0.0",
|
||||
"tiktoken>=0.8.0,<1.0.0",
|
||||
"litellm>=1.57.4, <1.71.0",
|
||||
"instructor>=1.9.1,<2.0.0",
|
||||
"langfuse>=2.32.0,<3",
|
||||
"filetype>=1.2.0,<2.0.0",
|
||||
"aiohttp>=3.11.14,<4.0.0",
|
||||
"aiofiles>=23.2.1,<24.0.0",
|
||||
"rdflib>=7.1.4,<7.2.0",
|
||||
"pypdf>=4.1.0,<7.0.0",
|
||||
"jinja2>=3.1.3,<4",
|
||||
"matplotlib>=3.8.3,<4",
|
||||
"networkx>=3.4.2,<4",
|
||||
"lancedb>=0.24.0,<1.0.0",
|
||||
"alembic>=1.13.3,<2",
|
||||
"pre-commit>=4.0.1,<5",
|
||||
"scikit-learn>=1.6.1,<2",
|
||||
"limits>=4.4.1,<5",
|
||||
"fastapi>=0.115.7,<1.0.0",
|
||||
"python-multipart>=0.0.20,<1.0.0",
|
||||
"fastapi-users[sqlalchemy]>=14.0.1,<15.0.0",
|
||||
"dlt[sqlalchemy]>=1.9.0,<2",
|
||||
"sentry-sdk[fastapi]>=2.9.0,<3",
|
||||
"structlog>=25.2.0,<26",
|
||||
"pympler>=1.1,<2.0.0",
|
||||
"onnxruntime>=1.0.0,<2.0.0",
|
||||
"pylance>=0.22.0,<1.0.0",
|
||||
"kuzu (==0.11.0)"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = [
|
||||
"uvicorn>=0.34.0,<1.0.0",
|
||||
"gunicorn>=20.1.0,<24",
|
||||
"websockets>=15.0.1,<16.0.0"
|
||||
]
|
||||
distributed = [
|
||||
"modal>=1.0.5,<2.0.0",
|
||||
]
|
||||
|
||||
neo4j = ["neo4j>=5.28.0,<6"]
|
||||
postgres = [
|
||||
"psycopg2>=2.9.10,<3",
|
||||
"pgvector>=0.3.5,<0.4",
|
||||
"asyncpg>=0.30.0,<1.0.0",
|
||||
]
|
||||
postgres-binary = [
|
||||
"psycopg2-binary>=2.9.10,<3.0.0",
|
||||
"pgvector>=0.3.5,<0.4",
|
||||
"asyncpg>=0.30.0,<1.0.0",
|
||||
]
|
||||
notebook = ["notebook>=7.1.0,<8"]
|
||||
langchain = [
|
||||
"langsmith>=0.2.3,<1.0.0",
|
||||
"langchain_text_splitters>=0.3.2,<1.0.0",
|
||||
]
|
||||
llama-index = ["llama-index-core>=0.12.11,<0.13"]
|
||||
gemini = ["google-generativeai>=0.8.4,<0.9"]
|
||||
huggingface = ["transformers>=4.46.3,<5"]
|
||||
ollama = ["transformers>=4.46.3,<5"]
|
||||
mistral = ["mistral-common>=1.5.2,<2"]
|
||||
anthropic = ["anthropic>=0.26.1,<0.27"]
|
||||
deepeval = ["deepeval>=2.0.1,<3"]
|
||||
posthog = ["posthog>=3.5.0,<4"]
|
||||
groq = ["groq>=0.8.0,<1.0.0"]
|
||||
chromadb = [
|
||||
"chromadb>=0.3.0,<0.7",
|
||||
"pypika==0.48.8",
|
||||
]
|
||||
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.18.1,<19"]
|
||||
codegraph = [
|
||||
"fastembed<=0.6.0 ; python_version < '3.13'",
|
||||
"transformers>=4.46.3,<5",
|
||||
"tree-sitter>=0.24.0,<0.25",
|
||||
"tree-sitter-python>=0.23.6,<0.24",
|
||||
]
|
||||
evals = [
|
||||
"plotly>=6.0.0,<7",
|
||||
"gdown>=5.2.0,<6",
|
||||
]
|
||||
gui = [
|
||||
"pyside6>=6.8.3,<7",
|
||||
"qasync>=0.27.1,<0.28",
|
||||
]
|
||||
graphiti = ["graphiti-core>=0.7.0,<0.8"]
|
||||
# Note: New s3fs and boto3 versions don't work well together
|
||||
# Always use comaptible fixed versions of these two dependencies
|
||||
aws = ["s3fs[boto3]==2025.3.2"]
|
||||
dev = [
|
||||
"pytest>=7.4.0,<8",
|
||||
"pytest-cov>=6.1.1,<7.0.0",
|
||||
"pytest-asyncio>=0.21.1,<0.22",
|
||||
"coverage>=7.3.2,<8",
|
||||
"mypy>=1.7.1,<2",
|
||||
"notebook>=7.1.0,<8",
|
||||
"deptry>=0.20.0,<0.21",
|
||||
"pylint>=3.0.3,<4",
|
||||
"ruff>=0.9.2,<1.0.0",
|
||||
"tweepy>=4.14.0,<5.0.0",
|
||||
"gitpython>=3.1.43,<4",
|
||||
"mkdocs-material>=9.5.42,<10",
|
||||
"mkdocs-minify-plugin>=0.8.0,<0.9",
|
||||
"mkdocstrings[python]>=0.26.2,<0.27",
|
||||
]
|
||||
debug = ["debugpy>=1.8.9,<2.0.0"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://www.cognee.ai"
|
||||
Repository = "https://github.com/topoteretes/cognee"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build]
|
||||
exclude = [
|
||||
"/bin",
|
||||
"/dist",
|
||||
"/.data",
|
||||
"/.github",
|
||||
"/alembic",
|
||||
"/deployment",
|
||||
"/cognee-mcp",
|
||||
"/cognee-frontend",
|
||||
"/examples",
|
||||
"/helm",
|
||||
"/licenses",
|
||||
"/logs",
|
||||
"/notebooks",
|
||||
"/profiling",
|
||||
"/tests",
|
||||
"/tools",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["cognee", "distributed"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
exclude = [
|
||||
"migrations/", # Ignore migrations directory
|
||||
"notebooks/", # Ignore notebook files
|
||||
"build/", # Ignore build directory
|
||||
"cognee/pipelines.py",
|
||||
"cognee/modules/users/models/Group.py",
|
||||
"cognee/modules/users/models/ACL.py",
|
||||
"cognee/modules/pipelines/models/Task.py",
|
||||
"cognee/modules/data/models/Dataset.py"
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["F401"]
|
||||
5
distributed/signal.py
Normal file
5
distributed/signal.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class QueueSignal(str, Enum):
|
||||
STOP = "STOP"
|
||||
|
|
@ -1,16 +1,17 @@
|
|||
import os
|
||||
import modal
|
||||
import asyncio
|
||||
from sqlalchemy.exc import OperationalError, DBAPIError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
from distributed.app import app
|
||||
from distributed.signal import QueueSignal
|
||||
from distributed.modal_image import image
|
||||
from distributed.queues import add_data_points_queue
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
|
||||
logger = get_logger("data_point_saving_worker")
|
||||
|
||||
|
||||
|
|
@ -39,55 +40,84 @@ def is_deadlock_error(error):
|
|||
return False
|
||||
|
||||
|
||||
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
||||
|
||||
|
||||
@app.function(
|
||||
retries=3,
|
||||
image=image,
|
||||
timeout=86400,
|
||||
max_containers=5,
|
||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||
max_containers=10,
|
||||
secrets=[modal.Secret.from_name(secret_name)],
|
||||
)
|
||||
async def data_point_saving_worker():
|
||||
print("Started processing of data points; starting vector engine queue.")
|
||||
vector_engine = get_vector_engine()
|
||||
# Defines how many data packets do we glue together from the modal queue before embedding call and ingestion
|
||||
BATCH_SIZE = 25
|
||||
stop_seen = False
|
||||
|
||||
while True:
|
||||
if stop_seen:
|
||||
print("Finished processing all data points; stopping vector engine queue consumer.")
|
||||
return True
|
||||
|
||||
if await add_data_points_queue.len.aio() != 0:
|
||||
try:
|
||||
add_data_points_request = await add_data_points_queue.get.aio(block=False)
|
||||
print("Remaining elements in queue:")
|
||||
print(await add_data_points_queue.len.aio())
|
||||
|
||||
# collect batched requests
|
||||
batched_points = {}
|
||||
for _ in range(min(BATCH_SIZE, await add_data_points_queue.len.aio())):
|
||||
add_data_points_request = await add_data_points_queue.get.aio(block=False)
|
||||
|
||||
if not add_data_points_request:
|
||||
continue
|
||||
|
||||
if add_data_points_request == QueueSignal.STOP:
|
||||
await add_data_points_queue.put.aio(QueueSignal.STOP)
|
||||
stop_seen = True
|
||||
break
|
||||
|
||||
if len(add_data_points_request) == 2:
|
||||
collection_name, data_points = add_data_points_request
|
||||
if collection_name not in batched_points:
|
||||
batched_points[collection_name] = []
|
||||
batched_points[collection_name].extend(data_points)
|
||||
else:
|
||||
print("NoneType or invalid request detected.")
|
||||
|
||||
if batched_points:
|
||||
for collection_name, data_points in batched_points.items():
|
||||
print(
|
||||
f"Adding {len(data_points)} data points to '{collection_name}' collection."
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def add_data_points():
|
||||
try:
|
||||
await vector_engine.create_data_points(
|
||||
collection_name, data_points, distributed=False
|
||||
)
|
||||
except DBAPIError as error:
|
||||
if is_deadlock_error(error):
|
||||
raise VectorDatabaseDeadlockError()
|
||||
except OperationalError as error:
|
||||
if is_deadlock_error(error):
|
||||
raise VectorDatabaseDeadlockError()
|
||||
|
||||
await add_data_points()
|
||||
print(f"Finished adding data points to '{collection_name}'.")
|
||||
|
||||
except modal.exception.DeserializationError as error:
|
||||
logger.error(f"Deserialization error: {str(error)}")
|
||||
continue
|
||||
|
||||
if len(add_data_points_request) == 0:
|
||||
print("Finished processing all data points; stopping vector engine queue.")
|
||||
return True
|
||||
|
||||
if len(add_data_points_request) == 2:
|
||||
(collection_name, data_points) = add_data_points_request
|
||||
|
||||
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def add_data_points():
|
||||
try:
|
||||
await vector_engine.create_data_points(
|
||||
collection_name, data_points, distributed=False
|
||||
)
|
||||
except DBAPIError as error:
|
||||
if is_deadlock_error(error):
|
||||
raise VectorDatabaseDeadlockError()
|
||||
except OperationalError as error:
|
||||
if is_deadlock_error(error):
|
||||
raise VectorDatabaseDeadlockError()
|
||||
|
||||
await add_data_points()
|
||||
|
||||
print("Finished adding data points.")
|
||||
|
||||
else:
|
||||
print("No jobs, go to sleep.")
|
||||
await asyncio.sleep(5)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import modal
|
||||
import asyncio
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
from distributed.app import app
|
||||
from distributed.signal import QueueSignal
|
||||
from distributed.modal_image import image
|
||||
from distributed.queues import add_nodes_and_edges_queue
|
||||
|
||||
|
|
@ -10,7 +12,6 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
|
||||
logger = get_logger("graph_saving_worker")
|
||||
|
||||
|
||||
|
|
@ -37,68 +38,91 @@ def is_deadlock_error(error):
|
|||
return False
|
||||
|
||||
|
||||
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
||||
|
||||
|
||||
@app.function(
|
||||
retries=3,
|
||||
image=image,
|
||||
timeout=86400,
|
||||
max_containers=5,
|
||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
||||
max_containers=1,
|
||||
secrets=[modal.Secret.from_name(secret_name)],
|
||||
)
|
||||
async def graph_saving_worker():
|
||||
print("Started processing of nodes and edges; starting graph engine queue.")
|
||||
graph_engine = await get_graph_engine()
|
||||
# Defines how many data packets do we glue together from the queue before ingesting them into the graph database
|
||||
BATCH_SIZE = 25
|
||||
stop_seen = False
|
||||
|
||||
while True:
|
||||
if stop_seen:
|
||||
print("Finished processing all data points; stopping graph engine queue consumer.")
|
||||
return True
|
||||
|
||||
if await add_nodes_and_edges_queue.len.aio() != 0:
|
||||
try:
|
||||
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
|
||||
print("Remaining elements in queue:")
|
||||
print(await add_nodes_and_edges_queue.len.aio())
|
||||
|
||||
all_nodes, all_edges = [], []
|
||||
for _ in range(min(BATCH_SIZE, await add_nodes_and_edges_queue.len.aio())):
|
||||
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
|
||||
|
||||
if not nodes_and_edges:
|
||||
continue
|
||||
|
||||
if nodes_and_edges == QueueSignal.STOP:
|
||||
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
|
||||
stop_seen = True
|
||||
break
|
||||
|
||||
if len(nodes_and_edges) == 2:
|
||||
nodes, edges = nodes_and_edges
|
||||
all_nodes.extend(nodes)
|
||||
all_edges.extend(edges)
|
||||
else:
|
||||
print("None Type detected.")
|
||||
|
||||
if all_nodes or all_edges:
|
||||
print(f"Adding {len(all_nodes)} nodes and {len(all_edges)} edges.")
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def save_graph_nodes(new_nodes):
|
||||
try:
|
||||
await graph_engine.add_nodes(new_nodes, distributed=False)
|
||||
except Exception as error:
|
||||
if is_deadlock_error(error):
|
||||
raise GraphDatabaseDeadlockError()
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def save_graph_edges(new_edges):
|
||||
try:
|
||||
await graph_engine.add_edges(new_edges, distributed=False)
|
||||
except Exception as error:
|
||||
if is_deadlock_error(error):
|
||||
raise GraphDatabaseDeadlockError()
|
||||
|
||||
if all_nodes:
|
||||
await save_graph_nodes(all_nodes)
|
||||
|
||||
if all_edges:
|
||||
await save_graph_edges(all_edges)
|
||||
|
||||
print("Finished adding nodes and edges.")
|
||||
|
||||
except modal.exception.DeserializationError as error:
|
||||
logger.error(f"Deserialization error: {str(error)}")
|
||||
continue
|
||||
|
||||
if len(nodes_and_edges) == 0:
|
||||
print("Finished processing all nodes and edges; stopping graph engine queue.")
|
||||
return True
|
||||
|
||||
if len(nodes_and_edges) == 2:
|
||||
print(
|
||||
f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
||||
)
|
||||
nodes = nodes_and_edges[0]
|
||||
edges = nodes_and_edges[1]
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def save_graph_nodes(new_nodes):
|
||||
try:
|
||||
await graph_engine.add_nodes(new_nodes, distributed=False)
|
||||
except Exception as error:
|
||||
if is_deadlock_error(error):
|
||||
raise GraphDatabaseDeadlockError()
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def save_graph_edges(new_edges):
|
||||
try:
|
||||
await graph_engine.add_edges(new_edges, distributed=False)
|
||||
except Exception as error:
|
||||
if is_deadlock_error(error):
|
||||
raise GraphDatabaseDeadlockError()
|
||||
|
||||
if nodes:
|
||||
await save_graph_nodes(nodes)
|
||||
|
||||
if edges:
|
||||
await save_graph_edges(edges)
|
||||
|
||||
print("Finished adding nodes and edges.")
|
||||
|
||||
else:
|
||||
print("No jobs, go to sleep.")
|
||||
await asyncio.sleep(5)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue