Merge branch 'dev' into add-s3-permissions-test

This commit is contained in:
Igor Ilic 2025-12-12 13:22:50 +01:00 committed by GitHub
commit 0cde551226
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1461 additions and 179 deletions

View file

@ -97,6 +97,8 @@ DB_NAME=cognee_db
# Default (local file-based)
GRAPH_DATABASE_PROVIDER="kuzu"
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
#GRAPH_DATABASE_PROVIDER="kuzu"
@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb"
# Not needed if a cloud vector database is not used
VECTOR_DB_URL=
VECTOR_DB_KEY=
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
################################################################################
# 🧩 Ontology resolver settings

View file

@ -61,6 +61,7 @@ jobs:
- name: Run Neo4j Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -142,6 +143,7 @@ jobs:
- name: Run PGVector Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -47,6 +47,7 @@ jobs:
- name: Run Distributed Cognee (Modal)
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -147,6 +147,7 @@ jobs:
- name: Run Deduplication Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
@ -211,6 +212,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_parallel_databases.py
test-dataset-database-handler:
name: Test dataset database handlers in Cognee
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run dataset databases handler test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_dataset_database_handler.py
test-permissions:
name: Test permissions with different situations in Cognee
runs-on: ubuntu-22.04
@ -556,3 +582,30 @@ jobs:
DB_USERNAME: cognee
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_conversation_history.py
run-pipeline-cache-test:
name: Test Pipeline Caching
runs-on: ubuntu-22.04
steps:
- name: Check out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Pipeline Cache Test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_pipeline_cache.py

View file

@ -72,6 +72,7 @@ jobs:
- name: Run Descriptive Graph Metrics Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -78,6 +78,7 @@ jobs:
- name: Run default Neo4j
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -72,6 +72,7 @@ jobs:
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@ -123,6 +124,7 @@ jobs:
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
env:
ENV: dev
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@ -189,6 +191,7 @@ jobs:
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
env:
ENV: dev
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View file

@ -92,6 +92,7 @@ jobs:
- name: Run PGVector Tests
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -127,4 +128,4 @@ jobs:
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_lancedb.py
run: uv run python ./cognee/tests/test_lancedb.py

View file

@ -94,6 +94,7 @@ jobs:
- name: Run Weighted Edges Tests
env:
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
@ -165,5 +166,3 @@ jobs:
uses: astral-sh/ruff-action@v2
with:
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"

View file

@ -0,0 +1,333 @@
"""Expand dataset database with json connection field
Revision ID: 46a6ce2bd2b2
Revises: 76625596c5c3
Create Date: 2025-11-25 17:56:28.938931
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "46a6ce2bd2b2"
down_revision: Union[str, None] = "76625596c5c3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
graph_constraint_name = "dataset_database_graph_database_name_key"
vector_constraint_name = "dataset_database_vector_database_name_key"
TABLE_NAME = "dataset_database"
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def _recreate_table_without_unique_constraint_sqlite(op, insp):
"""
SQLite cannot drop unique constraints on individual columns. We must:
1. Create a new table without the unique constraints.
2. Copy data from the old table.
3. Drop the old table.
4. Rename the new table.
"""
conn = op.get_bind()
# Create new table definition (without unique constraints)
op.create_table(
f"{TABLE_NAME}_new",
sa.Column("owner_id", sa.UUID()),
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
sa.Column("vector_database_name", sa.String(), nullable=False),
sa.Column("graph_database_name", sa.String(), nullable=False),
sa.Column("vector_database_provider", sa.String(), nullable=False),
sa.Column("graph_database_provider", sa.String(), nullable=False),
sa.Column(
"vector_dataset_database_handler",
sa.String(),
unique=False,
nullable=False,
server_default="lancedb",
),
sa.Column(
"graph_dataset_database_handler",
sa.String(),
unique=False,
nullable=False,
server_default="kuzu",
),
sa.Column("vector_database_url", sa.String()),
sa.Column("graph_database_url", sa.String()),
sa.Column("vector_database_key", sa.String()),
sa.Column("graph_database_key", sa.String()),
sa.Column(
"graph_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column(
"vector_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column("created_at", sa.DateTime()),
sa.Column("updated_at", sa.DateTime()),
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
)
# Copy data into new table
conn.execute(
sa.text(f"""
INSERT INTO {TABLE_NAME}_new
SELECT
owner_id,
dataset_id,
vector_database_name,
graph_database_name,
vector_database_provider,
graph_database_provider,
vector_dataset_database_handler,
graph_dataset_database_handler,
vector_database_url,
graph_database_url,
vector_database_key,
graph_database_key,
COALESCE(graph_database_connection_info, '{{}}'),
COALESCE(vector_database_connection_info, '{{}}'),
created_at,
updated_at
FROM {TABLE_NAME}
""")
)
# Drop old table
op.drop_table(TABLE_NAME)
# Rename new table
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
def _recreate_table_with_unique_constraint_sqlite(op, insp):
"""
SQLite cannot drop unique constraints on individual columns. We must:
1. Create a new table without the unique constraints.
2. Copy data from the old table.
3. Drop the old table.
4. Rename the new table.
"""
conn = op.get_bind()
# Create new table definition (without unique constraints)
op.create_table(
f"{TABLE_NAME}_new",
sa.Column("owner_id", sa.UUID()),
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
sa.Column("vector_database_provider", sa.String(), nullable=False),
sa.Column("graph_database_provider", sa.String(), nullable=False),
sa.Column(
"vector_dataset_database_handler",
sa.String(),
unique=False,
nullable=False,
server_default="lancedb",
),
sa.Column(
"graph_dataset_database_handler",
sa.String(),
unique=False,
nullable=False,
server_default="kuzu",
),
sa.Column("vector_database_url", sa.String()),
sa.Column("graph_database_url", sa.String()),
sa.Column("vector_database_key", sa.String()),
sa.Column("graph_database_key", sa.String()),
sa.Column(
"graph_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column(
"vector_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column("created_at", sa.DateTime()),
sa.Column("updated_at", sa.DateTime()),
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
)
# Copy data into new table
conn.execute(
sa.text(f"""
INSERT INTO {TABLE_NAME}_new
SELECT
owner_id,
dataset_id,
vector_database_name,
graph_database_name,
vector_database_provider,
graph_database_provider,
vector_dataset_database_handler,
graph_dataset_database_handler,
vector_database_url,
graph_database_url,
vector_database_key,
graph_database_key,
COALESCE(graph_database_connection_info, '{{}}'),
COALESCE(vector_database_connection_info, '{{}}'),
created_at,
updated_at
FROM {TABLE_NAME}
""")
)
# Drop old table
op.drop_table(TABLE_NAME)
# Rename new table
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
vector_database_connection_info_column = _get_column(
insp, "dataset_database", "vector_database_connection_info"
)
if not vector_database_connection_info_column:
op.add_column(
"dataset_database",
sa.Column(
"vector_database_connection_info",
sa.JSON(),
unique=False,
nullable=False,
server_default=sa.text("'{}'"),
),
)
vector_dataset_database_handler = _get_column(
insp, "dataset_database", "vector_dataset_database_handler"
)
if not vector_dataset_database_handler:
# Add LanceDB as the default graph dataset database handler
op.add_column(
"dataset_database",
sa.Column(
"vector_dataset_database_handler",
sa.String(),
unique=False,
nullable=False,
server_default="lancedb",
),
)
graph_database_connection_info_column = _get_column(
insp, "dataset_database", "graph_database_connection_info"
)
if not graph_database_connection_info_column:
op.add_column(
"dataset_database",
sa.Column(
"graph_database_connection_info",
sa.JSON(),
unique=False,
nullable=False,
server_default=sa.text("'{}'"),
),
)
graph_dataset_database_handler = _get_column(
insp, "dataset_database", "graph_dataset_database_handler"
)
if not graph_dataset_database_handler:
# Add Kuzu as the default graph dataset database handler
op.add_column(
"dataset_database",
sa.Column(
"graph_dataset_database_handler",
sa.String(),
unique=False,
nullable=False,
server_default="kuzu",
),
)
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
# Drop the unique constraint to make unique=False
graph_constraint_to_drop = None
for uc in unique_constraints:
# Check if the constraint covers ONLY the target column
if uc["name"] == graph_constraint_name:
graph_constraint_to_drop = uc["name"]
break
vector_constraint_to_drop = None
for uc in unique_constraints:
# Check if the constraint covers ONLY the target column
if uc["name"] == vector_constraint_name:
vector_constraint_to_drop = uc["name"]
break
if (
vector_constraint_to_drop
and graph_constraint_to_drop
and op.get_context().dialect.name == "postgresql"
):
# PostgreSQL
batch_op.drop_constraint(graph_constraint_name, type_="unique")
batch_op.drop_constraint(vector_constraint_name, type_="unique")
if op.get_context().dialect.name == "sqlite":
conn = op.get_bind()
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
# So we need to check for them and drop them by recreating the table (altering column also won't work)
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
rows = result.fetchall()
unique_auto_indexes = [row for row in rows if row[3] == "u"]
for row in unique_auto_indexes:
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
index_info = result.fetchall()
if index_info[0][2] == "vector_database_name":
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
_recreate_table_without_unique_constraint_sqlite(op, insp)
def downgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
if op.get_context().dialect.name == "sqlite":
_recreate_table_with_unique_constraint_sqlite(op, insp)
elif op.get_context().dialect.name == "postgresql":
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
# Re-add the unique constraint to return to unique=True
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
# Re-add the unique constraint to return to unique=True
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
op.drop_column("dataset_database", "vector_database_connection_info")
op.drop_column("dataset_database", "graph_database_connection_info")
op.drop_column("dataset_database", "vector_dataset_database_handler")
op.drop_column("dataset_database", "graph_dataset_database_handler")

View file

@ -205,6 +205,7 @@ async def add(
pipeline_name="add_pipeline",
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
use_pipeline_cache=True,
incremental_loading=incremental_loading,
data_per_batch=data_per_batch,
):

View file

@ -20,7 +20,6 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
from cognee.modules.users.models import User
from cognee.tasks.documents import (
check_permissions_on_dataset,
classify_documents,
extract_chunks_from_documents,
)
@ -79,12 +78,11 @@ async def cognify(
Processing Pipeline:
1. **Document Classification**: Identifies document types and structures
2. **Permission Validation**: Ensures user has processing rights
3. **Text Chunking**: Breaks content into semantically meaningful segments
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
5. **Relationship Detection**: Discovers connections between entities
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
7. **Content Summarization**: Creates hierarchical summaries for navigation
2. **Text Chunking**: Breaks content into semantically meaningful segments
3. **Entity Extraction**: Identifies key concepts, people, places, organizations
4. **Relationship Detection**: Discovers connections between entities
5. **Graph Construction**: Builds semantic knowledge graph with embeddings
6. **Content Summarization**: Creates hierarchical summaries for navigation
Graph Model Customization:
The `graph_model` parameter allows custom knowledge structures:
@ -239,6 +237,7 @@ async def cognify(
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
incremental_loading=incremental_loading,
use_pipeline_cache=True,
pipeline_name="cognify_pipeline",
data_per_batch=data_per_batch,
)
@ -278,7 +277,6 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),
@ -313,14 +311,13 @@ async def get_temporal_tasks(
The pipeline includes:
1. Document classification.
2. Dataset permission checks (requires "write" access).
3. Document chunking with a specified or default chunk size.
4. Event and timestamp extraction from chunks.
5. Knowledge graph extraction from events.
6. Batched insertion of data points.
2. Document chunking with a specified or default chunk size.
3. Event and timestamp extraction from chunks.
4. Knowledge graph extraction from events.
5. Batched insertion of data points.
Args:
user (User, optional): The user requesting task execution, used for permission checks.
user (User, optional): The user requesting task execution.
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
@ -333,7 +330,6 @@ async def get_temporal_tasks(
temporal_tasks = [
Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),

View file

@ -5,6 +5,7 @@ from pathlib import Path
from datetime import datetime, timezone
from typing import Optional, List
from dataclasses import dataclass
from fastapi import UploadFile
@dataclass
@ -45,8 +46,10 @@ class OntologyService:
json.dump(metadata, f, indent=2)
async def upload_ontology(
self, ontology_key: str, file, user, description: Optional[str] = None
self, ontology_key: str, file: UploadFile, user, description: Optional[str] = None
) -> OntologyMetadata:
if not file.filename:
raise ValueError("File must have a filename")
if not file.filename.lower().endswith(".owl"):
raise ValueError("File must be in .owl format")
@ -57,8 +60,6 @@ class OntologyService:
raise ValueError(f"Ontology key '{ontology_key}' already exists")
content = await file.read()
if len(content) > 10 * 1024 * 1024:
raise ValueError("File size exceeds 10MB limit")
file_path = user_dir / f"{ontology_key}.owl"
with open(file_path, "wb") as f:
@ -82,7 +83,11 @@ class OntologyService:
)
async def upload_ontologies(
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
self,
ontology_key: List[str],
files: List[UploadFile],
user,
descriptions: Optional[List[str]] = None,
) -> List[OntologyMetadata]:
"""
Upload ontology files with their respective keys.
@ -105,47 +110,17 @@ class OntologyService:
if len(set(ontology_key)) != len(ontology_key):
raise ValueError("Duplicate ontology keys not allowed")
if descriptions and len(descriptions) != len(files):
raise ValueError("Number of descriptions must match number of files")
results = []
user_dir = self._get_user_dir(str(user.id))
metadata = self._load_metadata(user_dir)
for i, (key, file) in enumerate(zip(ontology_key, files)):
if key in metadata:
raise ValueError(f"Ontology key '{key}' already exists")
if not file.filename.lower().endswith(".owl"):
raise ValueError(f"File '{file.filename}' must be in .owl format")
content = await file.read()
if len(content) > 10 * 1024 * 1024:
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
file_path = user_dir / f"{key}.owl"
with open(file_path, "wb") as f:
f.write(content)
ontology_metadata = {
"filename": file.filename,
"size_bytes": len(content),
"uploaded_at": datetime.now(timezone.utc).isoformat(),
"description": descriptions[i] if descriptions else None,
}
metadata[key] = ontology_metadata
results.append(
OntologyMetadata(
await self.upload_ontology(
ontology_key=key,
filename=file.filename,
size_bytes=len(content),
uploaded_at=ontology_metadata["uploaded_at"],
file=file,
user=user,
description=descriptions[i] if descriptions else None,
)
)
self._save_metadata(user_dir, metadata)
return results
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:

View file

@ -4,9 +4,10 @@ from typing import Union
from uuid import UUID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
async def set_session_user_context_variable(user):
session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
graph_db_config = get_graph_config()
vector_db_config = get_vectordb_config()
graph_handler = graph_db_config.graph_dataset_database_handler
vector_handler = vector_db_config.vector_dataset_database_handler
from cognee.infrastructure.databases.dataset_database_handler import (
supported_dataset_database_handlers,
)
if graph_handler not in supported_dataset_database_handlers:
raise EnvironmentError(
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected graph dataset to database handler: {graph_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
if vector_handler not in supported_dataset_database_handlers:
raise EnvironmentError(
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected vector dataset to database handler: {vector_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
if (
supported_dataset_database_handlers[graph_handler]["handler_provider"]
!= graph_db_config.graph_database_provider
):
raise EnvironmentError(
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
f"Selected graph dataset to database handler: {graph_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
if (
supported_dataset_database_handlers[vector_handler]["handler_provider"]
!= vector_db_config.vector_db_provider
):
raise EnvironmentError(
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
f"Selected vector dataset to database handler: {vector_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
return True
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
@ -41,12 +79,7 @@ def backend_access_control_enabled():
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return multi_user_support_possible()
return False
@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user)
# Ensure that all connection info is resolved properly
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
base_config = get_base_config()
data_root_directory = os.path.join(
@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
)
# Set vector and graph database configuration based on dataset database information
# TODO: Add better handling of vector and graph config accross Cognee.
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
vector_config = {
"vector_db_provider": dataset_database.vector_database_provider,
"vector_db_url": dataset_database.vector_database_url,
@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
"graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name
),
"graph_database_username": dataset_database.graph_database_connection_info.get(
"graph_database_username", ""
),
"graph_database_password": dataset_database.graph_database_connection_info.get(
"graph_database_password", ""
),
"graph_dataset_database_handler": "",
"graph_database_port": "",
}
storage_config = {

View file

@ -8,7 +8,6 @@ from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.utils import send_telemetry
from cognee.tasks.documents import (
check_permissions_on_dataset,
classify_documents,
extract_chunks_from_documents,
)
@ -31,7 +30,6 @@ async def get_cascade_graph_tasks(
cognee_config = get_cognify_config()
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
), # Extract text chunks based on the document type.

View file

@ -30,8 +30,8 @@ async def get_no_summary_tasks(
ontology_file_path=None,
) -> List[Task]:
"""Returns default tasks without summarization tasks."""
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
# Get base tasks (0=classify, 1=extract_chunks)
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path)
@ -51,8 +51,8 @@ async def get_just_chunks_tasks(
chunk_size: int = None, chunker=TextChunker, user=None
) -> List[Task]:
"""Returns default tasks with only chunk extraction and data points addition."""
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
# Get base tasks (0=classify, 1=extract_chunks)
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})

View file

@ -0,0 +1,3 @@
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
from .supported_dataset_database_handlers import supported_dataset_database_handlers
from .use_dataset_database_handler import use_dataset_database_handler

View file

@ -0,0 +1,80 @@
from typing import Optional
from uuid import UUID
from abc import ABC, abstractmethod
from cognee.modules.users.models.User import User
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
class DatasetDatabaseHandlerInterface(ABC):
@classmethod
@abstractmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
Function can auto handle deploying of the actual database if needed, but is not necessary.
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
Needed for Cognee multi-tenant/multi-user and backend access control support.
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
From which internal mapping of dataset -> database connection info will be done.
The returned dictionary is stored verbatim in the relational database and is later passed to
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
returning only references to secrets or role identifiers, not plaintext credentials.
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
Args:
dataset_id: UUID of the dataset if needed by the database creation logic
user: User object if needed by the database creation logic
Returns:
dict: Connection info for the created graph or vector database instance.
"""
pass
@classmethod
async def resolve_dataset_connection_info(
cls, dataset_database: DatasetDatabase
) -> DatasetDatabase:
"""
Resolve runtime connection details for a datasets backing graph/vector database.
Function is intended to be overwritten to implement custom logic for resolving connection info.
This method is invoked right before the application opens a connection for a given dataset.
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
connection info and only change parameters related to its appropriate database, the resolution function will then
be called one after another with the updated DatasetDatabase value from the previous function as the input.
Typical behavior:
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
or api_url/api_key), return them as-is.
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
to the caller.
Args:
dataset_database: DatasetDatabase row from the relational database
Returns:
DatasetDatabase: Updated instance with resolved connection info
"""
return dataset_database
@classmethod
@abstractmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
"""
Delete the graph or vector database for the given dataset.
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
Args:
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
"""
pass

View file

@ -0,0 +1,18 @@
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
Neo4jAuraDevDatasetDatabaseHandler,
)
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
LanceDBDatasetDatabaseHandler,
)
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
KuzuDatasetDatabaseHandler,
)
supported_dataset_database_handlers = {
"neo4j_aura_dev": {
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
"handler_provider": "neo4j",
},
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
}

View file

@ -0,0 +1,10 @@
from .supported_dataset_database_handlers import supported_dataset_database_handlers
def use_dataset_database_handler(
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
):
supported_dataset_database_handlers[dataset_database_handler_name] = {
"handler_instance": dataset_database_handler,
"handler_provider": dataset_database_provider,
}

View file

@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
graph_filename: str = ""
graph_model: object = KnowledgeGraph
graph_topology: object = KnowledgeGraph
graph_dataset_database_handler: str = "kuzu"
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
"model_config": self.model_config,
"graph_dataset_database_handler": self.graph_dataset_database_handler,
}
def to_hashable_dict(self) -> dict:
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
"graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
"graph_dataset_database_handler": self.graph_dataset_database_handler,
}

View file

@ -34,6 +34,7 @@ def create_graph_engine(
graph_database_password="",
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Create a graph engine based on the specified provider type.

View file

@ -0,0 +1,81 @@
import os
from uuid import UUID
from typing import Optional
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.base_config import get_base_config
from cognee.modules.users.models import User
from cognee.modules.users.models import DatasetDatabase
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for interacting with Kuzu Dataset databases.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
Args:
dataset_id: Dataset UUID
user: User object who owns the dataset and is making the request
Returns:
dict: Connection details for the created Kuzu instance
"""
from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
if graph_config.graph_database_provider != "kuzu":
raise ValueError(
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
)
graph_db_name = f"{dataset_id}.pkl"
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": graph_config.graph_database_provider,
"graph_database_key": graph_db_key,
"graph_dataset_database_handler": "kuzu",
"graph_database_connection_info": {
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
},
}
@classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase):
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
)
graph_file_path = os.path.join(
databases_directory_path, dataset_database.graph_database_name
)
graph_engine = create_graph_engine(
graph_database_provider=dataset_database.graph_database_provider,
graph_database_url=dataset_database.graph_database_url,
graph_database_name=dataset_database.graph_database_name,
graph_database_key=dataset_database.graph_database_key,
graph_file_path=graph_file_path,
graph_database_username=dataset_database.graph_database_connection_info.get(
"graph_database_username", ""
),
graph_database_password=dataset_database.graph_database_connection_info.get(
"graph_database_password", ""
),
graph_dataset_database_handler="",
graph_database_port="",
)
await graph_engine.delete_graph()

View file

@ -0,0 +1,168 @@
import os
import asyncio
import requests
import base64
import hashlib
from uuid import UUID
from typing import Optional
from cryptography.fernet import Fernet
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.users.models import User, DatasetDatabase
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
Improvements needed to be production ready:
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
a secret manager or a similar system should be used instead.
Quality of life improvements:
- Allow configuration of different Neo4j Aura plans and regions.
- Requests should be made async, currently a blocking requests library is used.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
Args:
dataset_id: Dataset UUID
user: User object who owns the dataset and is making the request
Returns:
dict: Connection details for the created Neo4j instance
"""
graph_config = get_graph_config()
if graph_config.graph_database_provider != "neo4j":
raise ValueError(
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
)
graph_db_name = f"{dataset_id}"
# Client credentials and encryption
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
if client_id is None or client_secret is None or tenant_id is None:
raise ValueError(
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
)
# Make the request with HTTP Basic Auth
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
# Too allow different configurations between datasets
payload = {
"version": "5",
"region": "europe-west1",
"memory": "1GB",
"name": graph_db_name[
0:29
], # TODO: Find better name to name Neo4j instance within 30 character limit
"type": "professional-db",
"tenant_id": tenant_id,
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
# Poll until the instance is running
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = ""
for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(
status_url, headers=headers
) # TODO: Use async requests with httpx
status = status_resp.json()["data"]["status"]
if status.lower() == "running":
return
await asyncio.sleep(10)
raise TimeoutError(
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
)
instance_id = response.json()["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
encrypted_db_password_string = encrypted_db_password_bytes.decode()
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": "neo4j",
"graph_database_key": graph_db_key,
"graph_dataset_database_handler": "neo4j_aura_dev",
"graph_database_connection_info": {
"graph_database_username": graph_db_username,
"graph_database_password": encrypted_db_password_string,
},
}
@classmethod
async def resolve_dataset_connection_info(
cls, dataset_database: DatasetDatabase
) -> DatasetDatabase:
"""
Resolve and decrypt connection info for the Neo4j dataset database.
In this case, decrypt the password stored in the database.
Args:
dataset_database: DatasetDatabase instance containing encrypted connection info.
"""
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
graph_db_password = cipher.decrypt(
dataset_database.graph_database_connection_info["graph_database_password"].encode()
).decode()
dataset_database.graph_database_connection_info["graph_database_password"] = (
graph_db_password
)
return dataset_database
@classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase):
pass

View file

@ -1 +1,2 @@
from .get_or_create_dataset_database import get_or_create_dataset_database
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info

View file

@ -1,11 +1,9 @@
import os
from uuid import UUID
from typing import Union
from typing import Union, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from cognee.base_config import get_base_config
from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vectordb_config
@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
vector_config = get_vectordb_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
return await handler["handler_instance"].create_dataset(dataset_id, user)
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
graph_config = get_graph_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
return await handler["handler_instance"].create_dataset(dataset_id, user)
async def _existing_dataset_database(
dataset_id: UUID,
user: User,
) -> Optional[DatasetDatabase]:
"""
Check if a DatasetDatabase row already exists for the given owner + dataset.
Return None if it doesn't exist, return the row if it does.
Args:
dataset_id:
user:
Returns:
DatasetDatabase or None
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
stmt = select(DatasetDatabase).where(
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id,
)
existing: DatasetDatabase = await session.scalar(stmt)
return existing
async def get_or_create_dataset_database(
dataset: Union[str, UUID],
user: User,
@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
If the row already exists, it is fetched and returned.
Otherwise a new one is created atomically and returned.
DatasetDatabase row contains connection and provider info for vector and graph databases.
Parameters
----------
user : User
@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user)
vector_config = get_vectordb_config()
graph_config = get_graph_config()
# If dataset is given as name make sure the dataset is created first
if isinstance(dataset, str):
async with db_engine.get_async_session() as session:
await create_dataset(dataset, user, session)
# Note: for hybrid databases both graph and vector DB name have to be the same
if graph_config.graph_database_provider == "kuzu":
graph_db_name = f"{dataset_id}.pkl"
else:
graph_db_name = f"{dataset_id}"
# If dataset database already exists return it
existing_dataset_database = await _existing_dataset_database(dataset_id, user)
if existing_dataset_database:
return existing_dataset_database
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}"
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector database URL
if vector_config.vector_db_provider == "lancedb":
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
else:
vector_db_url = vector_config.vector_database_url
# Determine graph database URL
graph_config_dict = await _get_graph_db_info(dataset_id, user)
vector_config_dict = await _get_vector_db_info(dataset_id, user)
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
if isinstance(dataset, str):
dataset = await create_dataset(dataset, user, session)
# Try to fetch an existing row first
stmt = select(DatasetDatabase).where(
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id,
)
existing: DatasetDatabase = await session.scalar(stmt)
if existing:
return existing
# If there are no existing rows build a new row
record = DatasetDatabase(
owner_id=user.id,
dataset_id=dataset_id,
vector_database_name=vector_db_name,
graph_database_name=graph_db_name,
vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_db_url,
graph_database_url=graph_config.graph_database_url,
vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_config.graph_database_key,
**graph_config_dict, # Unpack graph db config
**vector_config_dict, # Unpack vector db config
)
try:

View file

@ -0,0 +1,36 @@
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def resolve_dataset_database_connection_info(
dataset_database: DatasetDatabase,
) -> DatasetDatabase:
"""
Resolve the connection info for the given DatasetDatabase instance.
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
Args:
dataset_database: DatasetDatabase instance
Returns:
DatasetDatabase instance with resolved connection info
"""
dataset_database = await _get_vector_db_connection_info(dataset_database)
dataset_database = await _get_graph_db_connection_info(dataset_database)
return dataset_database

View file

@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
vector_db_name: str = ""
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
vector_dataset_database_handler: str = "lancedb"
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
"vector_dataset_database_handler": self.vector_dataset_database_handler,
}

View file

@ -12,6 +12,7 @@ def create_vector_engine(
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Create a vector database engine based on the specified provider.

View file

@ -0,0 +1,50 @@
import os
from uuid import UUID
from typing import Optional
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.modules.users.models import User
from cognee.modules.users.models import DatasetDatabase
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for interacting with LanceDB Dataset databases.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
vector_config = get_vectordb_config()
base_config = get_base_config()
if vector_config.vector_db_provider != "lancedb":
raise ValueError(
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
)
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
vector_db_name = f"{dataset_id}.lance.db"
return {
"vector_database_provider": vector_config.vector_db_provider,
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
"vector_database_key": vector_config.vector_db_key,
"vector_database_name": vector_db_name,
"vector_dataset_database_handler": "lancedb",
}
@classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase):
vector_engine = create_vector_engine(
vector_db_provider=dataset_database.vector_database_provider,
vector_db_url=dataset_database.vector_database_url,
vector_db_key=dataset_database.vector_database_key,
vector_db_name=dataset_database.vector_database_name,
)
await vector_engine.prune()

View file

@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
from abc import abstractmethod
from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema
from uuid import UUID
from cognee.modules.users.models import User
class VectorDBInterface(Protocol):
@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
- Any: The schema object suitable for this vector database
"""
return model_type
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Return a dictionary with connection info for a vector database for the given dataset.
Function can auto handle deploying of the actual database if needed, but is not necessary.
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
Needed for Cognee multi-tenant/multi-user and backend access control support.
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
From which internal mapping of dataset -> database connection info will be done.
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
Args:
dataset_id: UUID of the dataset if needed by the database creation logic
user: User object if needed by the database creation logic
Returns:
dict: Connection info for the created vector database instance.
"""
pass
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
"""
Delete the vector database for the given dataset.
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
Args:
dataset_id: UUID of the dataset
user: User object
"""
pass

View file

@ -1,17 +1,81 @@
from sqlalchemy.exc import OperationalError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.shared.cache import delete_cache
from cognee.modules.users.models import DatasetDatabase
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def prune_graph_databases():
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[
dataset_database.graph_dataset_database_handler
]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine()
try:
data = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the graph database
for data_item in data:
await _prune_graph_db(data_item)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
e,
)
return
async def prune_vector_databases():
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[
dataset_database.vector_dataset_database_handler
]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine()
try:
data = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the vector database
for data_item in data:
await _prune_vector_db(data_item)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
e,
)
return
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
if graph:
# Note: prune system should not be available through the API, it has no permission checks and will
# delete all graph and vector databases if called. It should only be used in development or testing environments.
if graph and not backend_access_control_enabled():
graph_engine = await get_graph_engine()
await graph_engine.delete_graph()
elif graph and backend_access_control_enabled():
await prune_graph_databases()
if vector:
if vector and not backend_access_control_enabled():
vector_engine = get_vector_engine()
await vector_engine.prune()
elif vector and backend_access_control_enabled():
await prune_vector_databases()
if metadata:
db_engine = get_relational_engine()

View file

@ -12,9 +12,6 @@ from cognee.modules.users.models import User
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
resolve_authorized_user_datasets,
)
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
reset_dataset_pipeline_run_status,
)
from cognee.modules.engine.operations.setup import setup
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
@ -97,10 +94,6 @@ async def memify(
*enrichment_tasks,
]
await reset_dataset_pipeline_run_status(
authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
)
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
@ -113,6 +106,7 @@ async def memify(
datasets=authorized_dataset.id,
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
use_pipeline_cache=False,
incremental_loading=False,
pipeline_name="memify_pipeline",
)

View file

@ -20,6 +20,9 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
check_pipeline_run_qualification,
)
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunStarted,
)
from typing import Any
logger = get_logger("cognee.pipeline")
@ -35,6 +38,7 @@ async def run_pipeline(
pipeline_name: str = "custom_pipeline",
vector_db_config: dict = None,
graph_db_config: dict = None,
use_pipeline_cache: bool = False,
incremental_loading: bool = False,
data_per_batch: int = 20,
):
@ -51,6 +55,7 @@ async def run_pipeline(
data=data,
pipeline_name=pipeline_name,
context={"dataset": dataset},
use_pipeline_cache=use_pipeline_cache,
incremental_loading=incremental_loading,
data_per_batch=data_per_batch,
):
@ -64,6 +69,7 @@ async def run_pipeline_per_dataset(
data=None,
pipeline_name: str = "custom_pipeline",
context: dict = None,
use_pipeline_cache=False,
incremental_loading=False,
data_per_batch: int = 20,
):
@ -77,8 +83,18 @@ async def run_pipeline_per_dataset(
if process_pipeline_status:
# If pipeline was already processed or is currently being processed
# return status information to async generator and finish execution
yield process_pipeline_status
return
if use_pipeline_cache:
# If pipeline caching is enabled we do not proceed with re-processing
yield process_pipeline_status
return
else:
# If pipeline caching is disabled we always return pipeline started information and proceed with re-processing
yield PipelineRunStarted(
pipeline_run_id=process_pipeline_status.pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=data,
)
pipeline_run = run_tasks(
tasks,

View file

@ -18,6 +18,8 @@ async def run_custom_pipeline(
user: User = None,
vector_db_config: Optional[dict] = None,
graph_db_config: Optional[dict] = None,
use_pipeline_cache: bool = False,
incremental_loading: bool = False,
data_per_batch: int = 20,
run_in_background: bool = False,
pipeline_name: str = "custom_pipeline",
@ -40,6 +42,10 @@ async def run_custom_pipeline(
user: User context for authentication and data access. Uses default if None.
vector_db_config: Custom vector database configuration for embeddings storage.
graph_db_config: Custom graph database configuration for relationship storage.
use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again.
Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function.
incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model).
The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash.
data_per_batch: Number of data items to be processed in parallel.
run_in_background: If True, starts processing asynchronously and returns immediately.
If False, waits for completion before returning.
@ -63,7 +69,8 @@ async def run_custom_pipeline(
datasets=dataset,
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
incremental_loading=False,
use_pipeline_cache=use_pipeline_cache,
incremental_loading=incremental_loading,
data_per_batch=data_per_batch,
pipeline_name=pipeline_name,
)

View file

@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or backend_access_control_enabled()
os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
)
fastapi_users = get_fastapi_users()

View file

@ -1,6 +1,6 @@
from datetime import datetime, timezone
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text
from cognee.infrastructure.databases.relational import Base
@ -12,17 +12,29 @@ class DatasetDatabase(Base):
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
)
vector_database_name = Column(String, unique=True, nullable=False)
graph_database_name = Column(String, unique=True, nullable=False)
vector_database_name = Column(String, unique=False, nullable=False)
graph_database_name = Column(String, unique=False, nullable=False)
vector_database_provider = Column(String, unique=False, nullable=False)
graph_database_provider = Column(String, unique=False, nullable=False)
graph_dataset_database_handler = Column(String, unique=False, nullable=False)
vector_dataset_database_handler = Column(String, unique=False, nullable=False)
vector_database_url = Column(String, unique=False, nullable=True)
graph_database_url = Column(String, unique=False, nullable=True)
vector_database_key = Column(String, unique=False, nullable=True)
graph_database_key = Column(String, unique=False, nullable=True)
# configuration details for different database types. This would make it more flexible to add new database types
# without changing the database schema.
graph_database_connection_info = Column(
JSON, unique=False, nullable=False, server_default=text("'{}'")
)
vector_database_connection_info = Column(
JSON, unique=False, nullable=False, server_default=text("'{}'")
)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))

View file

@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None):
# Get a configured logger and log system information
logger = structlog.get_logger(name if name else __name__)
logger.warning(
"From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation."
)
if logs_dir is not None:
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)

View file

@ -1,3 +1,2 @@
from .classify_documents import classify_documents
from .extract_chunks_from_documents import extract_chunks_from_documents
from .check_permissions_on_dataset import check_permissions_on_dataset

View file

@ -1,26 +0,0 @@
from cognee.modules.data.processing.document_types import Document
from cognee.modules.users.permissions.methods import check_permission_on_dataset
from typing import List
async def check_permissions_on_dataset(
documents: List[Document], context: dict, user, permissions
) -> List[Document]:
"""
Validates a user's permissions on a list of documents.
Notes:
- This function assumes that `check_permission_on_documents` raises an exception if the permission check fails.
- It is designed to validate multiple permissions in a sequential manner for the same set of documents.
- Ensure that the `Document` and `user` objects conform to the expected structure and interfaces.
"""
for permission in permissions:
await check_permission_on_dataset(
user,
permission,
# TODO: pass dataset through argument instead of context
context["dataset"].id,
)
return documents

View file

@ -0,0 +1,137 @@
import asyncio
import os
# Set custom dataset database handler environment variable
os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler"
os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler"
import cognee
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
from cognee.shared.logging_utils import setup_logging, ERROR
from cognee.api.v1.search import SearchType
class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
@classmethod
async def create_dataset(cls, dataset_id, user):
import pathlib
cognee_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
)
).resolve()
)
databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id))
os.makedirs(databases_directory_path, exist_ok=True)
vector_db_name = "test.lance.db"
return {
"vector_dataset_database_handler": "custom_lancedb_handler",
"vector_database_name": vector_db_name,
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
"vector_database_provider": "lancedb",
}
class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
@classmethod
async def create_dataset(cls, dataset_id, user):
databases_directory_path = os.path.join("databases", str(user.id))
os.makedirs(databases_directory_path, exist_ok=True)
graph_db_name = "test.kuzu"
return {
"graph_dataset_database_handler": "custom_kuzu_handler",
"graph_database_name": graph_db_name,
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
"graph_database_provider": "kuzu",
}
async def main():
import pathlib
data_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler"
)
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
)
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
# Add custom dataset database handler
from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import (
use_dataset_database_handler,
)
use_dataset_database_handler(
"custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb"
)
use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu")
# Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text
text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""
print("Adding text to cognee:")
print(text.strip())
# Add the text, and make it available for cognify
await cognee.add(text)
print("Text added successfully.\n")
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()
print("Cognify process complete.\n")
query_text = "Tell me about NLP"
print(f"Searching cognee for insights with query: '{query_text}'")
# Query cognee for insights on the added text
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
)
print("Search results:")
# Display results
for result_text in search_results:
print(result_text)
default_user = await get_default_user()
# Assert that the custom database files were created based on the custom dataset database handlers
assert os.path.exists(
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu")
), "Graph database file not found."
assert os.path.exists(
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db")
), "Vector database file not found."
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -0,0 +1,164 @@
"""
Test suite for the pipeline_cache feature in Cognee pipelines.
This module tests the behavior of the `pipeline_cache` parameter which controls
whether a pipeline should skip re-execution when it has already been completed
for the same dataset.
Architecture Overview:
---------------------
The pipeline_cache mechanism works at the dataset level:
1. When a pipeline runs, it logs its status (INITIATED -> STARTED -> COMPLETED)
2. Before each run, `check_pipeline_run_qualification()` checks the pipeline status
3. If `use_pipeline_cache=True` and status is COMPLETED/STARTED, the pipeline skips
4. If `use_pipeline_cache=False`, the pipeline always re-executes regardless of status
"""
import pytest
import cognee
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.pipelines import run_pipeline
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
reset_dataset_pipeline_run_status,
)
from cognee.infrastructure.databases.relational import create_db_and_tables
class ExecutionCounter:
"""Helper class to track task execution counts."""
def __init__(self):
self.count = 0
async def create_counting_task(data, counter: ExecutionCounter):
"""Create a task that increments a counter from the ExecutionCounter instance when executed."""
counter.count += 1
return counter
class TestPipelineCache:
"""Tests for basic pipeline_cache on/off behavior."""
@pytest.mark.asyncio
async def test_pipeline_cache_off_allows_reexecution(self):
"""
Test that with use_pipeline_cache=False, the pipeline re-executes
even when it has already completed for the dataset.
Expected behavior:
- First run: Pipeline executes fully, task runs once
- Second run: Pipeline executes again, task runs again (total: 2 times)
"""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()
counter = ExecutionCounter()
user = await get_default_user()
tasks = [Task(create_counting_task, counter=counter)]
# First run
pipeline_results_1 = []
async for result in run_pipeline(
tasks=tasks,
datasets="test_dataset_cache_off",
data=["sample data"], # Data is necessary to trigger processing
user=user,
pipeline_name="test_cache_off_pipeline",
use_pipeline_cache=False,
):
pipeline_results_1.append(result)
first_run_count = counter.count
assert first_run_count >= 1, "Task should have executed at least once on first run"
# Second run with pipeline_cache=False
pipeline_results_2 = []
async for result in run_pipeline(
tasks=tasks,
datasets="test_dataset_cache_off",
data=["sample data"], # Data is necessary to trigger processing
user=user,
pipeline_name="test_cache_off_pipeline",
use_pipeline_cache=False,
):
pipeline_results_2.append(result)
second_run_count = counter.count
assert second_run_count > first_run_count, (
f"With pipeline_cache=False, task should re-execute. "
f"First run: {first_run_count}, After second run: {second_run_count}"
)
@pytest.mark.asyncio
async def test_reset_pipeline_status_allows_reexecution_with_cache(self):
"""
Test that resetting pipeline status allows re-execution even with
pipeline_cache=True.
"""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()
counter = ExecutionCounter()
user = await get_default_user()
dataset_name = "reset_status_test"
pipeline_name = "test_reset_pipeline"
tasks = [Task(create_counting_task, counter=counter)]
# First run
pipeline_result = []
async for result in run_pipeline(
tasks=tasks,
datasets=dataset_name,
user=user,
data=["sample data"], # Data is necessary to trigger processing
pipeline_name=pipeline_name,
use_pipeline_cache=True,
):
pipeline_result.append(result)
first_run_count = counter.count
assert first_run_count >= 1
# Second run without reset - should skip
async for _ in run_pipeline(
tasks=tasks,
datasets=dataset_name,
user=user,
data=["sample data"], # Data is necessary to trigger processing
pipeline_name=pipeline_name,
use_pipeline_cache=True,
):
pass
after_second_run = counter.count
assert after_second_run == first_run_count, "Should have skipped due to cache"
# Reset the pipeline status
await reset_dataset_pipeline_run_status(
pipeline_result[0].dataset_id, user, pipeline_names=[pipeline_name]
)
# Third run after reset - should execute
async for _ in run_pipeline(
tasks=tasks,
datasets=dataset_name,
user=user,
data=["sample data"], # Data is necessary to trigger processing
pipeline_name=pipeline_name,
use_pipeline_cache=True,
):
pass
after_reset_run = counter.count
assert after_reset_run > after_second_run, (
f"After reset, pipeline should re-execute. "
f"Before reset: {after_second_run}, After reset run: {after_reset_run}"
)

View file

@ -32,16 +32,13 @@ async def main():
print("Cognify process steps:")
print("1. Classifying the document: Determining the type and category of the input text.")
print(
"2. Checking permissions: Ensuring the user has the necessary rights to process the text."
"2. Extracting text chunks: Breaking down the text into sentences or phrases for analysis."
)
print(
"3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis."
"3. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph."
)
print("4. Adding data points: Storing the extracted chunks for processing.")
print(
"5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph."
)
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
print("4. Summarizing text: Creating concise summaries of the content for quick insights.")
print("5. Adding data points: Storing the extracted chunks for processing.\n")
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()

View file

@ -591,7 +591,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "7c431fdef4921ae0",
"metadata": {
"ExecuteTime": {
@ -609,7 +609,6 @@
"from cognee.modules.pipelines import run_tasks\n",
"from cognee.modules.users.models import User\n",
"from cognee.tasks.documents import (\n",
" check_permissions_on_dataset,\n",
" classify_documents,\n",
" extract_chunks_from_documents,\n",
")\n",
@ -627,7 +626,6 @@
"\n",
" tasks = [\n",
" Task(classify_documents),\n",
" Task(check_permissions_on_dataset, user=user, permissions=[\"write\"]),\n",
" Task(\n",
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
" ), # Extract text chunks based on the document type.\n",