Merge branch 'dev' into feature/web_scraping_connector_task

This commit is contained in:
Geoffrey Robinson 2025-10-10 16:16:36 +05:30 committed by GitHub
commit 4e5c681e62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 719 additions and 12821 deletions

View file

@ -54,6 +54,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Run Neo4j Example
env:
ENV: dev
@ -66,9 +70,9 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: |
uv run python examples/database_examples/neo4j_example.py

73
.github/workflows/distributed_test.yml vendored Normal file
View file

@ -0,0 +1,73 @@
name: Distributed Cognee test with modal
permissions:
contents: read
on:
workflow_call:
inputs:
python-version:
required: false
type: string
default: '3.11.x'
secrets:
LLM_MODEL:
required: true
LLM_ENDPOINT:
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
required: true
EMBEDDING_API_KEY:
required: true
EMBEDDING_API_VERSION:
required: true
OPENAI_API_KEY:
required: true
jobs:
run-server-start-test:
name: Distributed Cognee test (Modal)
runs-on: ubuntu-22.04
steps:
- name: Check out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "distributed postgres"
- name: Run Distributed Cognee (Modal)
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
MODAL_SECRET_NAME: ${{ secrets.MODAL_SECRET_NAME }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.AZURE_NEO4j_URL }}
GRAPH_DATABASE_USERNAME: ${{ secrets.AZURE_NEO4J_USERNAME }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.AZURE_NEO4J_PW }}
DB_PROVIDER: "postgres"
DB_NAME: ${{ secrets.AZURE_POSTGRES_DB_NAME }}
DB_HOST: ${{ secrets.AZURE_POSTGRES_HOST }}
DB_PORT: ${{ secrets.AZURE_POSTGRES_PORT }}
DB_USERNAME: ${{ secrets.AZURE_POSTGRES_USERNAME }}
DB_PASSWORD: ${{ secrets.AZURE_POSTGRES_PW }}
VECTOR_DB_PROVIDER: "pgvector"
COGNEE_DISTRIBUTED: "true"
run: uv run modal run ./distributed/entrypoint.py

View file

@ -71,6 +71,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Run default Neo4j
env:
ENV: 'dev'
@ -83,9 +87,9 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_neo4j.py
- name: Run Weighted Edges Tests with Neo4j

View file

@ -186,6 +186,10 @@ jobs:
python-version: '3.11.x'
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Install specific db dependency
run: echo "Dependencies already installed in setup"
@ -206,9 +210,9 @@ jobs:
env:
ENV: 'dev'
GRAPH_DATABASE_PROVIDER: "neo4j"
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}

View file

@ -51,20 +51,6 @@ jobs:
name: Search test for Neo4j/LanceDB/Sqlite
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
steps:
- name: Check out
@ -77,6 +63,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -94,9 +84,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_search_db.py
run-kuzu-pgvector-postgres-search-tests:
@ -158,19 +148,6 @@ jobs:
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
postgres:
image: pgvector/pgvector:pg17
env:
@ -196,6 +173,10 @@ jobs:
python-version: ${{ inputs.python-version }}
extra-dependencies: "postgres"
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -213,9 +194,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'pgvector'
DB_PROVIDER: 'postgres'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
DB_NAME: cognee_db
DB_HOST: 127.0.0.1
DB_PORT: 5432

View file

@ -51,20 +51,6 @@ jobs:
name: Temporal Graph test Neo4j (lancedb + sqlite)
runs-on: ubuntu-22.04
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
services:
neo4j:
image: neo4j:5.11
env:
NEO4J_AUTH: neo4j/pleaseletmein
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
ports:
- 7474:7474
- 7687:7687
options: >-
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
--health-interval=10s
--health-timeout=5s
--health-retries=5
steps:
- name: Check out
@ -77,6 +63,10 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
@ -94,9 +84,9 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite'
GRAPH_DATABASE_URL: bolt://localhost:7687
GRAPH_DATABASE_USERNAME: neo4j
GRAPH_DATABASE_PASSWORD: pleaseletmein
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
run: uv run python ./cognee/tests/test_temporal_graph.py
run_temporal_graph_kuzu_postgres_pgvector:

View file

@ -27,6 +27,12 @@ jobs:
uses: ./.github/workflows/e2e_tests.yml
secrets: inherit
distributed-tests:
name: Distributed Cognee Test
needs: [ basic-tests, e2e-tests, graph-db-tests ]
uses: ./.github/workflows/distributed_test.yml
secrets: inherit
cli-tests:
name: CLI Tests
uses: ./.github/workflows/cli_tests.yml
@ -104,7 +110,7 @@ jobs:
db-examples-tests:
name: DB Examples Tests
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests, distributed-tests]
uses: ./.github/workflows/db_examples_tests.yml
secrets: inherit

View file

@ -86,12 +86,19 @@ jobs:
with:
python-version: '3.11'
- name: Setup Neo4j with GDS
uses: ./.github/actions/setup_neo4j
id: neo4j
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Weighted Edges Tests
env:
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
GRAPH_DATABASE_PASSWORD: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-password || '' }}
run: |
uv run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short

View file

@ -217,10 +217,24 @@ export default function GraphVisualization({ ref, data, graphControls, className
const [graphShape, setGraphShape] = useState<string>();
const zoomToFit: ForceGraphMethods["zoomToFit"] = (
durationMs?: number,
padding?: number,
nodeFilter?: (node: NodeObject) => boolean
) => {
if (!graphRef.current) {
console.warn("GraphVisualization: graphRef not ready yet");
return undefined as any;
}
return graphRef.current.zoomToFit?.(durationMs, padding, nodeFilter);
};
useImperativeHandle(ref, () => ({
zoomToFit: graphRef.current!.zoomToFit,
setGraphShape: setGraphShape,
zoomToFit,
setGraphShape,
}));
return (
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">

View file

@ -117,5 +117,4 @@ async def add_rule_associations(data: str, rules_nodeset_name: str):
if len(edges_to_save) > 0:
await graph_engine.add_edges(edges_to_save)
await index_graph_edges()
await index_graph_edges(edges_to_save)

View file

@ -68,6 +68,7 @@ class Neo4jAdapter(GraphDBInterface):
auth=auth,
max_connection_lifetime=120,
notifications_min_severity="OFF",
keep_alive=True,
)
async def initialize(self) -> None:
@ -205,7 +206,7 @@ class Neo4jAdapter(GraphDBInterface):
{
"node_id": str(node.id),
"label": type(node).__name__,
"properties": self.serialize_properties(node.model_dump()),
"properties": self.serialize_properties(dict(node)),
}
for node in nodes
]

View file

@ -8,7 +8,7 @@ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
logger = get_logger("deadlock_retry")
def deadlock_retry(max_retries=5):
def deadlock_retry(max_retries=10):
"""
Decorator that automatically retries an asynchronous function when rate limit errors occur.

View file

@ -9,7 +9,10 @@ async def get_dataset_data(dataset_id: UUID) -> list[Data]:
async with db_engine.get_async_session() as session:
result = await session.execute(
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
select(Data)
.join(Data.datasets)
.filter((Dataset.id == dataset_id))
.order_by(Data.data_size.desc())
)
data = list(result.scalars().all())

View file

@ -5,7 +5,6 @@ from typing import Optional
class TableRow(DataPoint):
name: str
is_a: Optional[TableType] = None
description: str
properties: str

View file

@ -4,35 +4,28 @@ import asyncio
from uuid import UUID
from typing import Any, List
from functools import wraps
from sqlalchemy import select
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
from cognee.modules.users.models import User
from cognee.modules.data.models import Data
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.utils import generate_pipeline_id
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
from cognee.tasks.ingestion import resolve_data_directories
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunErrored,
PipelineRunStarted,
PipelineRunYield,
PipelineRunAlreadyCompleted,
)
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from .run_tasks_data_item import run_tasks_data_item
from ..tasks.task import Task
@ -68,176 +61,6 @@ async def run_tasks(
context: dict = None,
incremental_loading: bool = False,
):
async def _run_tasks_data_item_incremental(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
):
db_engine = get_relational_engine()
# If incremental_loading of data is set to True don't process documents already processed by pipeline
# If data is being added to Cognee for the first time calculate the id of the data
if not isinstance(data_item, Data):
file_path = await save_data_item_to_storage(data_item)
# Ingest data and add metadata
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id
# Check pipeline status, if Data already processed for pipeline before skip current processing
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
if data_point:
if (
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
):
yield {
"run_info": PipelineRunAlreadyCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
return
try:
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
# Update pipeline status for Data element
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
data_point.pipeline_status[pipeline_name] = {
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
}
await session.merge(data_point)
await session.commit()
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
except Exception as error:
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
logger.error(
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
)
yield {
"run_info": PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
raise error
async def _run_tasks_data_item_regular(
data_item,
dataset,
tasks,
pipeline_id,
pipeline_run_id,
context,
user,
):
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)
}
async def _run_tasks_data_item(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
):
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
result = None
if incremental_loading:
async for result in _run_tasks_data_item_incremental(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
else:
async for result in _run_tasks_data_item_regular(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
return result
if not user:
user = await get_default_user()
@ -269,7 +92,7 @@ async def run_tasks(
# Create async tasks per data item that will run the pipeline for the data item
data_item_tasks = [
asyncio.create_task(
_run_tasks_data_item(
run_tasks_data_item(
data_item,
dataset,
tasks,

View file

@ -0,0 +1,261 @@
"""
Data item processing functions for pipeline operations.
This module contains reusable functions for processing individual data items
within pipeline operations, supporting both incremental and regular processing modes.
"""
import os
from typing import Any, Dict, AsyncGenerator, Optional
from sqlalchemy import select
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.data.models import Data, Dataset
from cognee.tasks.ingestion import save_data_item_to_storage
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunErrored,
PipelineRunYield,
PipelineRunAlreadyCompleted,
)
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
from cognee.modules.pipelines.operations.run_tasks_with_telemetry import run_tasks_with_telemetry
from ..tasks.task import Task
logger = get_logger("run_tasks_data_item")
async def run_tasks_data_item_incremental(
data_item: Any,
dataset: Dataset,
tasks: list[Task],
pipeline_name: str,
pipeline_id: str,
pipeline_run_id: str,
context: Optional[Dict[str, Any]],
user: User,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Process a single data item with incremental loading support.
This function handles incremental processing by checking if the data item
has already been processed for the given pipeline and dataset. If it has,
it skips processing and returns a completion status.
Args:
data_item: The data item to process
dataset: The dataset containing the data item
tasks: List of tasks to execute on the data item
pipeline_name: Name of the pipeline
pipeline_id: Unique identifier for the pipeline
pipeline_run_id: Unique identifier for this pipeline run
context: Optional context dictionary
user: User performing the operation
Yields:
Dict containing run_info and data_id for each processing step
"""
db_engine = get_relational_engine()
# If incremental_loading of data is set to True don't process documents already processed by pipeline
# If data is being added to Cognee for the first time calculate the id of the data
if not isinstance(data_item, Data):
file_path = await save_data_item_to_storage(data_item)
# Ingest data and add metadata
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id
# Check pipeline status, if Data already processed for pipeline before skip current processing
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
if data_point:
if (
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
):
yield {
"run_info": PipelineRunAlreadyCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
return
try:
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
# Update pipeline status for Data element
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
data_point.pipeline_status[pipeline_name] = {
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
}
await session.merge(data_point)
await session.commit()
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
except Exception as error:
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
logger.error(
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
)
yield {
"run_info": PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
raise error
async def run_tasks_data_item_regular(
data_item: Any,
dataset: Dataset,
tasks: list[Task],
pipeline_id: str,
pipeline_run_id: str,
context: Optional[Dict[str, Any]],
user: User,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Process a single data item in regular (non-incremental) mode.
This function processes a data item without checking for previous processing
status, executing all tasks on the data item.
Args:
data_item: The data item to process
dataset: The dataset containing the data item
tasks: List of tasks to execute on the data item
pipeline_id: Unique identifier for the pipeline
pipeline_run_id: Unique identifier for this pipeline run
context: Optional context dictionary
user: User performing the operation
Yields:
Dict containing run_info for each processing step
"""
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)
}
async def run_tasks_data_item(
data_item: Any,
dataset: Dataset,
tasks: list[Task],
pipeline_name: str,
pipeline_id: str,
pipeline_run_id: str,
context: Optional[Dict[str, Any]],
user: User,
incremental_loading: bool,
) -> Optional[Dict[str, Any]]:
"""
Process a single data item, choosing between incremental and regular processing.
This is the main entry point for data item processing that delegates to either
incremental or regular processing based on the incremental_loading flag.
Args:
data_item: The data item to process
dataset: The dataset containing the data item
tasks: List of tasks to execute on the data item
pipeline_name: Name of the pipeline
pipeline_id: Unique identifier for the pipeline
pipeline_run_id: Unique identifier for this pipeline run
context: Optional context dictionary
user: User performing the operation
incremental_loading: Whether to use incremental processing
Returns:
Dict containing the final processing result, or None if processing was skipped
"""
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
result = None
if incremental_loading:
async for result in run_tasks_data_item_incremental(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
else:
async for result in run_tasks_data_item_regular(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
return result

View file

@ -3,49 +3,96 @@ try:
except ModuleNotFoundError:
modal = None
from typing import Any, List, Optional
from uuid import UUID
from cognee.modules.pipelines.tasks.task import Task
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.pipelines.models import (
PipelineRunStarted,
PipelineRunYield,
PipelineRunCompleted,
PipelineRunErrored,
)
from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from cognee.modules.pipelines.utils import generate_pipeline_id
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from cognee.modules.users.models import User
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
from cognee.tasks.ingestion import resolve_data_directories
from .run_tasks_data_item import run_tasks_data_item
logger = get_logger("run_tasks_distributed()")
if modal:
import os
from distributed.app import app
from distributed.modal_image import image
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
@app.function(
retries=3,
image=image,
timeout=86400,
max_containers=50,
secrets=[modal.Secret.from_name("distributed_cognee")],
secrets=[modal.Secret.from_name(secret_name)],
)
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
async def run_tasks_on_modal(
data_item,
dataset_id: UUID,
tasks: List[Task],
pipeline_name: str,
pipeline_id: str,
pipeline_run_id: str,
context: Optional[dict],
user: User,
incremental_loading: bool,
):
"""
Wrapper that runs the run_tasks_data_item function.
This is the function/code that runs on modal executor and produces the graph/vector db objects
"""
from cognee.infrastructure.databases.relational import get_relational_engine
run_info = None
async with get_relational_engine().get_async_session() as session:
from cognee.modules.data.models import Dataset
async for pipeline_run_info in pipeline_run:
run_info = pipeline_run_info
dataset = await session.get(Dataset, dataset_id)
return run_info
result = await run_tasks_data_item(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
incremental_loading=incremental_loading,
)
return result
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
async def run_tasks_distributed(
tasks: List[Task],
dataset_id: UUID,
data: List[Any] = None,
user: User = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
incremental_loading: bool = False,
):
if not user:
user = await get_default_user()
# Get dataset object
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
@ -53,9 +100,7 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run_id = pipeline_run.pipeline_run_id
yield PipelineRunStarted(
@ -65,30 +110,67 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
payload=data,
)
data_count = len(data) if isinstance(data, list) else 1
try:
if not isinstance(data, list):
data = [data]
arguments = [
[tasks] * data_count,
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
[user] * data_count,
[pipeline_name] * data_count,
[context] * data_count,
]
data = await resolve_data_directories(data)
async for result in run_tasks_on_modal.map.aio(*arguments):
logger.info(f"Received result: {result}")
number_of_data_items = len(data) if isinstance(data, list) else 1
yield PipelineRunYield(
data_item_tasks = [
data,
[dataset.id] * number_of_data_items,
[tasks] * number_of_data_items,
[pipeline_name] * number_of_data_items,
[pipeline_id] * number_of_data_items,
[pipeline_run_id] * number_of_data_items,
[context] * number_of_data_items,
[user] * number_of_data_items,
[incremental_loading] * number_of_data_items,
]
results = []
async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
if not result:
continue
results.append(result)
# Remove skipped results
results = [r for r in results if r]
# If any data item failed, raise PipelineRunFailedError
errored = [
r
for r in results
if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
]
if errored:
raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
await log_pipeline_run_complete(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
)
yield PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
data_ingestion_info=results,
)
await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
except Exception as error:
await log_pipeline_run_error(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
)
yield PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)
yield PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=locals().get("results"),
)
if not isinstance(error, PipelineRunFailedError):
raise

View file

@ -194,7 +194,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
belongs_to_set=interactions_node_set,
)
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
await add_data_points(data_points=[cognee_user_interaction])
relationships = []
relationship_name = "used_graph_element_to_answer"

View file

@ -8,7 +8,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_feedback import BaseFeedback
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage import add_data_points, index_graph_edges
logger = get_logger("CompletionRetriever")
@ -47,7 +47,7 @@ class UserQAFeedback(BaseFeedback):
belongs_to_set=feedbacks_node_set,
)
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
await add_data_points(data_points=[cognee_user_feedback])
relationships = []
relationship_name = "gives_feedback_to"
@ -76,6 +76,7 @@ class UserQAFeedback(BaseFeedback):
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
await index_graph_edges(relationships)
await graph_engine.apply_feedback_weight(
node_ids=to_node_ids, weight=feedback_sentiment.score
)

View file

@ -124,5 +124,4 @@ async def add_rule_associations(
if len(edges_to_save) > 0:
await graph_engine.add_edges(edges_to_save)
await index_graph_edges()
await index_graph_edges(edges_to_save)

View file

@ -4,6 +4,7 @@ from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
from cognee.tasks.storage import index_graph_edges
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.get_default_ontology_resolver import (
@ -88,6 +89,7 @@ async def integrate_chunk_graphs(
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
await index_graph_edges(graph_edges)
return data_chunks

View file

@ -10,9 +10,7 @@ from cognee.tasks.storage.exceptions import (
)
async def add_data_points(
data_points: List[DataPoint], update_edge_collection: bool = True
) -> List[DataPoint]:
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
"""
Add a batch of data points to the graph database by extracting nodes and edges,
deduplicating them, and indexing them for retrieval.
@ -25,9 +23,6 @@ async def add_data_points(
Args:
data_points (List[DataPoint]):
A list of data points to process and insert into the graph.
update_edge_collection (bool, optional):
Whether to update the edge index after adding edges.
Defaults to True.
Returns:
List[DataPoint]:
@ -73,12 +68,10 @@ async def add_data_points(
graph_engine = await get_graph_engine()
await graph_engine.add_nodes(nodes)
await index_data_points(nodes)
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
if update_edge_collection:
await index_graph_edges()
await index_graph_edges(edges)
return data_points

View file

@ -1,15 +1,18 @@
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.shared.logging_utils import get_logger
from collections import Counter
from typing import Optional, Dict, Any, List, Tuple, Union
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
logger = get_logger(level=ERROR)
logger = get_logger()
async def index_graph_edges():
async def index_graph_edges(
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
):
"""
Indexes graph edges by creating and managing vector indexes for relationship types.
@ -35,13 +38,17 @@ async def index_graph_edges():
index_points = {}
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
if edges_data is None:
graph_engine = await get_graph_engine()
_, edges_data = await graph_engine.get_graph_data()
logger.warning(
"Your graph edge embedding is deprecated, please pass edges to the index_graph_edges directly."
)
except Exception as e:
logger.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
_, edges_data = await graph_engine.get_graph_data()
edge_types = Counter(
item.get("relationship_name")
for edge in edges_data

View file

@ -29,6 +29,3 @@ RUN poetry install --extras neo4j --extras postgres --extras aws --extras distri
COPY cognee/ /app/cognee
COPY distributed/ /app/distributed
RUN chmod +x /app/distributed/entrypoint.sh
ENTRYPOINT ["/app/distributed/entrypoint.sh"]

View file

@ -10,6 +10,7 @@ from distributed.app import app
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
from distributed.workers.graph_saving_worker import graph_saving_worker
from distributed.workers.data_point_saving_worker import data_point_saving_worker
from distributed.signal import QueueSignal
logger = get_logger()
@ -23,13 +24,14 @@ async def main():
await add_nodes_and_edges_queue.clear.aio()
await add_data_points_queue.clear.aio()
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn (MAX 1)
number_of_data_point_saving_workers = (
10 # Total number of graph_saving_worker to spawn (MAX 10)
)
results = []
consumer_futures = []
# await prune.prune_data() # We don't want to delete files on s3
await prune.prune_data() # This prunes the data from the file storage
# Delete DBs and saved files from metastore
await prune.prune_system(metadata=True)
@ -45,16 +47,28 @@ async def main():
worker_future = data_point_saving_worker.spawn()
consumer_futures.append(worker_future)
""" Example: Setting and adding S3 path as input
s3_bucket_path = os.getenv("S3_BUCKET_PATH")
s3_data_path = "s3://" + s3_bucket_path
await cognee.add(s3_data_path, dataset_name="s3-files")
"""
await cognee.add(
[
"Audi is a German car manufacturer",
"The Netherlands is next to Germany",
"Berlin is the capital of Germany",
"The Rhine is a major European river",
"BMW produces luxury vehicles",
],
dataset_name="s3-files",
)
await cognee.cognify(datasets=["s3-files"])
# Push empty tuple into the queue to signal the end of data.
await add_nodes_and_edges_queue.put.aio(())
await add_data_points_queue.put.aio(())
# Put Processing end signal into the queues to stop the consumers
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
await add_data_points_queue.put.aio(QueueSignal.STOP)
for consumer_future in consumer_futures:
try:
@ -64,8 +78,6 @@ async def main():
except Exception as e:
logger.error(e)
print(results)
if __name__ == "__main__":
asyncio.run(main())

12238
distributed/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,185 +0,0 @@
[project]
name = "cognee"
version = "0.2.2.dev0"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },
{ name = "Boris Arzentar" },
]
requires-python = ">=3.10,<=3.13"
readme = "README.md"
license = "Apache-2.0"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Libraries",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
]
dependencies = [
"openai>=1.80.1,<2",
"python-dotenv>=1.0.1,<2.0.0",
"pydantic>=2.11.7,<3.0.0",
"pydantic-settings>=2.10.1,<3",
"typing_extensions>=4.12.2,<5.0.0",
"nltk>=3.9.1,<4.0.0",
"numpy>=1.26.4, <=4.0.0",
"pandas>=2.2.2,<3.0.0",
# Note: New s3fs and boto3 versions don't work well together
# Always use comaptible fixed versions of these two dependencies
"s3fs[boto3]==2025.3.2",
"sqlalchemy>=2.0.39,<3.0.0",
"aiosqlite>=0.20.0,<1.0.0",
"tiktoken>=0.8.0,<1.0.0",
"litellm>=1.57.4, <1.71.0",
"instructor>=1.9.1,<2.0.0",
"langfuse>=2.32.0,<3",
"filetype>=1.2.0,<2.0.0",
"aiohttp>=3.11.14,<4.0.0",
"aiofiles>=23.2.1,<24.0.0",
"rdflib>=7.1.4,<7.2.0",
"pypdf>=4.1.0,<7.0.0",
"jinja2>=3.1.3,<4",
"matplotlib>=3.8.3,<4",
"networkx>=3.4.2,<4",
"lancedb>=0.24.0,<1.0.0",
"alembic>=1.13.3,<2",
"pre-commit>=4.0.1,<5",
"scikit-learn>=1.6.1,<2",
"limits>=4.4.1,<5",
"fastapi>=0.115.7,<1.0.0",
"python-multipart>=0.0.20,<1.0.0",
"fastapi-users[sqlalchemy]>=14.0.1,<15.0.0",
"dlt[sqlalchemy]>=1.9.0,<2",
"sentry-sdk[fastapi]>=2.9.0,<3",
"structlog>=25.2.0,<26",
"pympler>=1.1,<2.0.0",
"onnxruntime>=1.0.0,<2.0.0",
"pylance>=0.22.0,<1.0.0",
"kuzu (==0.11.0)"
]
[project.optional-dependencies]
api = [
"uvicorn>=0.34.0,<1.0.0",
"gunicorn>=20.1.0,<24",
"websockets>=15.0.1,<16.0.0"
]
distributed = [
"modal>=1.0.5,<2.0.0",
]
neo4j = ["neo4j>=5.28.0,<6"]
postgres = [
"psycopg2>=2.9.10,<3",
"pgvector>=0.3.5,<0.4",
"asyncpg>=0.30.0,<1.0.0",
]
postgres-binary = [
"psycopg2-binary>=2.9.10,<3.0.0",
"pgvector>=0.3.5,<0.4",
"asyncpg>=0.30.0,<1.0.0",
]
notebook = ["notebook>=7.1.0,<8"]
langchain = [
"langsmith>=0.2.3,<1.0.0",
"langchain_text_splitters>=0.3.2,<1.0.0",
]
llama-index = ["llama-index-core>=0.12.11,<0.13"]
gemini = ["google-generativeai>=0.8.4,<0.9"]
huggingface = ["transformers>=4.46.3,<5"]
ollama = ["transformers>=4.46.3,<5"]
mistral = ["mistral-common>=1.5.2,<2"]
anthropic = ["anthropic>=0.26.1,<0.27"]
deepeval = ["deepeval>=2.0.1,<3"]
posthog = ["posthog>=3.5.0,<4"]
groq = ["groq>=0.8.0,<1.0.0"]
chromadb = [
"chromadb>=0.3.0,<0.7",
"pypika==0.48.8",
]
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.18.1,<19"]
codegraph = [
"fastembed<=0.6.0 ; python_version < '3.13'",
"transformers>=4.46.3,<5",
"tree-sitter>=0.24.0,<0.25",
"tree-sitter-python>=0.23.6,<0.24",
]
evals = [
"plotly>=6.0.0,<7",
"gdown>=5.2.0,<6",
]
gui = [
"pyside6>=6.8.3,<7",
"qasync>=0.27.1,<0.28",
]
graphiti = ["graphiti-core>=0.7.0,<0.8"]
# Note: New s3fs and boto3 versions don't work well together
# Always use comaptible fixed versions of these two dependencies
aws = ["s3fs[boto3]==2025.3.2"]
dev = [
"pytest>=7.4.0,<8",
"pytest-cov>=6.1.1,<7.0.0",
"pytest-asyncio>=0.21.1,<0.22",
"coverage>=7.3.2,<8",
"mypy>=1.7.1,<2",
"notebook>=7.1.0,<8",
"deptry>=0.20.0,<0.21",
"pylint>=3.0.3,<4",
"ruff>=0.9.2,<1.0.0",
"tweepy>=4.14.0,<5.0.0",
"gitpython>=3.1.43,<4",
"mkdocs-material>=9.5.42,<10",
"mkdocs-minify-plugin>=0.8.0,<0.9",
"mkdocstrings[python]>=0.26.2,<0.27",
]
debug = ["debugpy>=1.8.9,<2.0.0"]
[project.urls]
Homepage = "https://www.cognee.ai"
Repository = "https://github.com/topoteretes/cognee"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build]
exclude = [
"/bin",
"/dist",
"/.data",
"/.github",
"/alembic",
"/deployment",
"/cognee-mcp",
"/cognee-frontend",
"/examples",
"/helm",
"/licenses",
"/logs",
"/notebooks",
"/profiling",
"/tests",
"/tools",
]
[tool.hatch.build.targets.wheel]
packages = ["cognee", "distributed"]
[tool.ruff]
line-length = 100
exclude = [
"migrations/", # Ignore migrations directory
"notebooks/", # Ignore notebook files
"build/", # Ignore build directory
"cognee/pipelines.py",
"cognee/modules/users/models/Group.py",
"cognee/modules/users/models/ACL.py",
"cognee/modules/pipelines/models/Task.py",
"cognee/modules/data/models/Dataset.py"
]
[tool.ruff.lint]
ignore = ["F401"]

5
distributed/signal.py Normal file
View file

@ -0,0 +1,5 @@
from enum import Enum
class QueueSignal(str, Enum):
STOP = "STOP"

View file

@ -1,16 +1,17 @@
import os
import modal
import asyncio
from sqlalchemy.exc import OperationalError, DBAPIError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from distributed.app import app
from distributed.signal import QueueSignal
from distributed.modal_image import image
from distributed.queues import add_data_points_queue
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
logger = get_logger("data_point_saving_worker")
@ -39,55 +40,84 @@ def is_deadlock_error(error):
return False
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
@app.function(
retries=3,
image=image,
timeout=86400,
max_containers=5,
secrets=[modal.Secret.from_name("distributed_cognee")],
max_containers=10,
secrets=[modal.Secret.from_name(secret_name)],
)
async def data_point_saving_worker():
print("Started processing of data points; starting vector engine queue.")
vector_engine = get_vector_engine()
# Defines how many data packets do we glue together from the modal queue before embedding call and ingestion
BATCH_SIZE = 25
stop_seen = False
while True:
if stop_seen:
print("Finished processing all data points; stopping vector engine queue consumer.")
return True
if await add_data_points_queue.len.aio() != 0:
try:
add_data_points_request = await add_data_points_queue.get.aio(block=False)
print("Remaining elements in queue:")
print(await add_data_points_queue.len.aio())
# collect batched requests
batched_points = {}
for _ in range(min(BATCH_SIZE, await add_data_points_queue.len.aio())):
add_data_points_request = await add_data_points_queue.get.aio(block=False)
if not add_data_points_request:
continue
if add_data_points_request == QueueSignal.STOP:
await add_data_points_queue.put.aio(QueueSignal.STOP)
stop_seen = True
break
if len(add_data_points_request) == 2:
collection_name, data_points = add_data_points_request
if collection_name not in batched_points:
batched_points[collection_name] = []
batched_points[collection_name].extend(data_points)
else:
print("NoneType or invalid request detected.")
if batched_points:
for collection_name, data_points in batched_points.items():
print(
f"Adding {len(data_points)} data points to '{collection_name}' collection."
)
@retry(
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def add_data_points():
try:
await vector_engine.create_data_points(
collection_name, data_points, distributed=False
)
except DBAPIError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
except OperationalError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
await add_data_points()
print(f"Finished adding data points to '{collection_name}'.")
except modal.exception.DeserializationError as error:
logger.error(f"Deserialization error: {str(error)}")
continue
if len(add_data_points_request) == 0:
print("Finished processing all data points; stopping vector engine queue.")
return True
if len(add_data_points_request) == 2:
(collection_name, data_points) = add_data_points_request
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
@retry(
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def add_data_points():
try:
await vector_engine.create_data_points(
collection_name, data_points, distributed=False
)
except DBAPIError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
except OperationalError as error:
if is_deadlock_error(error):
raise VectorDatabaseDeadlockError()
await add_data_points()
print("Finished adding data points.")
else:
print("No jobs, go to sleep.")
await asyncio.sleep(5)

View file

@ -1,8 +1,10 @@
import os
import modal
import asyncio
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from distributed.app import app
from distributed.signal import QueueSignal
from distributed.modal_image import image
from distributed.queues import add_nodes_and_edges_queue
@ -10,7 +12,6 @@ from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.config import get_graph_config
logger = get_logger("graph_saving_worker")
@ -37,68 +38,91 @@ def is_deadlock_error(error):
return False
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
@app.function(
retries=3,
image=image,
timeout=86400,
max_containers=5,
secrets=[modal.Secret.from_name("distributed_cognee")],
max_containers=1,
secrets=[modal.Secret.from_name(secret_name)],
)
async def graph_saving_worker():
print("Started processing of nodes and edges; starting graph engine queue.")
graph_engine = await get_graph_engine()
# Defines how many data packets do we glue together from the queue before ingesting them into the graph database
BATCH_SIZE = 25
stop_seen = False
while True:
if stop_seen:
print("Finished processing all data points; stopping graph engine queue consumer.")
return True
if await add_nodes_and_edges_queue.len.aio() != 0:
try:
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
print("Remaining elements in queue:")
print(await add_nodes_and_edges_queue.len.aio())
all_nodes, all_edges = [], []
for _ in range(min(BATCH_SIZE, await add_nodes_and_edges_queue.len.aio())):
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
if not nodes_and_edges:
continue
if nodes_and_edges == QueueSignal.STOP:
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
stop_seen = True
break
if len(nodes_and_edges) == 2:
nodes, edges = nodes_and_edges
all_nodes.extend(nodes)
all_edges.extend(edges)
else:
print("None Type detected.")
if all_nodes or all_edges:
print(f"Adding {len(all_nodes)} nodes and {len(all_edges)} edges.")
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_nodes(new_nodes):
try:
await graph_engine.add_nodes(new_nodes, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_edges(new_edges):
try:
await graph_engine.add_edges(new_edges, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
if all_nodes:
await save_graph_nodes(all_nodes)
if all_edges:
await save_graph_edges(all_edges)
print("Finished adding nodes and edges.")
except modal.exception.DeserializationError as error:
logger.error(f"Deserialization error: {str(error)}")
continue
if len(nodes_and_edges) == 0:
print("Finished processing all nodes and edges; stopping graph engine queue.")
return True
if len(nodes_and_edges) == 2:
print(
f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
)
nodes = nodes_and_edges[0]
edges = nodes_and_edges[1]
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_nodes(new_nodes):
try:
await graph_engine.add_nodes(new_nodes, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
@retry(
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def save_graph_edges(new_edges):
try:
await graph_engine.add_edges(new_edges, distributed=False)
except Exception as error:
if is_deadlock_error(error):
raise GraphDatabaseDeadlockError()
if nodes:
await save_graph_nodes(nodes)
if edges:
await save_graph_edges(edges)
print("Finished adding nodes and edges.")
else:
print("No jobs, go to sleep.")
await asyncio.sleep(5)