Merge branch 'dev' into add-s3-permissions-test
This commit is contained in:
commit
0cde551226
44 changed files with 1461 additions and 179 deletions
|
|
@ -97,6 +97,8 @@ DB_NAME=cognee_db
|
|||
|
||||
# Default (local file-based)
|
||||
GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
|
||||
|
||||
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
||||
#GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
|
|
@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb"
|
|||
# Not needed if a cloud vector database is not used
|
||||
VECTOR_DB_URL=
|
||||
VECTOR_DB_KEY=
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
||||
|
||||
################################################################################
|
||||
# 🧩 Ontology resolver settings
|
||||
|
|
|
|||
2
.github/workflows/db_examples_tests.yml
vendored
2
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -61,6 +61,7 @@ jobs:
|
|||
- name: Run Neo4j Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -142,6 +143,7 @@ jobs:
|
|||
- name: Run PGVector Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
1
.github/workflows/distributed_test.yml
vendored
1
.github/workflows/distributed_test.yml
vendored
|
|
@ -47,6 +47,7 @@ jobs:
|
|||
- name: Run Distributed Cognee (Modal)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
53
.github/workflows/e2e_tests.yml
vendored
53
.github/workflows/e2e_tests.yml
vendored
|
|
@ -147,6 +147,7 @@ jobs:
|
|||
- name: Run Deduplication Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
|
|
@ -211,6 +212,31 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_parallel_databases.py
|
||||
|
||||
test-dataset-database-handler:
|
||||
name: Test dataset database handlers in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run dataset databases handler test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
||||
|
||||
test-permissions:
|
||||
name: Test permissions with different situations in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -556,3 +582,30 @@ jobs:
|
|||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||
|
||||
run-pipeline-cache-test:
|
||||
name: Test Pipeline Caching
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Pipeline Cache Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_pipeline_cache.py
|
||||
|
|
|
|||
1
.github/workflows/examples_tests.yml
vendored
1
.github/workflows/examples_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
|||
- name: Run Descriptive Graph Metrics Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
1
.github/workflows/graph_db_tests.yml
vendored
1
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -78,6 +78,7 @@ jobs:
|
|||
- name: Run default Neo4j
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
3
.github/workflows/temporal_graph_tests.yml
vendored
3
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -123,6 +124,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -189,6 +191,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
|
|||
3
.github/workflows/vector_db_tests.yml
vendored
3
.github/workflows/vector_db_tests.yml
vendored
|
|
@ -92,6 +92,7 @@ jobs:
|
|||
- name: Run PGVector Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -127,4 +128,4 @@ jobs:
|
|||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
|
|
|
|||
3
.github/workflows/weighted_edges_tests.yml
vendored
3
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -94,6 +94,7 @@ jobs:
|
|||
|
||||
- name: Run Weighted Edges Tests
|
||||
env:
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||
|
|
@ -165,5 +166,3 @@ jobs:
|
|||
uses: astral-sh/ruff-action@v2
|
||||
with:
|
||||
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""Expand dataset database with json connection field
|
||||
|
||||
Revision ID: 46a6ce2bd2b2
|
||||
Revises: 76625596c5c3
|
||||
Create Date: 2025-11-25 17:56:28.938931
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "46a6ce2bd2b2"
|
||||
down_revision: Union[str, None] = "76625596c5c3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
graph_constraint_name = "dataset_database_graph_database_name_key"
|
||||
vector_constraint_name = "dataset_database_vector_database_name_key"
|
||||
TABLE_NAME = "dataset_database"
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_dataset_database_handler,
|
||||
graph_dataset_database_handler,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_dataset_database_handler,
|
||||
graph_dataset_database_handler,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
|
||||
|
||||
vector_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "vector_database_connection_info"
|
||||
)
|
||||
if not vector_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
vector_dataset_database_handler = _get_column(
|
||||
insp, "dataset_database", "vector_dataset_database_handler"
|
||||
)
|
||||
if not vector_dataset_database_handler:
|
||||
# Add LanceDB as the default graph dataset database handler
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
)
|
||||
|
||||
graph_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "graph_database_connection_info"
|
||||
)
|
||||
if not graph_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
graph_dataset_database_handler = _get_column(
|
||||
insp, "dataset_database", "graph_dataset_database_handler"
|
||||
)
|
||||
if not graph_dataset_database_handler:
|
||||
# Add Kuzu as the default graph dataset database handler
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_dataset_database_handler",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Drop the unique constraint to make unique=False
|
||||
graph_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == graph_constraint_name:
|
||||
graph_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
vector_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == vector_constraint_name:
|
||||
vector_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
if (
|
||||
vector_constraint_to_drop
|
||||
and graph_constraint_to_drop
|
||||
and op.get_context().dialect.name == "postgresql"
|
||||
):
|
||||
# PostgreSQL
|
||||
batch_op.drop_constraint(graph_constraint_name, type_="unique")
|
||||
batch_op.drop_constraint(vector_constraint_name, type_="unique")
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
conn = op.get_bind()
|
||||
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
|
||||
# So we need to check for them and drop them by recreating the table (altering column also won't work)
|
||||
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
|
||||
rows = result.fetchall()
|
||||
unique_auto_indexes = [row for row in rows if row[3] == "u"]
|
||||
for row in unique_auto_indexes:
|
||||
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
|
||||
index_info = result.fetchall()
|
||||
if index_info[0][2] == "vector_database_name":
|
||||
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
|
||||
_recreate_table_without_unique_constraint_sqlite(op, insp)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
_recreate_table_with_unique_constraint_sqlite(op, insp)
|
||||
elif op.get_context().dialect.name == "postgresql":
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
|
||||
|
||||
op.drop_column("dataset_database", "vector_database_connection_info")
|
||||
op.drop_column("dataset_database", "graph_database_connection_info")
|
||||
op.drop_column("dataset_database", "vector_dataset_database_handler")
|
||||
op.drop_column("dataset_database", "graph_dataset_database_handler")
|
||||
|
|
@ -205,6 +205,7 @@ async def add(
|
|||
pipeline_name="add_pipeline",
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
use_pipeline_cache=True,
|
||||
incremental_loading=incremental_loading,
|
||||
data_per_batch=data_per_batch,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
|
|||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
|
|
@ -79,12 +78,11 @@ async def cognify(
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
7. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
2. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
3. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
4. **Relationship Detection**: Discovers connections between entities
|
||||
5. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
6. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
|
||||
Graph Model Customization:
|
||||
The `graph_model` parameter allows custom knowledge structures:
|
||||
|
|
@ -239,6 +237,7 @@ async def cognify(
|
|||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
use_pipeline_cache=True,
|
||||
pipeline_name="cognify_pipeline",
|
||||
data_per_batch=data_per_batch,
|
||||
)
|
||||
|
|
@ -278,7 +277,6 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
|
|
@ -313,14 +311,13 @@ async def get_temporal_tasks(
|
|||
|
||||
The pipeline includes:
|
||||
1. Document classification.
|
||||
2. Dataset permission checks (requires "write" access).
|
||||
3. Document chunking with a specified or default chunk size.
|
||||
4. Event and timestamp extraction from chunks.
|
||||
5. Knowledge graph extraction from events.
|
||||
6. Batched insertion of data points.
|
||||
2. Document chunking with a specified or default chunk size.
|
||||
3. Event and timestamp extraction from chunks.
|
||||
4. Knowledge graph extraction from events.
|
||||
5. Batched insertion of data points.
|
||||
|
||||
Args:
|
||||
user (User, optional): The user requesting task execution, used for permission checks.
|
||||
user (User, optional): The user requesting task execution.
|
||||
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
|
||||
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
|
||||
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
|
||||
|
|
@ -333,7 +330,6 @@ async def get_temporal_tasks(
|
|||
|
||||
temporal_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -45,8 +46,10 @@ class OntologyService:
|
|||
json.dump(metadata, f, indent=2)
|
||||
|
||||
async def upload_ontology(
|
||||
self, ontology_key: str, file, user, description: Optional[str] = None
|
||||
self, ontology_key: str, file: UploadFile, user, description: Optional[str] = None
|
||||
) -> OntologyMetadata:
|
||||
if not file.filename:
|
||||
raise ValueError("File must have a filename")
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError("File must be in .owl format")
|
||||
|
||||
|
|
@ -57,8 +60,6 @@ class OntologyService:
|
|||
raise ValueError(f"Ontology key '{ontology_key}' already exists")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError("File size exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{ontology_key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
|
|
@ -82,7 +83,11 @@ class OntologyService:
|
|||
)
|
||||
|
||||
async def upload_ontologies(
|
||||
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
|
||||
self,
|
||||
ontology_key: List[str],
|
||||
files: List[UploadFile],
|
||||
user,
|
||||
descriptions: Optional[List[str]] = None,
|
||||
) -> List[OntologyMetadata]:
|
||||
"""
|
||||
Upload ontology files with their respective keys.
|
||||
|
|
@ -105,47 +110,17 @@ class OntologyService:
|
|||
if len(set(ontology_key)) != len(ontology_key):
|
||||
raise ValueError("Duplicate ontology keys not allowed")
|
||||
|
||||
if descriptions and len(descriptions) != len(files):
|
||||
raise ValueError("Number of descriptions must match number of files")
|
||||
|
||||
results = []
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
for i, (key, file) in enumerate(zip(ontology_key, files)):
|
||||
if key in metadata:
|
||||
raise ValueError(f"Ontology key '{key}' already exists")
|
||||
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError(f"File '{file.filename}' must be in .owl format")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
ontology_metadata = {
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"description": descriptions[i] if descriptions else None,
|
||||
}
|
||||
metadata[key] = ontology_metadata
|
||||
|
||||
results.append(
|
||||
OntologyMetadata(
|
||||
await self.upload_ontology(
|
||||
ontology_key=key,
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
uploaded_at=ontology_metadata["uploaded_at"],
|
||||
file=file,
|
||||
user=user,
|
||||
description=descriptions[i] if descriptions else None,
|
||||
)
|
||||
)
|
||||
|
||||
self._save_metadata(user_dir, metadata)
|
||||
return results
|
||||
|
||||
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from typing import Union
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
|
|
@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
||||
graph_db_config = get_graph_config()
|
||||
vector_db_config = get_vectordb_config()
|
||||
|
||||
graph_handler = graph_db_config.graph_dataset_database_handler
|
||||
vector_handler = vector_db_config.vector_dataset_database_handler
|
||||
from cognee.infrastructure.databases.dataset_database_handler import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
if graph_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if vector_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[graph_handler]["handler_provider"]
|
||||
!= graph_db_config.graph_database_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[vector_handler]["handler_provider"]
|
||||
!= vector_db_config.vector_db_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
|
|
@ -41,12 +79,7 @@ def backend_access_control_enabled():
|
|||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return multi_user_support_possible()
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
# Ensure that all connection info is resolved properly
|
||||
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
|
||||
|
||||
base_config = get_base_config()
|
||||
data_root_directory = os.path.join(
|
||||
|
|
@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
)
|
||||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
# TODO: Add better handling of vector and graph config accross Cognee.
|
||||
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||
vector_config = {
|
||||
"vector_db_provider": dataset_database.vector_database_provider,
|
||||
"vector_db_url": dataset_database.vector_database_url,
|
||||
|
|
@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
"graph_file_path": os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
"graph_database_username": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
"graph_database_password": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
"graph_dataset_database_handler": "",
|
||||
"graph_database_port": "",
|
||||
}
|
||||
|
||||
storage_config = {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from cognee.modules.users.models import User
|
|||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
|
|
@ -31,7 +30,6 @@ async def get_cascade_graph_tasks(
|
|||
cognee_config = get_cognify_config()
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
||||
), # Extract text chunks based on the document type.
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ async def get_no_summary_tasks(
|
|||
ontology_file_path=None,
|
||||
) -> List[Task]:
|
||||
"""Returns default tasks without summarization tasks."""
|
||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
||||
# Get base tasks (0=classify, 1=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||
|
||||
ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path)
|
||||
|
||||
|
|
@ -51,8 +51,8 @@ async def get_just_chunks_tasks(
|
|||
chunk_size: int = None, chunker=TextChunker, user=None
|
||||
) -> List[Task]:
|
||||
"""Returns default tasks with only chunk extraction and data points addition."""
|
||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
||||
# Get base tasks (0=classify, 1=extract_chunks)
|
||||
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||
|
||||
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
|
||||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
from .use_dataset_database_handler import use_dataset_database_handler
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.users.models.User import User
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
class DatasetDatabaseHandlerInterface(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
The returned dictionary is stored verbatim in the relational database and is later passed to
|
||||
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
|
||||
returning only references to secrets or role identifiers, not plaintext credentials.
|
||||
|
||||
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created graph or vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve runtime connection details for a dataset’s backing graph/vector database.
|
||||
Function is intended to be overwritten to implement custom logic for resolving connection info.
|
||||
|
||||
This method is invoked right before the application opens a connection for a given dataset.
|
||||
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
|
||||
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
|
||||
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
|
||||
|
||||
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
|
||||
connection info and only change parameters related to its appropriate database, the resolution function will then
|
||||
be called one after another with the updated DatasetDatabase value from the previous function as the input.
|
||||
|
||||
Typical behavior:
|
||||
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
|
||||
or api_url/api_key), return them as-is.
|
||||
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
|
||||
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
|
||||
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
|
||||
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
|
||||
to the caller.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row from the relational database
|
||||
Returns:
|
||||
DatasetDatabase: Updated instance with resolved connection info
|
||||
"""
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
|
||||
"""
|
||||
Delete the graph or vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
|
||||
Neo4jAuraDevDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
|
||||
LanceDBDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||
KuzuDatasetDatabaseHandler,
|
||||
)
|
||||
|
||||
supported_dataset_database_handlers = {
|
||||
"neo4j_aura_dev": {
|
||||
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
|
||||
"handler_provider": "neo4j",
|
||||
},
|
||||
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
|
||||
|
||||
def use_dataset_database_handler(
|
||||
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
|
||||
):
|
||||
supported_dataset_database_handlers[dataset_database_handler_name] = {
|
||||
"handler_instance": dataset_database_handler,
|
||||
"handler_provider": dataset_database_provider,
|
||||
}
|
||||
|
|
@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_filename: str = ""
|
||||
graph_model: object = KnowledgeGraph
|
||||
graph_topology: object = KnowledgeGraph
|
||||
graph_dataset_database_handler: str = "kuzu"
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||
|
||||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||
|
|
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
"model_config": self.model_config,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
def to_hashable_dict(self) -> dict:
|
||||
|
|
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def create_graph_engine(
|
|||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
graph_dataset_database_handler="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with Kuzu Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Kuzu instance
|
||||
|
||||
"""
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "kuzu":
|
||||
raise ValueError(
|
||||
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
graph_db_url = graph_config.graph_database_url
|
||||
graph_db_key = graph_config.graph_database_key
|
||||
graph_db_username = graph_config.graph_database_username
|
||||
graph_db_password = graph_config.graph_database_password
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": graph_config.graph_database_provider,
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_dataset_database_handler": "kuzu",
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
|
||||
)
|
||||
graph_file_path = os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
)
|
||||
graph_engine = create_graph_engine(
|
||||
graph_database_provider=dataset_database.graph_database_provider,
|
||||
graph_database_url=dataset_database.graph_database_url,
|
||||
graph_database_name=dataset_database.graph_database_name,
|
||||
graph_database_key=dataset_database.graph_database_key,
|
||||
graph_file_path=graph_file_path,
|
||||
graph_database_username=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
graph_database_password=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
graph_dataset_database_handler="",
|
||||
graph_database_port="",
|
||||
)
|
||||
await graph_engine.delete_graph()
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
import base64
|
||||
import hashlib
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.modules.users.models import User, DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
|
||||
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
|
||||
|
||||
Improvements needed to be production ready:
|
||||
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
|
||||
a secret manager or a similar system should be used instead.
|
||||
|
||||
Quality of life improvements:
|
||||
- Allow configuration of different Neo4j Aura plans and regions.
|
||||
- Requests should be made async, currently a blocking requests library is used.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Neo4j instance
|
||||
|
||||
"""
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "neo4j":
|
||||
raise ValueError(
|
||||
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
# Client credentials and encryption
|
||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
|
||||
if client_id is None or client_secret is None or tenant_id is None:
|
||||
raise ValueError(
|
||||
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||
)
|
||||
|
||||
# Make the request with HTTP Basic Auth
|
||||
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||
url = "https://api.neo4j.io/oauth/token"
|
||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||
|
||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||
resp.raise_for_status() # raises if the request failed
|
||||
return resp.json()
|
||||
|
||||
resp = get_aura_token(client_id, client_secret)
|
||||
|
||||
url = "https://api.neo4j.io/v1/instances"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Authorization": f"Bearer {resp['access_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||
# Too allow different configurations between datasets
|
||||
payload = {
|
||||
"version": "5",
|
||||
"region": "europe-west1",
|
||||
"memory": "1GB",
|
||||
"name": graph_db_name[
|
||||
0:29
|
||||
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||
"type": "professional-db",
|
||||
"tenant_id": tenant_id,
|
||||
"cloud_provider": "gcp",
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||
graph_db_url = response.json()["data"]["connection_url"]
|
||||
graph_db_key = resp["access_token"]
|
||||
graph_db_username = response.json()["data"]["username"]
|
||||
graph_db_password = response.json()["data"]["password"]
|
||||
|
||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||
# Poll until the instance is running
|
||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||
status = ""
|
||||
for attempt in range(30): # Try for up to ~5 minutes
|
||||
status_resp = requests.get(
|
||||
status_url, headers=headers
|
||||
) # TODO: Use async requests with httpx
|
||||
status = status_resp.json()["data"]["status"]
|
||||
if status.lower() == "running":
|
||||
return
|
||||
await asyncio.sleep(10)
|
||||
raise TimeoutError(
|
||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||
)
|
||||
|
||||
instance_id = response.json()["data"]["id"]
|
||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||
|
||||
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_dataset_database_handler": "neo4j_aura_dev",
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": encrypted_db_password_string,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||
In this case, decrypt the password stored in the database.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||
"""
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
graph_db_password = cipher.decrypt(
|
||||
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||
).decode()
|
||||
|
||||
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||
graph_db_password
|
||||
)
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
pass
|
||||
|
|
@ -1 +1,2 @@
|
|||
from .get_or_create_dataset_database import get_or_create_dataset_database
|
||||
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
|
|
@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
|
|||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
graph_config = get_graph_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _existing_dataset_database(
|
||||
dataset_id: UUID,
|
||||
user: User,
|
||||
) -> Optional[DatasetDatabase]:
|
||||
"""
|
||||
Check if a DatasetDatabase row already exists for the given owner + dataset.
|
||||
Return None if it doesn't exist, return the row if it does.
|
||||
Args:
|
||||
dataset_id:
|
||||
user:
|
||||
|
||||
Returns:
|
||||
DatasetDatabase or None
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = select(DatasetDatabase).where(
|
||||
DatasetDatabase.owner_id == user.id,
|
||||
DatasetDatabase.dataset_id == dataset_id,
|
||||
)
|
||||
existing: DatasetDatabase = await session.scalar(stmt)
|
||||
return existing
|
||||
|
||||
|
||||
async def get_or_create_dataset_database(
|
||||
dataset: Union[str, UUID],
|
||||
user: User,
|
||||
|
|
@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
|
|||
• If the row already exists, it is fetched and returned.
|
||||
• Otherwise a new one is created atomically and returned.
|
||||
|
||||
DatasetDatabase row contains connection and provider info for vector and graph databases.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
user : User
|
||||
|
|
@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
|
|||
|
||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
graph_config = get_graph_config()
|
||||
# If dataset is given as name make sure the dataset is created first
|
||||
if isinstance(dataset, str):
|
||||
async with db_engine.get_async_session() as session:
|
||||
await create_dataset(dataset, user, session)
|
||||
|
||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||
if graph_config.graph_database_provider == "kuzu":
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
else:
|
||||
graph_db_name = f"{dataset_id}"
|
||||
# If dataset database already exists return it
|
||||
existing_dataset_database = await _existing_dataset_database(dataset_id, user)
|
||||
if existing_dataset_database:
|
||||
return existing_dataset_database
|
||||
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
else:
|
||||
vector_db_name = f"{dataset_id}"
|
||||
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
# Determine vector database URL
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
||||
else:
|
||||
vector_db_url = vector_config.vector_database_url
|
||||
|
||||
# Determine graph database URL
|
||||
graph_config_dict = await _get_graph_db_info(dataset_id, user)
|
||||
vector_config_dict = await _get_vector_db_info(dataset_id, user)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Create dataset if it doesn't exist
|
||||
if isinstance(dataset, str):
|
||||
dataset = await create_dataset(dataset, user, session)
|
||||
|
||||
# Try to fetch an existing row first
|
||||
stmt = select(DatasetDatabase).where(
|
||||
DatasetDatabase.owner_id == user.id,
|
||||
DatasetDatabase.dataset_id == dataset_id,
|
||||
)
|
||||
existing: DatasetDatabase = await session.scalar(stmt)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# If there are no existing rows build a new row
|
||||
record = DatasetDatabase(
|
||||
owner_id=user.id,
|
||||
dataset_id=dataset_id,
|
||||
vector_database_name=vector_db_name,
|
||||
graph_database_name=graph_db_name,
|
||||
vector_database_provider=vector_config.vector_db_provider,
|
||||
graph_database_provider=graph_config.graph_database_provider,
|
||||
vector_database_url=vector_db_url,
|
||||
graph_database_url=graph_config.graph_database_url,
|
||||
vector_database_key=vector_config.vector_db_key,
|
||||
graph_database_key=graph_config.graph_database_key,
|
||||
**graph_config_dict, # Unpack graph db config
|
||||
**vector_config_dict, # Unpack vector db config
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
|
||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||
|
||||
|
||||
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
|
||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||
|
||||
|
||||
async def resolve_dataset_database_connection_info(
|
||||
dataset_database: DatasetDatabase,
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve the connection info for the given DatasetDatabase instance.
|
||||
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance
|
||||
Returns:
|
||||
DatasetDatabase instance with resolved connection info
|
||||
"""
|
||||
dataset_database = await _get_vector_db_connection_info(dataset_database)
|
||||
dataset_database = await _get_graph_db_connection_info(dataset_database)
|
||||
return dataset_database
|
||||
|
|
@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
|
|||
vector_db_name: str = ""
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
vector_dataset_database_handler: str = "lancedb"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
|
|||
"vector_db_name": self.vector_db_name,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ def create_vector_engine(
|
|||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
vector_dataset_database_handler: str = "",
|
||||
):
|
||||
"""
|
||||
Create a vector database engine based on the specified provider.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with LanceDB Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
base_config = get_base_config()
|
||||
|
||||
if vector_config.vector_db_provider != "lancedb":
|
||||
raise ValueError(
|
||||
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
|
||||
)
|
||||
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
|
||||
return {
|
||||
"vector_database_provider": vector_config.vector_db_provider,
|
||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||
"vector_database_key": vector_config.vector_db_key,
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_dataset_database_handler": "lancedb",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
vector_engine = create_vector_engine(
|
||||
vector_db_provider=dataset_database.vector_database_provider,
|
||||
vector_db_url=dataset_database.vector_database_url,
|
||||
vector_db_key=dataset_database.vector_database_key,
|
||||
vector_db_name=dataset_database.vector_database_name,
|
||||
)
|
||||
await vector_engine.prune()
|
||||
|
|
@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
|
|||
from abc import abstractmethod
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from .models.PayloadSchema import PayloadSchema
|
||||
from uuid import UUID
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
class VectorDBInterface(Protocol):
|
||||
|
|
@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
|
|||
- Any: The schema object suitable for this vector database
|
||||
"""
|
||||
return model_type
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with connection info for a vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
||||
"""
|
||||
Delete the vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
user: User object
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,17 +1,81 @@
|
|||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.shared.cache import delete_cache
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def prune_graph_databases():
|
||||
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[
|
||||
dataset_database.graph_dataset_database_handler
|
||||
]
|
||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
try:
|
||||
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||
# Go through each dataset database and delete the graph database
|
||||
for data_item in data:
|
||||
await _prune_graph_db(data_item)
|
||||
except (OperationalError, EntityNotFoundError) as e:
|
||||
logger.debug(
|
||||
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def prune_vector_databases():
|
||||
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[
|
||||
dataset_database.vector_dataset_database_handler
|
||||
]
|
||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
try:
|
||||
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||
# Go through each dataset database and delete the vector database
|
||||
for data_item in data:
|
||||
await _prune_vector_db(data_item)
|
||||
except (OperationalError, EntityNotFoundError) as e:
|
||||
logger.debug(
|
||||
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||
if graph:
|
||||
# Note: prune system should not be available through the API, it has no permission checks and will
|
||||
# delete all graph and vector databases if called. It should only be used in development or testing environments.
|
||||
if graph and not backend_access_control_enabled():
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
elif graph and backend_access_control_enabled():
|
||||
await prune_graph_databases()
|
||||
|
||||
if vector:
|
||||
if vector and not backend_access_control_enabled():
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.prune()
|
||||
elif vector and backend_access_control_enabled():
|
||||
await prune_vector_databases()
|
||||
|
||||
if metadata:
|
||||
db_engine = get_relational_engine()
|
||||
|
|
|
|||
|
|
@ -12,9 +12,6 @@ from cognee.modules.users.models import User
|
|||
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
||||
resolve_authorized_user_datasets,
|
||||
)
|
||||
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
||||
reset_dataset_pipeline_run_status,
|
||||
)
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
||||
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
|
||||
|
|
@ -97,10 +94,6 @@ async def memify(
|
|||
*enrichment_tasks,
|
||||
]
|
||||
|
||||
await reset_dataset_pipeline_run_status(
|
||||
authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
|
||||
)
|
||||
|
||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||
|
||||
|
|
@ -113,6 +106,7 @@ async def memify(
|
|||
datasets=authorized_dataset.id,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
use_pipeline_cache=False,
|
||||
incremental_loading=False,
|
||||
pipeline_name="memify_pipeline",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
|||
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
||||
check_pipeline_run_qualification,
|
||||
)
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunStarted,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("cognee.pipeline")
|
||||
|
|
@ -35,6 +38,7 @@ async def run_pipeline(
|
|||
pipeline_name: str = "custom_pipeline",
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
use_pipeline_cache: bool = False,
|
||||
incremental_loading: bool = False,
|
||||
data_per_batch: int = 20,
|
||||
):
|
||||
|
|
@ -51,6 +55,7 @@ async def run_pipeline(
|
|||
data=data,
|
||||
pipeline_name=pipeline_name,
|
||||
context={"dataset": dataset},
|
||||
use_pipeline_cache=use_pipeline_cache,
|
||||
incremental_loading=incremental_loading,
|
||||
data_per_batch=data_per_batch,
|
||||
):
|
||||
|
|
@ -64,6 +69,7 @@ async def run_pipeline_per_dataset(
|
|||
data=None,
|
||||
pipeline_name: str = "custom_pipeline",
|
||||
context: dict = None,
|
||||
use_pipeline_cache=False,
|
||||
incremental_loading=False,
|
||||
data_per_batch: int = 20,
|
||||
):
|
||||
|
|
@ -77,8 +83,18 @@ async def run_pipeline_per_dataset(
|
|||
if process_pipeline_status:
|
||||
# If pipeline was already processed or is currently being processed
|
||||
# return status information to async generator and finish execution
|
||||
yield process_pipeline_status
|
||||
return
|
||||
if use_pipeline_cache:
|
||||
# If pipeline caching is enabled we do not proceed with re-processing
|
||||
yield process_pipeline_status
|
||||
return
|
||||
else:
|
||||
# If pipeline caching is disabled we always return pipeline started information and proceed with re-processing
|
||||
yield PipelineRunStarted(
|
||||
pipeline_run_id=process_pipeline_status.pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=data,
|
||||
)
|
||||
|
||||
pipeline_run = run_tasks(
|
||||
tasks,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ async def run_custom_pipeline(
|
|||
user: User = None,
|
||||
vector_db_config: Optional[dict] = None,
|
||||
graph_db_config: Optional[dict] = None,
|
||||
use_pipeline_cache: bool = False,
|
||||
incremental_loading: bool = False,
|
||||
data_per_batch: int = 20,
|
||||
run_in_background: bool = False,
|
||||
pipeline_name: str = "custom_pipeline",
|
||||
|
|
@ -40,6 +42,10 @@ async def run_custom_pipeline(
|
|||
user: User context for authentication and data access. Uses default if None.
|
||||
vector_db_config: Custom vector database configuration for embeddings storage.
|
||||
graph_db_config: Custom graph database configuration for relationship storage.
|
||||
use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again.
|
||||
Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function.
|
||||
incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model).
|
||||
The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash.
|
||||
data_per_batch: Number of data items to be processed in parallel.
|
||||
run_in_background: If True, starts processing asynchronously and returns immediately.
|
||||
If False, waits for completion before returning.
|
||||
|
|
@ -63,7 +69,8 @@ async def run_custom_pipeline(
|
|||
datasets=dataset,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=False,
|
||||
use_pipeline_cache=use_pipeline_cache,
|
||||
incremental_loading=incremental_loading,
|
||||
data_per_batch=data_per_batch,
|
||||
pipeline_name=pipeline_name,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
|
|||
|
||||
# Check environment variable to determine authentication requirement
|
||||
REQUIRE_AUTHENTICATION = (
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||
or backend_access_control_enabled()
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
|
||||
or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
|
||||
)
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
|
||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
|
|
@ -12,17 +12,29 @@ class DatasetDatabase(Base):
|
|||
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
|
||||
)
|
||||
|
||||
vector_database_name = Column(String, unique=True, nullable=False)
|
||||
graph_database_name = Column(String, unique=True, nullable=False)
|
||||
vector_database_name = Column(String, unique=False, nullable=False)
|
||||
graph_database_name = Column(String, unique=False, nullable=False)
|
||||
|
||||
vector_database_provider = Column(String, unique=False, nullable=False)
|
||||
graph_database_provider = Column(String, unique=False, nullable=False)
|
||||
|
||||
graph_dataset_database_handler = Column(String, unique=False, nullable=False)
|
||||
vector_dataset_database_handler = Column(String, unique=False, nullable=False)
|
||||
|
||||
vector_database_url = Column(String, unique=False, nullable=True)
|
||||
graph_database_url = Column(String, unique=False, nullable=True)
|
||||
|
||||
vector_database_key = Column(String, unique=False, nullable=True)
|
||||
graph_database_key = Column(String, unique=False, nullable=True)
|
||||
|
||||
# configuration details for different database types. This would make it more flexible to add new database types
|
||||
# without changing the database schema.
|
||||
graph_database_connection_info = Column(
|
||||
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||
)
|
||||
vector_database_connection_info = Column(
|
||||
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||
)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
|
|
|||
|
|
@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None):
|
|||
# Get a configured logger and log system information
|
||||
logger = structlog.get_logger(name if name else __name__)
|
||||
|
||||
logger.warning(
|
||||
"From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation."
|
||||
)
|
||||
|
||||
if logs_dir is not None:
|
||||
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
from .classify_documents import classify_documents
|
||||
from .extract_chunks_from_documents import extract_chunks_from_documents
|
||||
from .check_permissions_on_dataset import check_permissions_on_dataset
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.users.permissions.methods import check_permission_on_dataset
|
||||
from typing import List
|
||||
|
||||
|
||||
async def check_permissions_on_dataset(
|
||||
documents: List[Document], context: dict, user, permissions
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Validates a user's permissions on a list of documents.
|
||||
|
||||
Notes:
|
||||
- This function assumes that `check_permission_on_documents` raises an exception if the permission check fails.
|
||||
- It is designed to validate multiple permissions in a sequential manner for the same set of documents.
|
||||
- Ensure that the `Document` and `user` objects conform to the expected structure and interfaces.
|
||||
"""
|
||||
|
||||
for permission in permissions:
|
||||
await check_permission_on_dataset(
|
||||
user,
|
||||
permission,
|
||||
# TODO: pass dataset through argument instead of context
|
||||
context["dataset"].id,
|
||||
)
|
||||
|
||||
return documents
|
||||
137
cognee/tests/test_dataset_database_handler.py
Normal file
137
cognee/tests/test_dataset_database_handler.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
# Set custom dataset database handler environment variable
|
||||
os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler"
|
||||
os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler"
|
||||
|
||||
import cognee
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
|
||||
class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id, user):
|
||||
import pathlib
|
||||
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id))
|
||||
os.makedirs(databases_directory_path, exist_ok=True)
|
||||
|
||||
vector_db_name = "test.lance.db"
|
||||
|
||||
return {
|
||||
"vector_dataset_database_handler": "custom_lancedb_handler",
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||
"vector_database_provider": "lancedb",
|
||||
}
|
||||
|
||||
|
||||
class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id, user):
|
||||
databases_directory_path = os.path.join("databases", str(user.id))
|
||||
os.makedirs(databases_directory_path, exist_ok=True)
|
||||
|
||||
graph_db_name = "test.kuzu"
|
||||
return {
|
||||
"graph_dataset_database_handler": "custom_kuzu_handler",
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
|
||||
"graph_database_provider": "kuzu",
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
import pathlib
|
||||
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler"
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
# Add custom dataset database handler
|
||||
from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import (
|
||||
use_dataset_database_handler,
|
||||
)
|
||||
|
||||
use_dataset_database_handler(
|
||||
"custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb"
|
||||
)
|
||||
use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu")
|
||||
|
||||
# Create a clean slate for cognee -- reset data and system state
|
||||
print("Resetting cognee data...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
print("Data reset complete.\n")
|
||||
|
||||
# cognee knowledge graph will be created based on this text
|
||||
text = """
|
||||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
|
||||
print("Adding text to cognee:")
|
||||
print(text.strip())
|
||||
|
||||
# Add the text, and make it available for cognify
|
||||
await cognee.add(text)
|
||||
print("Text added successfully.\n")
|
||||
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
await cognee.cognify()
|
||||
print("Cognify process complete.\n")
|
||||
|
||||
query_text = "Tell me about NLP"
|
||||
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||
# Query cognee for insights on the added text
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
|
||||
)
|
||||
|
||||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
|
||||
default_user = await get_default_user()
|
||||
# Assert that the custom database files were created based on the custom dataset database handlers
|
||||
assert os.path.exists(
|
||||
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu")
|
||||
), "Graph database file not found."
|
||||
assert os.path.exists(
|
||||
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db")
|
||||
), "Vector database file not found."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
164
cognee/tests/test_pipeline_cache.py
Normal file
164
cognee/tests/test_pipeline_cache.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Test suite for the pipeline_cache feature in Cognee pipelines.
|
||||
|
||||
This module tests the behavior of the `pipeline_cache` parameter which controls
|
||||
whether a pipeline should skip re-execution when it has already been completed
|
||||
for the same dataset.
|
||||
|
||||
Architecture Overview:
|
||||
---------------------
|
||||
The pipeline_cache mechanism works at the dataset level:
|
||||
1. When a pipeline runs, it logs its status (INITIATED -> STARTED -> COMPLETED)
|
||||
2. Before each run, `check_pipeline_run_qualification()` checks the pipeline status
|
||||
3. If `use_pipeline_cache=True` and status is COMPLETED/STARTED, the pipeline skips
|
||||
4. If `use_pipeline_cache=False`, the pipeline always re-executes regardless of status
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import cognee
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.pipelines import run_pipeline
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
||||
reset_dataset_pipeline_run_status,
|
||||
)
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
class ExecutionCounter:
|
||||
"""Helper class to track task execution counts."""
|
||||
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
|
||||
async def create_counting_task(data, counter: ExecutionCounter):
|
||||
"""Create a task that increments a counter from the ExecutionCounter instance when executed."""
|
||||
counter.count += 1
|
||||
return counter
|
||||
|
||||
|
||||
class TestPipelineCache:
|
||||
"""Tests for basic pipeline_cache on/off behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_cache_off_allows_reexecution(self):
|
||||
"""
|
||||
Test that with use_pipeline_cache=False, the pipeline re-executes
|
||||
even when it has already completed for the dataset.
|
||||
|
||||
Expected behavior:
|
||||
- First run: Pipeline executes fully, task runs once
|
||||
- Second run: Pipeline executes again, task runs again (total: 2 times)
|
||||
"""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
|
||||
counter = ExecutionCounter()
|
||||
user = await get_default_user()
|
||||
|
||||
tasks = [Task(create_counting_task, counter=counter)]
|
||||
|
||||
# First run
|
||||
pipeline_results_1 = []
|
||||
async for result in run_pipeline(
|
||||
tasks=tasks,
|
||||
datasets="test_dataset_cache_off",
|
||||
data=["sample data"], # Data is necessary to trigger processing
|
||||
user=user,
|
||||
pipeline_name="test_cache_off_pipeline",
|
||||
use_pipeline_cache=False,
|
||||
):
|
||||
pipeline_results_1.append(result)
|
||||
|
||||
first_run_count = counter.count
|
||||
assert first_run_count >= 1, "Task should have executed at least once on first run"
|
||||
|
||||
# Second run with pipeline_cache=False
|
||||
pipeline_results_2 = []
|
||||
async for result in run_pipeline(
|
||||
tasks=tasks,
|
||||
datasets="test_dataset_cache_off",
|
||||
data=["sample data"], # Data is necessary to trigger processing
|
||||
user=user,
|
||||
pipeline_name="test_cache_off_pipeline",
|
||||
use_pipeline_cache=False,
|
||||
):
|
||||
pipeline_results_2.append(result)
|
||||
|
||||
second_run_count = counter.count
|
||||
assert second_run_count > first_run_count, (
|
||||
f"With pipeline_cache=False, task should re-execute. "
|
||||
f"First run: {first_run_count}, After second run: {second_run_count}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_pipeline_status_allows_reexecution_with_cache(self):
|
||||
"""
|
||||
Test that resetting pipeline status allows re-execution even with
|
||||
pipeline_cache=True.
|
||||
"""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
|
||||
counter = ExecutionCounter()
|
||||
user = await get_default_user()
|
||||
dataset_name = "reset_status_test"
|
||||
pipeline_name = "test_reset_pipeline"
|
||||
|
||||
tasks = [Task(create_counting_task, counter=counter)]
|
||||
|
||||
# First run
|
||||
pipeline_result = []
|
||||
async for result in run_pipeline(
|
||||
tasks=tasks,
|
||||
datasets=dataset_name,
|
||||
user=user,
|
||||
data=["sample data"], # Data is necessary to trigger processing
|
||||
pipeline_name=pipeline_name,
|
||||
use_pipeline_cache=True,
|
||||
):
|
||||
pipeline_result.append(result)
|
||||
|
||||
first_run_count = counter.count
|
||||
assert first_run_count >= 1
|
||||
|
||||
# Second run without reset - should skip
|
||||
async for _ in run_pipeline(
|
||||
tasks=tasks,
|
||||
datasets=dataset_name,
|
||||
user=user,
|
||||
data=["sample data"], # Data is necessary to trigger processing
|
||||
pipeline_name=pipeline_name,
|
||||
use_pipeline_cache=True,
|
||||
):
|
||||
pass
|
||||
|
||||
after_second_run = counter.count
|
||||
assert after_second_run == first_run_count, "Should have skipped due to cache"
|
||||
|
||||
# Reset the pipeline status
|
||||
await reset_dataset_pipeline_run_status(
|
||||
pipeline_result[0].dataset_id, user, pipeline_names=[pipeline_name]
|
||||
)
|
||||
|
||||
# Third run after reset - should execute
|
||||
async for _ in run_pipeline(
|
||||
tasks=tasks,
|
||||
datasets=dataset_name,
|
||||
user=user,
|
||||
data=["sample data"], # Data is necessary to trigger processing
|
||||
pipeline_name=pipeline_name,
|
||||
use_pipeline_cache=True,
|
||||
):
|
||||
pass
|
||||
|
||||
after_reset_run = counter.count
|
||||
assert after_reset_run > after_second_run, (
|
||||
f"After reset, pipeline should re-execute. "
|
||||
f"Before reset: {after_second_run}, After reset run: {after_reset_run}"
|
||||
)
|
||||
|
|
@ -32,16 +32,13 @@ async def main():
|
|||
print("Cognify process steps:")
|
||||
print("1. Classifying the document: Determining the type and category of the input text.")
|
||||
print(
|
||||
"2. Checking permissions: Ensuring the user has the necessary rights to process the text."
|
||||
"2. Extracting text chunks: Breaking down the text into sentences or phrases for analysis."
|
||||
)
|
||||
print(
|
||||
"3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis."
|
||||
"3. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph."
|
||||
)
|
||||
print("4. Adding data points: Storing the extracted chunks for processing.")
|
||||
print(
|
||||
"5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph."
|
||||
)
|
||||
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
|
||||
print("4. Summarizing text: Creating concise summaries of the content for quick insights.")
|
||||
print("5. Adding data points: Storing the extracted chunks for processing.\n")
|
||||
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
await cognee.cognify()
|
||||
|
|
|
|||
4
notebooks/cognee_demo.ipynb
vendored
4
notebooks/cognee_demo.ipynb
vendored
|
|
@ -591,7 +591,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "7c431fdef4921ae0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -609,7 +609,6 @@
|
|||
"from cognee.modules.pipelines import run_tasks\n",
|
||||
"from cognee.modules.users.models import User\n",
|
||||
"from cognee.tasks.documents import (\n",
|
||||
" check_permissions_on_dataset,\n",
|
||||
" classify_documents,\n",
|
||||
" extract_chunks_from_documents,\n",
|
||||
")\n",
|
||||
|
|
@ -627,7 +626,6 @@
|
|||
"\n",
|
||||
" tasks = [\n",
|
||||
" Task(classify_documents),\n",
|
||||
" Task(check_permissions_on_dataset, user=user, permissions=[\"write\"]),\n",
|
||||
" Task(\n",
|
||||
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
|
||||
" ), # Extract text chunks based on the document type.\n",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue