Merge branch 'dev' into feature/cog-3532-empower-test_search-db-retrievers-tests-reorg
This commit is contained in:
commit
3a48930c3b
15 changed files with 302 additions and 25 deletions
27
.github/workflows/e2e_tests.yml
vendored
27
.github/workflows/e2e_tests.yml
vendored
|
|
@ -582,3 +582,30 @@ jobs:
|
||||||
DB_USERNAME: cognee
|
DB_USERNAME: cognee
|
||||||
DB_PASSWORD: cognee
|
DB_PASSWORD: cognee
|
||||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||||
|
|
||||||
|
run-pipeline-cache-test:
|
||||||
|
name: Test Pipeline Caching
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Run Pipeline Cache Test
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_pipeline_cache.py
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,20 @@ def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||||
sa.Column("graph_database_name", sa.String(), nullable=False),
|
sa.Column("graph_database_name", sa.String(), nullable=False),
|
||||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"vector_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="lancedb",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"graph_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="kuzu",
|
||||||
|
),
|
||||||
sa.Column("vector_database_url", sa.String()),
|
sa.Column("vector_database_url", sa.String()),
|
||||||
sa.Column("graph_database_url", sa.String()),
|
sa.Column("graph_database_url", sa.String()),
|
||||||
sa.Column("vector_database_key", sa.String()),
|
sa.Column("vector_database_key", sa.String()),
|
||||||
|
|
@ -82,6 +96,8 @@ def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||||
graph_database_name,
|
graph_database_name,
|
||||||
vector_database_provider,
|
vector_database_provider,
|
||||||
graph_database_provider,
|
graph_database_provider,
|
||||||
|
vector_dataset_database_handler,
|
||||||
|
graph_dataset_database_handler,
|
||||||
vector_database_url,
|
vector_database_url,
|
||||||
graph_database_url,
|
graph_database_url,
|
||||||
vector_database_key,
|
vector_database_key,
|
||||||
|
|
@ -120,6 +136,20 @@ def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||||
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
||||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"vector_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="lancedb",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"graph_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="kuzu",
|
||||||
|
),
|
||||||
sa.Column("vector_database_url", sa.String()),
|
sa.Column("vector_database_url", sa.String()),
|
||||||
sa.Column("graph_database_url", sa.String()),
|
sa.Column("graph_database_url", sa.String()),
|
||||||
sa.Column("vector_database_key", sa.String()),
|
sa.Column("vector_database_key", sa.String()),
|
||||||
|
|
@ -153,6 +183,8 @@ def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||||
graph_database_name,
|
graph_database_name,
|
||||||
vector_database_provider,
|
vector_database_provider,
|
||||||
graph_database_provider,
|
graph_database_provider,
|
||||||
|
vector_dataset_database_handler,
|
||||||
|
graph_dataset_database_handler,
|
||||||
vector_database_url,
|
vector_database_url,
|
||||||
graph_database_url,
|
graph_database_url,
|
||||||
vector_database_key,
|
vector_database_key,
|
||||||
|
|
@ -193,6 +225,22 @@ def upgrade() -> None:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vector_dataset_database_handler = _get_column(
|
||||||
|
insp, "dataset_database", "vector_dataset_database_handler"
|
||||||
|
)
|
||||||
|
if not vector_dataset_database_handler:
|
||||||
|
# Add LanceDB as the default graph dataset database handler
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"vector_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="lancedb",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
graph_database_connection_info_column = _get_column(
|
graph_database_connection_info_column = _get_column(
|
||||||
insp, "dataset_database", "graph_database_connection_info"
|
insp, "dataset_database", "graph_database_connection_info"
|
||||||
)
|
)
|
||||||
|
|
@ -208,6 +256,22 @@ def upgrade() -> None:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
graph_dataset_database_handler = _get_column(
|
||||||
|
insp, "dataset_database", "graph_dataset_database_handler"
|
||||||
|
)
|
||||||
|
if not graph_dataset_database_handler:
|
||||||
|
# Add Kuzu as the default graph dataset database handler
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"graph_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="kuzu",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||||
# Drop the unique constraint to make unique=False
|
# Drop the unique constraint to make unique=False
|
||||||
graph_constraint_to_drop = None
|
graph_constraint_to_drop = None
|
||||||
|
|
@ -265,3 +329,5 @@ def downgrade() -> None:
|
||||||
|
|
||||||
op.drop_column("dataset_database", "vector_database_connection_info")
|
op.drop_column("dataset_database", "vector_database_connection_info")
|
||||||
op.drop_column("dataset_database", "graph_database_connection_info")
|
op.drop_column("dataset_database", "graph_database_connection_info")
|
||||||
|
op.drop_column("dataset_database", "vector_dataset_database_handler")
|
||||||
|
op.drop_column("dataset_database", "graph_dataset_database_handler")
|
||||||
|
|
|
||||||
|
|
@ -205,6 +205,7 @@ async def add(
|
||||||
pipeline_name="add_pipeline",
|
pipeline_name="add_pipeline",
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
|
use_pipeline_cache=True,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -237,6 +237,7 @@ async def cognify(
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
|
use_pipeline_cache=True,
|
||||||
pipeline_name="cognify_pipeline",
|
pipeline_name="cognify_pipeline",
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
"graph_database_url": graph_db_url,
|
"graph_database_url": graph_db_url,
|
||||||
"graph_database_provider": graph_config.graph_database_provider,
|
"graph_database_provider": graph_config.graph_database_provider,
|
||||||
"graph_database_key": graph_db_key,
|
"graph_database_key": graph_db_key,
|
||||||
|
"graph_dataset_database_handler": "kuzu",
|
||||||
"graph_database_connection_info": {
|
"graph_database_connection_info": {
|
||||||
"graph_database_username": graph_db_username,
|
"graph_database_username": graph_db_username,
|
||||||
"graph_database_password": graph_db_password,
|
"graph_database_password": graph_db_password,
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
"graph_database_url": graph_db_url,
|
"graph_database_url": graph_db_url,
|
||||||
"graph_database_provider": "neo4j",
|
"graph_database_provider": "neo4j",
|
||||||
"graph_database_key": graph_db_key,
|
"graph_database_key": graph_db_key,
|
||||||
|
"graph_dataset_database_handler": "neo4j_aura_dev",
|
||||||
"graph_database_connection_info": {
|
"graph_database_connection_info": {
|
||||||
"graph_database_username": graph_db_username,
|
"graph_database_username": graph_db_username,
|
||||||
"graph_database_password": encrypted_db_password_string,
|
"graph_database_password": encrypted_db_password_string,
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,21 @@
|
||||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
|
||||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||||
vector_config = get_vectordb_config()
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
supported_dataset_database_handlers,
|
supported_dataset_database_handlers,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
|
||||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||||
|
|
||||||
|
|
||||||
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||||
graph_config = get_graph_config()
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
supported_dataset_database_handlers,
|
supported_dataset_database_handlers,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
|
||||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||||
"vector_database_key": vector_config.vector_db_key,
|
"vector_database_key": vector_config.vector_db_key,
|
||||||
"vector_database_name": vector_db_name,
|
"vector_database_name": vector_db_name,
|
||||||
|
"vector_dataset_database_handler": "lancedb",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,6 @@ from cognee.context_global_variables import backend_access_control_enabled
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
|
||||||
from cognee.shared.cache import delete_cache
|
from cognee.shared.cache import delete_cache
|
||||||
from cognee.modules.users.models import DatasetDatabase
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -16,12 +14,13 @@ logger = get_logger()
|
||||||
|
|
||||||
async def prune_graph_databases():
|
async def prune_graph_databases():
|
||||||
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
|
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
|
||||||
graph_config = get_graph_config()
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
supported_dataset_database_handlers,
|
supported_dataset_database_handlers,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
handler = supported_dataset_database_handlers[
|
||||||
|
dataset_database.graph_dataset_database_handler
|
||||||
|
]
|
||||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
@ -40,13 +39,13 @@ async def prune_graph_databases():
|
||||||
|
|
||||||
async def prune_vector_databases():
|
async def prune_vector_databases():
|
||||||
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
|
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
|
||||||
vector_config = get_vectordb_config()
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
supported_dataset_database_handlers,
|
supported_dataset_database_handlers,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
handler = supported_dataset_database_handlers[
|
||||||
|
dataset_database.vector_dataset_database_handler
|
||||||
|
]
|
||||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,6 @@ from cognee.modules.users.models import User
|
||||||
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
||||||
resolve_authorized_user_datasets,
|
resolve_authorized_user_datasets,
|
||||||
)
|
)
|
||||||
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|
||||||
reset_dataset_pipeline_run_status,
|
|
||||||
)
|
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
||||||
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
|
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
|
||||||
|
|
@ -97,10 +94,6 @@ async def memify(
|
||||||
*enrichment_tasks,
|
*enrichment_tasks,
|
||||||
]
|
]
|
||||||
|
|
||||||
await reset_dataset_pipeline_run_status(
|
|
||||||
authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||||
|
|
||||||
|
|
@ -113,6 +106,7 @@ async def memify(
|
||||||
datasets=authorized_dataset.id,
|
datasets=authorized_dataset.id,
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
|
use_pipeline_cache=False,
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
pipeline_name="memify_pipeline",
|
pipeline_name="memify_pipeline",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,9 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
||||||
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
||||||
check_pipeline_run_qualification,
|
check_pipeline_run_qualification,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||||
|
PipelineRunStarted,
|
||||||
|
)
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
logger = get_logger("cognee.pipeline")
|
logger = get_logger("cognee.pipeline")
|
||||||
|
|
@ -35,6 +38,7 @@ async def run_pipeline(
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
vector_db_config: dict = None,
|
vector_db_config: dict = None,
|
||||||
graph_db_config: dict = None,
|
graph_db_config: dict = None,
|
||||||
|
use_pipeline_cache: bool = False,
|
||||||
incremental_loading: bool = False,
|
incremental_loading: bool = False,
|
||||||
data_per_batch: int = 20,
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
|
|
@ -51,6 +55,7 @@ async def run_pipeline(
|
||||||
data=data,
|
data=data,
|
||||||
pipeline_name=pipeline_name,
|
pipeline_name=pipeline_name,
|
||||||
context={"dataset": dataset},
|
context={"dataset": dataset},
|
||||||
|
use_pipeline_cache=use_pipeline_cache,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
):
|
):
|
||||||
|
|
@ -64,6 +69,7 @@ async def run_pipeline_per_dataset(
|
||||||
data=None,
|
data=None,
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
context: dict = None,
|
context: dict = None,
|
||||||
|
use_pipeline_cache=False,
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
data_per_batch: int = 20,
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
|
|
@ -77,8 +83,18 @@ async def run_pipeline_per_dataset(
|
||||||
if process_pipeline_status:
|
if process_pipeline_status:
|
||||||
# If pipeline was already processed or is currently being processed
|
# If pipeline was already processed or is currently being processed
|
||||||
# return status information to async generator and finish execution
|
# return status information to async generator and finish execution
|
||||||
yield process_pipeline_status
|
if use_pipeline_cache:
|
||||||
return
|
# If pipeline caching is enabled we do not proceed with re-processing
|
||||||
|
yield process_pipeline_status
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# If pipeline caching is disabled we always return pipeline started information and proceed with re-processing
|
||||||
|
yield PipelineRunStarted(
|
||||||
|
pipeline_run_id=process_pipeline_status.pipeline_run_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
payload=data,
|
||||||
|
)
|
||||||
|
|
||||||
pipeline_run = run_tasks(
|
pipeline_run = run_tasks(
|
||||||
tasks,
|
tasks,
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ async def run_custom_pipeline(
|
||||||
user: User = None,
|
user: User = None,
|
||||||
vector_db_config: Optional[dict] = None,
|
vector_db_config: Optional[dict] = None,
|
||||||
graph_db_config: Optional[dict] = None,
|
graph_db_config: Optional[dict] = None,
|
||||||
|
use_pipeline_cache: bool = False,
|
||||||
|
incremental_loading: bool = False,
|
||||||
data_per_batch: int = 20,
|
data_per_batch: int = 20,
|
||||||
run_in_background: bool = False,
|
run_in_background: bool = False,
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
|
|
@ -40,6 +42,10 @@ async def run_custom_pipeline(
|
||||||
user: User context for authentication and data access. Uses default if None.
|
user: User context for authentication and data access. Uses default if None.
|
||||||
vector_db_config: Custom vector database configuration for embeddings storage.
|
vector_db_config: Custom vector database configuration for embeddings storage.
|
||||||
graph_db_config: Custom graph database configuration for relationship storage.
|
graph_db_config: Custom graph database configuration for relationship storage.
|
||||||
|
use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again.
|
||||||
|
Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function.
|
||||||
|
incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model).
|
||||||
|
The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash.
|
||||||
data_per_batch: Number of data items to be processed in parallel.
|
data_per_batch: Number of data items to be processed in parallel.
|
||||||
run_in_background: If True, starts processing asynchronously and returns immediately.
|
run_in_background: If True, starts processing asynchronously and returns immediately.
|
||||||
If False, waits for completion before returning.
|
If False, waits for completion before returning.
|
||||||
|
|
@ -63,7 +69,8 @@ async def run_custom_pipeline(
|
||||||
datasets=dataset,
|
datasets=dataset,
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
incremental_loading=False,
|
use_pipeline_cache=use_pipeline_cache,
|
||||||
|
incremental_loading=incremental_loading,
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
pipeline_name=pipeline_name,
|
pipeline_name=pipeline_name,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,9 @@ class DatasetDatabase(Base):
|
||||||
vector_database_provider = Column(String, unique=False, nullable=False)
|
vector_database_provider = Column(String, unique=False, nullable=False)
|
||||||
graph_database_provider = Column(String, unique=False, nullable=False)
|
graph_database_provider = Column(String, unique=False, nullable=False)
|
||||||
|
|
||||||
|
graph_dataset_database_handler = Column(String, unique=False, nullable=False)
|
||||||
|
vector_dataset_database_handler = Column(String, unique=False, nullable=False)
|
||||||
|
|
||||||
vector_database_url = Column(String, unique=False, nullable=True)
|
vector_database_url = Column(String, unique=False, nullable=True)
|
||||||
graph_database_url = Column(String, unique=False, nullable=True)
|
graph_database_url = Column(String, unique=False, nullable=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
vector_db_name = "test.lance.db"
|
vector_db_name = "test.lance.db"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"vector_dataset_database_handler": "custom_lancedb_handler",
|
||||||
"vector_database_name": vector_db_name,
|
"vector_database_name": vector_db_name,
|
||||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||||
"vector_database_provider": "lancedb",
|
"vector_database_provider": "lancedb",
|
||||||
|
|
@ -44,6 +45,7 @@ class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
|
||||||
graph_db_name = "test.kuzu"
|
graph_db_name = "test.kuzu"
|
||||||
return {
|
return {
|
||||||
|
"graph_dataset_database_handler": "custom_kuzu_handler",
|
||||||
"graph_database_name": graph_db_name,
|
"graph_database_name": graph_db_name,
|
||||||
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
|
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
|
||||||
"graph_database_provider": "kuzu",
|
"graph_database_provider": "kuzu",
|
||||||
|
|
|
||||||
164
cognee/tests/test_pipeline_cache.py
Normal file
164
cognee/tests/test_pipeline_cache.py
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
"""
|
||||||
|
Test suite for the pipeline_cache feature in Cognee pipelines.
|
||||||
|
|
||||||
|
This module tests the behavior of the `pipeline_cache` parameter which controls
|
||||||
|
whether a pipeline should skip re-execution when it has already been completed
|
||||||
|
for the same dataset.
|
||||||
|
|
||||||
|
Architecture Overview:
|
||||||
|
---------------------
|
||||||
|
The pipeline_cache mechanism works at the dataset level:
|
||||||
|
1. When a pipeline runs, it logs its status (INITIATED -> STARTED -> COMPLETED)
|
||||||
|
2. Before each run, `check_pipeline_run_qualification()` checks the pipeline status
|
||||||
|
3. If `use_pipeline_cache=True` and status is COMPLETED/STARTED, the pipeline skips
|
||||||
|
4. If `use_pipeline_cache=False`, the pipeline always re-executes regardless of status
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
from cognee.modules.pipelines import run_pipeline
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
||||||
|
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
||||||
|
reset_dataset_pipeline_run_status,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionCounter:
|
||||||
|
"""Helper class to track task execution counts."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def create_counting_task(data, counter: ExecutionCounter):
|
||||||
|
"""Create a task that increments a counter from the ExecutionCounter instance when executed."""
|
||||||
|
counter.count += 1
|
||||||
|
return counter
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineCache:
|
||||||
|
"""Tests for basic pipeline_cache on/off behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_cache_off_allows_reexecution(self):
|
||||||
|
"""
|
||||||
|
Test that with use_pipeline_cache=False, the pipeline re-executes
|
||||||
|
even when it has already completed for the dataset.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- First run: Pipeline executes fully, task runs once
|
||||||
|
- Second run: Pipeline executes again, task runs again (total: 2 times)
|
||||||
|
"""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await create_db_and_tables()
|
||||||
|
|
||||||
|
counter = ExecutionCounter()
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
tasks = [Task(create_counting_task, counter=counter)]
|
||||||
|
|
||||||
|
# First run
|
||||||
|
pipeline_results_1 = []
|
||||||
|
async for result in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets="test_dataset_cache_off",
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
user=user,
|
||||||
|
pipeline_name="test_cache_off_pipeline",
|
||||||
|
use_pipeline_cache=False,
|
||||||
|
):
|
||||||
|
pipeline_results_1.append(result)
|
||||||
|
|
||||||
|
first_run_count = counter.count
|
||||||
|
assert first_run_count >= 1, "Task should have executed at least once on first run"
|
||||||
|
|
||||||
|
# Second run with pipeline_cache=False
|
||||||
|
pipeline_results_2 = []
|
||||||
|
async for result in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets="test_dataset_cache_off",
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
user=user,
|
||||||
|
pipeline_name="test_cache_off_pipeline",
|
||||||
|
use_pipeline_cache=False,
|
||||||
|
):
|
||||||
|
pipeline_results_2.append(result)
|
||||||
|
|
||||||
|
second_run_count = counter.count
|
||||||
|
assert second_run_count > first_run_count, (
|
||||||
|
f"With pipeline_cache=False, task should re-execute. "
|
||||||
|
f"First run: {first_run_count}, After second run: {second_run_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_pipeline_status_allows_reexecution_with_cache(self):
|
||||||
|
"""
|
||||||
|
Test that resetting pipeline status allows re-execution even with
|
||||||
|
pipeline_cache=True.
|
||||||
|
"""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await create_db_and_tables()
|
||||||
|
|
||||||
|
counter = ExecutionCounter()
|
||||||
|
user = await get_default_user()
|
||||||
|
dataset_name = "reset_status_test"
|
||||||
|
pipeline_name = "test_reset_pipeline"
|
||||||
|
|
||||||
|
tasks = [Task(create_counting_task, counter=counter)]
|
||||||
|
|
||||||
|
# First run
|
||||||
|
pipeline_result = []
|
||||||
|
async for result in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets=dataset_name,
|
||||||
|
user=user,
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
use_pipeline_cache=True,
|
||||||
|
):
|
||||||
|
pipeline_result.append(result)
|
||||||
|
|
||||||
|
first_run_count = counter.count
|
||||||
|
assert first_run_count >= 1
|
||||||
|
|
||||||
|
# Second run without reset - should skip
|
||||||
|
async for _ in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets=dataset_name,
|
||||||
|
user=user,
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
use_pipeline_cache=True,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
after_second_run = counter.count
|
||||||
|
assert after_second_run == first_run_count, "Should have skipped due to cache"
|
||||||
|
|
||||||
|
# Reset the pipeline status
|
||||||
|
await reset_dataset_pipeline_run_status(
|
||||||
|
pipeline_result[0].dataset_id, user, pipeline_names=[pipeline_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Third run after reset - should execute
|
||||||
|
async for _ in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets=dataset_name,
|
||||||
|
user=user,
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
use_pipeline_cache=True,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
after_reset_run = counter.count
|
||||||
|
assert after_reset_run > after_second_run, (
|
||||||
|
f"After reset, pipeline should re-execute. "
|
||||||
|
f"Before reset: {after_second_run}, After reset run: {after_reset_run}"
|
||||||
|
)
|
||||||
Loading…
Add table
Reference in a new issue