Merge branch 'dev' into add-s3-permissions-test
This commit is contained in:
commit
0cde551226
44 changed files with 1461 additions and 179 deletions
|
|
@ -97,6 +97,8 @@ DB_NAME=cognee_db
|
||||||
|
|
||||||
# Default (local file-based)
|
# Default (local file-based)
|
||||||
GRAPH_DATABASE_PROVIDER="kuzu"
|
GRAPH_DATABASE_PROVIDER="kuzu"
|
||||||
|
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||||
|
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
|
||||||
|
|
||||||
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
||||||
#GRAPH_DATABASE_PROVIDER="kuzu"
|
#GRAPH_DATABASE_PROVIDER="kuzu"
|
||||||
|
|
@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb"
|
||||||
# Not needed if a cloud vector database is not used
|
# Not needed if a cloud vector database is not used
|
||||||
VECTOR_DB_URL=
|
VECTOR_DB_URL=
|
||||||
VECTOR_DB_KEY=
|
VECTOR_DB_KEY=
|
||||||
|
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||||
|
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🧩 Ontology resolver settings
|
# 🧩 Ontology resolver settings
|
||||||
|
|
|
||||||
2
.github/workflows/db_examples_tests.yml
vendored
2
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -61,6 +61,7 @@ jobs:
|
||||||
- name: Run Neo4j Example
|
- name: Run Neo4j Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
@ -142,6 +143,7 @@ jobs:
|
||||||
- name: Run PGVector Example
|
- name: Run PGVector Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
1
.github/workflows/distributed_test.yml
vendored
1
.github/workflows/distributed_test.yml
vendored
|
|
@ -47,6 +47,7 @@ jobs:
|
||||||
- name: Run Distributed Cognee (Modal)
|
- name: Run Distributed Cognee (Modal)
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
53
.github/workflows/e2e_tests.yml
vendored
53
.github/workflows/e2e_tests.yml
vendored
|
|
@ -147,6 +147,7 @@ jobs:
|
||||||
- name: Run Deduplication Example
|
- name: Run Deduplication Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
|
@ -211,6 +212,31 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_parallel_databases.py
|
run: uv run python ./cognee/tests/test_parallel_databases.py
|
||||||
|
|
||||||
|
test-dataset-database-handler:
|
||||||
|
name: Test dataset database handlers in Cognee
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Run dataset databases handler test
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
||||||
|
|
||||||
test-permissions:
|
test-permissions:
|
||||||
name: Test permissions with different situations in Cognee
|
name: Test permissions with different situations in Cognee
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
|
@ -556,3 +582,30 @@ jobs:
|
||||||
DB_USERNAME: cognee
|
DB_USERNAME: cognee
|
||||||
DB_PASSWORD: cognee
|
DB_PASSWORD: cognee
|
||||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||||
|
|
||||||
|
run-pipeline-cache-test:
|
||||||
|
name: Test Pipeline Caching
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Run Pipeline Cache Test
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_pipeline_cache.py
|
||||||
|
|
|
||||||
1
.github/workflows/examples_tests.yml
vendored
1
.github/workflows/examples_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
||||||
- name: Run Descriptive Graph Metrics Example
|
- name: Run Descriptive Graph Metrics Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
1
.github/workflows/graph_db_tests.yml
vendored
1
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -78,6 +78,7 @@ jobs:
|
||||||
- name: Run default Neo4j
|
- name: Run default Neo4j
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
3
.github/workflows/temporal_graph_tests.yml
vendored
3
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
||||||
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
@ -123,6 +124,7 @@ jobs:
|
||||||
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
||||||
env:
|
env:
|
||||||
ENV: dev
|
ENV: dev
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
@ -189,6 +191,7 @@ jobs:
|
||||||
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
||||||
env:
|
env:
|
||||||
ENV: dev
|
ENV: dev
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
|
||||||
3
.github/workflows/vector_db_tests.yml
vendored
3
.github/workflows/vector_db_tests.yml
vendored
|
|
@ -92,6 +92,7 @@ jobs:
|
||||||
- name: Run PGVector Tests
|
- name: Run PGVector Tests
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
@ -127,4 +128,4 @@ jobs:
|
||||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_lancedb.py
|
run: uv run python ./cognee/tests/test_lancedb.py
|
||||||
|
|
|
||||||
3
.github/workflows/weighted_edges_tests.yml
vendored
3
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -94,6 +94,7 @@ jobs:
|
||||||
|
|
||||||
- name: Run Weighted Edges Tests
|
- name: Run Weighted Edges Tests
|
||||||
env:
|
env:
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||||
|
|
@ -165,5 +166,3 @@ jobs:
|
||||||
uses: astral-sh/ruff-action@v2
|
uses: astral-sh/ruff-action@v2
|
||||||
with:
|
with:
|
||||||
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,333 @@
|
||||||
|
"""Expand dataset database with json connection field
|
||||||
|
|
||||||
|
Revision ID: 46a6ce2bd2b2
|
||||||
|
Revises: 76625596c5c3
|
||||||
|
Create Date: 2025-11-25 17:56:28.938931
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "46a6ce2bd2b2"
|
||||||
|
down_revision: Union[str, None] = "76625596c5c3"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
graph_constraint_name = "dataset_database_graph_database_name_key"
|
||||||
|
vector_constraint_name = "dataset_database_vector_database_name_key"
|
||||||
|
TABLE_NAME = "dataset_database"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_column(inspector, table, name, schema=None):
|
||||||
|
for col in inspector.get_columns(table, schema=schema):
|
||||||
|
if col["name"] == name:
|
||||||
|
return col
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||||
|
"""
|
||||||
|
SQLite cannot drop unique constraints on individual columns. We must:
|
||||||
|
1. Create a new table without the unique constraints.
|
||||||
|
2. Copy data from the old table.
|
||||||
|
3. Drop the old table.
|
||||||
|
4. Rename the new table.
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# Create new table definition (without unique constraints)
|
||||||
|
op.create_table(
|
||||||
|
f"{TABLE_NAME}_new",
|
||||||
|
sa.Column("owner_id", sa.UUID()),
|
||||||
|
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||||
|
sa.Column("vector_database_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("graph_database_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"vector_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="lancedb",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"graph_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="kuzu",
|
||||||
|
),
|
||||||
|
sa.Column("vector_database_url", sa.String()),
|
||||||
|
sa.Column("graph_database_url", sa.String()),
|
||||||
|
sa.Column("vector_database_key", sa.String()),
|
||||||
|
sa.Column("graph_database_key", sa.String()),
|
||||||
|
sa.Column(
|
||||||
|
"graph_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"vector_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column("created_at", sa.DateTime()),
|
||||||
|
sa.Column("updated_at", sa.DateTime()),
|
||||||
|
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy data into new table
|
||||||
|
conn.execute(
|
||||||
|
sa.text(f"""
|
||||||
|
INSERT INTO {TABLE_NAME}_new
|
||||||
|
SELECT
|
||||||
|
owner_id,
|
||||||
|
dataset_id,
|
||||||
|
vector_database_name,
|
||||||
|
graph_database_name,
|
||||||
|
vector_database_provider,
|
||||||
|
graph_database_provider,
|
||||||
|
vector_dataset_database_handler,
|
||||||
|
graph_dataset_database_handler,
|
||||||
|
vector_database_url,
|
||||||
|
graph_database_url,
|
||||||
|
vector_database_key,
|
||||||
|
graph_database_key,
|
||||||
|
COALESCE(graph_database_connection_info, '{{}}'),
|
||||||
|
COALESCE(vector_database_connection_info, '{{}}'),
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
FROM {TABLE_NAME}
|
||||||
|
""")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop old table
|
||||||
|
op.drop_table(TABLE_NAME)
|
||||||
|
|
||||||
|
# Rename new table
|
||||||
|
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||||
|
"""
|
||||||
|
SQLite cannot drop unique constraints on individual columns. We must:
|
||||||
|
1. Create a new table without the unique constraints.
|
||||||
|
2. Copy data from the old table.
|
||||||
|
3. Drop the old table.
|
||||||
|
4. Rename the new table.
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# Create new table definition (without unique constraints)
|
||||||
|
op.create_table(
|
||||||
|
f"{TABLE_NAME}_new",
|
||||||
|
sa.Column("owner_id", sa.UUID()),
|
||||||
|
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||||
|
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
|
||||||
|
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
||||||
|
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"vector_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="lancedb",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"graph_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="kuzu",
|
||||||
|
),
|
||||||
|
sa.Column("vector_database_url", sa.String()),
|
||||||
|
sa.Column("graph_database_url", sa.String()),
|
||||||
|
sa.Column("vector_database_key", sa.String()),
|
||||||
|
sa.Column("graph_database_key", sa.String()),
|
||||||
|
sa.Column(
|
||||||
|
"graph_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"vector_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column("created_at", sa.DateTime()),
|
||||||
|
sa.Column("updated_at", sa.DateTime()),
|
||||||
|
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy data into new table
|
||||||
|
conn.execute(
|
||||||
|
sa.text(f"""
|
||||||
|
INSERT INTO {TABLE_NAME}_new
|
||||||
|
SELECT
|
||||||
|
owner_id,
|
||||||
|
dataset_id,
|
||||||
|
vector_database_name,
|
||||||
|
graph_database_name,
|
||||||
|
vector_database_provider,
|
||||||
|
graph_database_provider,
|
||||||
|
vector_dataset_database_handler,
|
||||||
|
graph_dataset_database_handler,
|
||||||
|
vector_database_url,
|
||||||
|
graph_database_url,
|
||||||
|
vector_database_key,
|
||||||
|
graph_database_key,
|
||||||
|
COALESCE(graph_database_connection_info, '{{}}'),
|
||||||
|
COALESCE(vector_database_connection_info, '{{}}'),
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
FROM {TABLE_NAME}
|
||||||
|
""")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop old table
|
||||||
|
op.drop_table(TABLE_NAME)
|
||||||
|
|
||||||
|
# Rename new table
|
||||||
|
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
insp = sa.inspect(conn)
|
||||||
|
|
||||||
|
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
|
||||||
|
|
||||||
|
vector_database_connection_info_column = _get_column(
|
||||||
|
insp, "dataset_database", "vector_database_connection_info"
|
||||||
|
)
|
||||||
|
if not vector_database_connection_info_column:
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"vector_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_dataset_database_handler = _get_column(
|
||||||
|
insp, "dataset_database", "vector_dataset_database_handler"
|
||||||
|
)
|
||||||
|
if not vector_dataset_database_handler:
|
||||||
|
# Add LanceDB as the default graph dataset database handler
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"vector_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="lancedb",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_database_connection_info_column = _get_column(
|
||||||
|
insp, "dataset_database", "graph_database_connection_info"
|
||||||
|
)
|
||||||
|
if not graph_database_connection_info_column:
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"graph_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_dataset_database_handler = _get_column(
|
||||||
|
insp, "dataset_database", "graph_dataset_database_handler"
|
||||||
|
)
|
||||||
|
if not graph_dataset_database_handler:
|
||||||
|
# Add Kuzu as the default graph dataset database handler
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"graph_dataset_database_handler",
|
||||||
|
sa.String(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default="kuzu",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||||
|
# Drop the unique constraint to make unique=False
|
||||||
|
graph_constraint_to_drop = None
|
||||||
|
for uc in unique_constraints:
|
||||||
|
# Check if the constraint covers ONLY the target column
|
||||||
|
if uc["name"] == graph_constraint_name:
|
||||||
|
graph_constraint_to_drop = uc["name"]
|
||||||
|
break
|
||||||
|
|
||||||
|
vector_constraint_to_drop = None
|
||||||
|
for uc in unique_constraints:
|
||||||
|
# Check if the constraint covers ONLY the target column
|
||||||
|
if uc["name"] == vector_constraint_name:
|
||||||
|
vector_constraint_to_drop = uc["name"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if (
|
||||||
|
vector_constraint_to_drop
|
||||||
|
and graph_constraint_to_drop
|
||||||
|
and op.get_context().dialect.name == "postgresql"
|
||||||
|
):
|
||||||
|
# PostgreSQL
|
||||||
|
batch_op.drop_constraint(graph_constraint_name, type_="unique")
|
||||||
|
batch_op.drop_constraint(vector_constraint_name, type_="unique")
|
||||||
|
|
||||||
|
if op.get_context().dialect.name == "sqlite":
|
||||||
|
conn = op.get_bind()
|
||||||
|
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
|
||||||
|
# So we need to check for them and drop them by recreating the table (altering column also won't work)
|
||||||
|
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
|
||||||
|
rows = result.fetchall()
|
||||||
|
unique_auto_indexes = [row for row in rows if row[3] == "u"]
|
||||||
|
for row in unique_auto_indexes:
|
||||||
|
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
|
||||||
|
index_info = result.fetchall()
|
||||||
|
if index_info[0][2] == "vector_database_name":
|
||||||
|
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
|
||||||
|
_recreate_table_without_unique_constraint_sqlite(op, insp)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
insp = sa.inspect(conn)
|
||||||
|
|
||||||
|
if op.get_context().dialect.name == "sqlite":
|
||||||
|
_recreate_table_with_unique_constraint_sqlite(op, insp)
|
||||||
|
elif op.get_context().dialect.name == "postgresql":
|
||||||
|
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||||
|
# Re-add the unique constraint to return to unique=True
|
||||||
|
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
|
||||||
|
|
||||||
|
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||||
|
# Re-add the unique constraint to return to unique=True
|
||||||
|
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
|
||||||
|
|
||||||
|
op.drop_column("dataset_database", "vector_database_connection_info")
|
||||||
|
op.drop_column("dataset_database", "graph_database_connection_info")
|
||||||
|
op.drop_column("dataset_database", "vector_dataset_database_handler")
|
||||||
|
op.drop_column("dataset_database", "graph_dataset_database_handler")
|
||||||
|
|
@ -205,6 +205,7 @@ async def add(
|
||||||
pipeline_name="add_pipeline",
|
pipeline_name="add_pipeline",
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
|
use_pipeline_cache=True,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
from cognee.tasks.documents import (
|
from cognee.tasks.documents import (
|
||||||
check_permissions_on_dataset,
|
|
||||||
classify_documents,
|
classify_documents,
|
||||||
extract_chunks_from_documents,
|
extract_chunks_from_documents,
|
||||||
)
|
)
|
||||||
|
|
@ -79,12 +78,11 @@ async def cognify(
|
||||||
|
|
||||||
Processing Pipeline:
|
Processing Pipeline:
|
||||||
1. **Document Classification**: Identifies document types and structures
|
1. **Document Classification**: Identifies document types and structures
|
||||||
2. **Permission Validation**: Ensures user has processing rights
|
2. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
3. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
4. **Relationship Detection**: Discovers connections between entities
|
||||||
5. **Relationship Detection**: Discovers connections between entities
|
5. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||||
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
6. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||||
7. **Content Summarization**: Creates hierarchical summaries for navigation
|
|
||||||
|
|
||||||
Graph Model Customization:
|
Graph Model Customization:
|
||||||
The `graph_model` parameter allows custom knowledge structures:
|
The `graph_model` parameter allows custom knowledge structures:
|
||||||
|
|
@ -239,6 +237,7 @@ async def cognify(
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
|
use_pipeline_cache=True,
|
||||||
pipeline_name="cognify_pipeline",
|
pipeline_name="cognify_pipeline",
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
)
|
)
|
||||||
|
|
@ -278,7 +277,6 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
|
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
|
||||||
Task(
|
Task(
|
||||||
extract_chunks_from_documents,
|
extract_chunks_from_documents,
|
||||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||||
|
|
@ -313,14 +311,13 @@ async def get_temporal_tasks(
|
||||||
|
|
||||||
The pipeline includes:
|
The pipeline includes:
|
||||||
1. Document classification.
|
1. Document classification.
|
||||||
2. Dataset permission checks (requires "write" access).
|
2. Document chunking with a specified or default chunk size.
|
||||||
3. Document chunking with a specified or default chunk size.
|
3. Event and timestamp extraction from chunks.
|
||||||
4. Event and timestamp extraction from chunks.
|
4. Knowledge graph extraction from events.
|
||||||
5. Knowledge graph extraction from events.
|
5. Batched insertion of data points.
|
||||||
6. Batched insertion of data points.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user (User, optional): The user requesting task execution, used for permission checks.
|
user (User, optional): The user requesting task execution.
|
||||||
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
|
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
|
||||||
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
|
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
|
||||||
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
|
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
|
||||||
|
|
@ -333,7 +330,6 @@ async def get_temporal_tasks(
|
||||||
|
|
||||||
temporal_tasks = [
|
temporal_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
|
||||||
Task(
|
Task(
|
||||||
extract_chunks_from_documents,
|
extract_chunks_from_documents,
|
||||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -45,8 +46,10 @@ class OntologyService:
|
||||||
json.dump(metadata, f, indent=2)
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
async def upload_ontology(
|
async def upload_ontology(
|
||||||
self, ontology_key: str, file, user, description: Optional[str] = None
|
self, ontology_key: str, file: UploadFile, user, description: Optional[str] = None
|
||||||
) -> OntologyMetadata:
|
) -> OntologyMetadata:
|
||||||
|
if not file.filename:
|
||||||
|
raise ValueError("File must have a filename")
|
||||||
if not file.filename.lower().endswith(".owl"):
|
if not file.filename.lower().endswith(".owl"):
|
||||||
raise ValueError("File must be in .owl format")
|
raise ValueError("File must be in .owl format")
|
||||||
|
|
||||||
|
|
@ -57,8 +60,6 @@ class OntologyService:
|
||||||
raise ValueError(f"Ontology key '{ontology_key}' already exists")
|
raise ValueError(f"Ontology key '{ontology_key}' already exists")
|
||||||
|
|
||||||
content = await file.read()
|
content = await file.read()
|
||||||
if len(content) > 10 * 1024 * 1024:
|
|
||||||
raise ValueError("File size exceeds 10MB limit")
|
|
||||||
|
|
||||||
file_path = user_dir / f"{ontology_key}.owl"
|
file_path = user_dir / f"{ontology_key}.owl"
|
||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
|
|
@ -82,7 +83,11 @@ class OntologyService:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upload_ontologies(
|
async def upload_ontologies(
|
||||||
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
|
self,
|
||||||
|
ontology_key: List[str],
|
||||||
|
files: List[UploadFile],
|
||||||
|
user,
|
||||||
|
descriptions: Optional[List[str]] = None,
|
||||||
) -> List[OntologyMetadata]:
|
) -> List[OntologyMetadata]:
|
||||||
"""
|
"""
|
||||||
Upload ontology files with their respective keys.
|
Upload ontology files with their respective keys.
|
||||||
|
|
@ -105,47 +110,17 @@ class OntologyService:
|
||||||
if len(set(ontology_key)) != len(ontology_key):
|
if len(set(ontology_key)) != len(ontology_key):
|
||||||
raise ValueError("Duplicate ontology keys not allowed")
|
raise ValueError("Duplicate ontology keys not allowed")
|
||||||
|
|
||||||
if descriptions and len(descriptions) != len(files):
|
|
||||||
raise ValueError("Number of descriptions must match number of files")
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
user_dir = self._get_user_dir(str(user.id))
|
|
||||||
metadata = self._load_metadata(user_dir)
|
|
||||||
|
|
||||||
for i, (key, file) in enumerate(zip(ontology_key, files)):
|
for i, (key, file) in enumerate(zip(ontology_key, files)):
|
||||||
if key in metadata:
|
|
||||||
raise ValueError(f"Ontology key '{key}' already exists")
|
|
||||||
|
|
||||||
if not file.filename.lower().endswith(".owl"):
|
|
||||||
raise ValueError(f"File '{file.filename}' must be in .owl format")
|
|
||||||
|
|
||||||
content = await file.read()
|
|
||||||
if len(content) > 10 * 1024 * 1024:
|
|
||||||
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
|
|
||||||
|
|
||||||
file_path = user_dir / f"{key}.owl"
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
ontology_metadata = {
|
|
||||||
"filename": file.filename,
|
|
||||||
"size_bytes": len(content),
|
|
||||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"description": descriptions[i] if descriptions else None,
|
|
||||||
}
|
|
||||||
metadata[key] = ontology_metadata
|
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
OntologyMetadata(
|
await self.upload_ontology(
|
||||||
ontology_key=key,
|
ontology_key=key,
|
||||||
filename=file.filename,
|
file=file,
|
||||||
size_bytes=len(content),
|
user=user,
|
||||||
uploaded_at=ontology_metadata["uploaded_at"],
|
|
||||||
description=descriptions[i] if descriptions else None,
|
description=descriptions[i] if descriptions else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._save_metadata(user_dir, metadata)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
|
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,10 @@ from typing import Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||||
|
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
|
||||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||||
from cognee.modules.users.methods import get_user
|
from cognee.modules.users.methods import get_user
|
||||||
|
|
||||||
|
|
@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
||||||
graph_db_config = ContextVar("graph_db_config", default=None)
|
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||||
session_user = ContextVar("session_user", default=None)
|
session_user = ContextVar("session_user", default=None)
|
||||||
|
|
||||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
|
||||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
|
||||||
|
|
||||||
|
|
||||||
async def set_session_user_context_variable(user):
|
async def set_session_user_context_variable(user):
|
||||||
session_user.set(user)
|
session_user.set(user)
|
||||||
|
|
||||||
|
|
||||||
def multi_user_support_possible():
|
def multi_user_support_possible():
|
||||||
graph_db_config = get_graph_context_config()
|
graph_db_config = get_graph_config()
|
||||||
vector_db_config = get_vectordb_context_config()
|
vector_db_config = get_vectordb_config()
|
||||||
return (
|
|
||||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
graph_handler = graph_db_config.graph_dataset_database_handler
|
||||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
vector_handler = vector_db_config.vector_dataset_database_handler
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if graph_handler not in supported_dataset_database_handlers:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if vector_handler not in supported_dataset_database_handlers:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
supported_dataset_database_handlers[graph_handler]["handler_provider"]
|
||||||
|
!= graph_db_config.graph_database_provider
|
||||||
|
):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
|
||||||
|
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
supported_dataset_database_handlers[vector_handler]["handler_provider"]
|
||||||
|
!= vector_db_config.vector_db_provider
|
||||||
|
):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
|
||||||
|
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def backend_access_control_enabled():
|
def backend_access_control_enabled():
|
||||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||||
|
|
@ -41,12 +79,7 @@ def backend_access_control_enabled():
|
||||||
return multi_user_support_possible()
|
return multi_user_support_possible()
|
||||||
elif backend_access_control.lower() == "true":
|
elif backend_access_control.lower() == "true":
|
||||||
# If enabled, ensure that the current graph and vector DBs can support it
|
# If enabled, ensure that the current graph and vector DBs can support it
|
||||||
multi_user_support = multi_user_support_possible()
|
return multi_user_support_possible()
|
||||||
if not multi_user_support:
|
|
||||||
raise EnvironmentError(
|
|
||||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
|
|
||||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||||
|
# Ensure that all connection info is resolved properly
|
||||||
|
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
|
||||||
|
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
data_root_directory = os.path.join(
|
data_root_directory = os.path.join(
|
||||||
|
|
@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set vector and graph database configuration based on dataset database information
|
# Set vector and graph database configuration based on dataset database information
|
||||||
|
# TODO: Add better handling of vector and graph config accross Cognee.
|
||||||
|
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||||
vector_config = {
|
vector_config = {
|
||||||
"vector_db_provider": dataset_database.vector_database_provider,
|
"vector_db_provider": dataset_database.vector_database_provider,
|
||||||
"vector_db_url": dataset_database.vector_database_url,
|
"vector_db_url": dataset_database.vector_database_url,
|
||||||
|
|
@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
"graph_file_path": os.path.join(
|
"graph_file_path": os.path.join(
|
||||||
databases_directory_path, dataset_database.graph_database_name
|
databases_directory_path, dataset_database.graph_database_name
|
||||||
),
|
),
|
||||||
|
"graph_database_username": dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_username", ""
|
||||||
|
),
|
||||||
|
"graph_database_password": dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_password", ""
|
||||||
|
),
|
||||||
|
"graph_dataset_database_handler": "",
|
||||||
|
"graph_database_port": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
storage_config = {
|
storage_config = {
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from cognee.modules.users.models import User
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.tasks.documents import (
|
from cognee.tasks.documents import (
|
||||||
check_permissions_on_dataset,
|
|
||||||
classify_documents,
|
classify_documents,
|
||||||
extract_chunks_from_documents,
|
extract_chunks_from_documents,
|
||||||
)
|
)
|
||||||
|
|
@ -31,7 +30,6 @@ async def get_cascade_graph_tasks(
|
||||||
cognee_config = get_cognify_config()
|
cognee_config = get_cognify_config()
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
|
||||||
Task(
|
Task(
|
||||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
||||||
), # Extract text chunks based on the document type.
|
), # Extract text chunks based on the document type.
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,8 @@ async def get_no_summary_tasks(
|
||||||
ontology_file_path=None,
|
ontology_file_path=None,
|
||||||
) -> List[Task]:
|
) -> List[Task]:
|
||||||
"""Returns default tasks without summarization tasks."""
|
"""Returns default tasks without summarization tasks."""
|
||||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
# Get base tasks (0=classify, 1=extract_chunks)
|
||||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||||
|
|
||||||
ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path)
|
ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path)
|
||||||
|
|
||||||
|
|
@ -51,8 +51,8 @@ async def get_just_chunks_tasks(
|
||||||
chunk_size: int = None, chunker=TextChunker, user=None
|
chunk_size: int = None, chunker=TextChunker, user=None
|
||||||
) -> List[Task]:
|
) -> List[Task]:
|
||||||
"""Returns default tasks with only chunk extraction and data points addition."""
|
"""Returns default tasks with only chunk extraction and data points addition."""
|
||||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
# Get base tasks (0=classify, 1=extract_chunks)
|
||||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker)
|
||||||
|
|
||||||
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})
|
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
|
||||||
|
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||||
|
from .use_dataset_database_handler import use_dataset_database_handler
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from cognee.modules.users.models.User import User
|
||||||
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetDatabaseHandlerInterface(ABC):
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
|
||||||
|
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||||
|
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||||
|
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||||
|
|
||||||
|
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||||
|
From which internal mapping of dataset -> database connection info will be done.
|
||||||
|
|
||||||
|
The returned dictionary is stored verbatim in the relational database and is later passed to
|
||||||
|
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
|
||||||
|
returning only references to secrets or role identifiers, not plaintext credentials.
|
||||||
|
|
||||||
|
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||||
|
user: User object if needed by the database creation logic
|
||||||
|
Returns:
|
||||||
|
dict: Connection info for the created graph or vector database instance.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def resolve_dataset_connection_info(
|
||||||
|
cls, dataset_database: DatasetDatabase
|
||||||
|
) -> DatasetDatabase:
|
||||||
|
"""
|
||||||
|
Resolve runtime connection details for a dataset’s backing graph/vector database.
|
||||||
|
Function is intended to be overwritten to implement custom logic for resolving connection info.
|
||||||
|
|
||||||
|
This method is invoked right before the application opens a connection for a given dataset.
|
||||||
|
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
|
||||||
|
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
|
||||||
|
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
|
||||||
|
|
||||||
|
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
|
||||||
|
connection info and only change parameters related to its appropriate database, the resolution function will then
|
||||||
|
be called one after another with the updated DatasetDatabase value from the previous function as the input.
|
||||||
|
|
||||||
|
Typical behavior:
|
||||||
|
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
|
||||||
|
or api_url/api_key), return them as-is.
|
||||||
|
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
|
||||||
|
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
|
||||||
|
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
|
||||||
|
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
|
||||||
|
to the caller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase row from the relational database
|
||||||
|
Returns:
|
||||||
|
DatasetDatabase: Updated instance with resolved connection info
|
||||||
|
"""
|
||||||
|
return dataset_database
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
|
||||||
|
"""
|
||||||
|
Delete the graph or vector database for the given dataset.
|
||||||
|
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
|
||||||
|
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
|
||||||
|
Neo4jAuraDevDatasetDatabaseHandler,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
|
||||||
|
LanceDBDatasetDatabaseHandler,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||||
|
KuzuDatasetDatabaseHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
supported_dataset_database_handlers = {
|
||||||
|
"neo4j_aura_dev": {
|
||||||
|
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
|
||||||
|
"handler_provider": "neo4j",
|
||||||
|
},
|
||||||
|
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||||
|
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||||
|
|
||||||
|
|
||||||
|
def use_dataset_database_handler(
|
||||||
|
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
|
||||||
|
):
|
||||||
|
supported_dataset_database_handlers[dataset_database_handler_name] = {
|
||||||
|
"handler_instance": dataset_database_handler,
|
||||||
|
"handler_provider": dataset_database_provider,
|
||||||
|
}
|
||||||
|
|
@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
|
||||||
graph_filename: str = ""
|
graph_filename: str = ""
|
||||||
graph_model: object = KnowledgeGraph
|
graph_model: object = KnowledgeGraph
|
||||||
graph_topology: object = KnowledgeGraph
|
graph_topology: object = KnowledgeGraph
|
||||||
|
graph_dataset_database_handler: str = "kuzu"
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||||
|
|
||||||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||||
|
|
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
|
||||||
"graph_model": self.graph_model,
|
"graph_model": self.graph_model,
|
||||||
"graph_topology": self.graph_topology,
|
"graph_topology": self.graph_topology,
|
||||||
"model_config": self.model_config,
|
"model_config": self.model_config,
|
||||||
|
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_hashable_dict(self) -> dict:
|
def to_hashable_dict(self) -> dict:
|
||||||
|
|
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
|
||||||
"graph_database_port": self.graph_database_port,
|
"graph_database_port": self.graph_database_port,
|
||||||
"graph_database_key": self.graph_database_key,
|
"graph_database_key": self.graph_database_key,
|
||||||
"graph_file_path": self.graph_file_path,
|
"graph_file_path": self.graph_file_path,
|
||||||
|
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ def create_graph_engine(
|
||||||
graph_database_password="",
|
graph_database_password="",
|
||||||
graph_database_port="",
|
graph_database_port="",
|
||||||
graph_database_key="",
|
graph_database_key="",
|
||||||
|
graph_dataset_database_handler="",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a graph engine based on the specified provider type.
|
Create a graph engine based on the specified provider type.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
import os
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
"""
|
||||||
|
Handler for interacting with Kuzu Dataset databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset UUID
|
||||||
|
user: User object who owns the dataset and is making the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Connection details for the created Kuzu instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
|
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
if graph_config.graph_database_provider != "kuzu":
|
||||||
|
raise ValueError(
|
||||||
|
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_db_name = f"{dataset_id}.pkl"
|
||||||
|
graph_db_url = graph_config.graph_database_url
|
||||||
|
graph_db_key = graph_config.graph_database_key
|
||||||
|
graph_db_username = graph_config.graph_database_username
|
||||||
|
graph_db_password = graph_config.graph_database_password
|
||||||
|
|
||||||
|
return {
|
||||||
|
"graph_database_name": graph_db_name,
|
||||||
|
"graph_database_url": graph_db_url,
|
||||||
|
"graph_database_provider": graph_config.graph_database_provider,
|
||||||
|
"graph_database_key": graph_db_key,
|
||||||
|
"graph_dataset_database_handler": "kuzu",
|
||||||
|
"graph_database_connection_info": {
|
||||||
|
"graph_database_username": graph_db_username,
|
||||||
|
"graph_database_password": graph_db_password,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
|
base_config = get_base_config()
|
||||||
|
databases_directory_path = os.path.join(
|
||||||
|
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
|
||||||
|
)
|
||||||
|
graph_file_path = os.path.join(
|
||||||
|
databases_directory_path, dataset_database.graph_database_name
|
||||||
|
)
|
||||||
|
graph_engine = create_graph_engine(
|
||||||
|
graph_database_provider=dataset_database.graph_database_provider,
|
||||||
|
graph_database_url=dataset_database.graph_database_url,
|
||||||
|
graph_database_name=dataset_database.graph_database_name,
|
||||||
|
graph_database_key=dataset_database.graph_database_key,
|
||||||
|
graph_file_path=graph_file_path,
|
||||||
|
graph_database_username=dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_username", ""
|
||||||
|
),
|
||||||
|
graph_database_password=dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_password", ""
|
||||||
|
),
|
||||||
|
graph_dataset_database_handler="",
|
||||||
|
graph_database_port="",
|
||||||
|
)
|
||||||
|
await graph_engine.delete_graph()
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
from cognee.modules.users.models import User, DatasetDatabase
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
"""
|
||||||
|
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
|
||||||
|
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
|
||||||
|
|
||||||
|
Improvements needed to be production ready:
|
||||||
|
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
|
||||||
|
a secret manager or a similar system should be used instead.
|
||||||
|
|
||||||
|
Quality of life improvements:
|
||||||
|
- Allow configuration of different Neo4j Aura plans and regions.
|
||||||
|
- Requests should be made async, currently a blocking requests library is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset UUID
|
||||||
|
user: User object who owns the dataset and is making the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Connection details for the created Neo4j instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
if graph_config.graph_database_provider != "neo4j":
|
||||||
|
raise ValueError(
|
||||||
|
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_db_name = f"{dataset_id}"
|
||||||
|
|
||||||
|
# Client credentials and encryption
|
||||||
|
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||||
|
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||||
|
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||||
|
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||||
|
encryption_key = base64.urlsafe_b64encode(
|
||||||
|
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||||
|
)
|
||||||
|
cipher = Fernet(encryption_key)
|
||||||
|
|
||||||
|
if client_id is None or client_secret is None or tenant_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make the request with HTTP Basic Auth
|
||||||
|
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||||
|
url = "https://api.neo4j.io/oauth/token"
|
||||||
|
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||||
|
|
||||||
|
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||||
|
resp.raise_for_status() # raises if the request failed
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
resp = get_aura_token(client_id, client_secret)
|
||||||
|
|
||||||
|
url = "https://api.neo4j.io/v1/instances"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"Authorization": f"Bearer {resp['access_token']}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||||
|
# Too allow different configurations between datasets
|
||||||
|
payload = {
|
||||||
|
"version": "5",
|
||||||
|
"region": "europe-west1",
|
||||||
|
"memory": "1GB",
|
||||||
|
"name": graph_db_name[
|
||||||
|
0:29
|
||||||
|
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||||
|
"type": "professional-db",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"cloud_provider": "gcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
|
|
||||||
|
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||||
|
graph_db_url = response.json()["data"]["connection_url"]
|
||||||
|
graph_db_key = resp["access_token"]
|
||||||
|
graph_db_username = response.json()["data"]["username"]
|
||||||
|
graph_db_password = response.json()["data"]["password"]
|
||||||
|
|
||||||
|
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||||
|
# Poll until the instance is running
|
||||||
|
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||||
|
status = ""
|
||||||
|
for attempt in range(30): # Try for up to ~5 minutes
|
||||||
|
status_resp = requests.get(
|
||||||
|
status_url, headers=headers
|
||||||
|
) # TODO: Use async requests with httpx
|
||||||
|
status = status_resp.json()["data"]["status"]
|
||||||
|
if status.lower() == "running":
|
||||||
|
return
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
instance_id = response.json()["data"]["id"]
|
||||||
|
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||||
|
|
||||||
|
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||||
|
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"graph_database_name": graph_db_name,
|
||||||
|
"graph_database_url": graph_db_url,
|
||||||
|
"graph_database_provider": "neo4j",
|
||||||
|
"graph_database_key": graph_db_key,
|
||||||
|
"graph_dataset_database_handler": "neo4j_aura_dev",
|
||||||
|
"graph_database_connection_info": {
|
||||||
|
"graph_database_username": graph_db_username,
|
||||||
|
"graph_database_password": encrypted_db_password_string,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def resolve_dataset_connection_info(
|
||||||
|
cls, dataset_database: DatasetDatabase
|
||||||
|
) -> DatasetDatabase:
|
||||||
|
"""
|
||||||
|
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||||
|
In this case, decrypt the password stored in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||||
|
"""
|
||||||
|
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||||
|
encryption_key = base64.urlsafe_b64encode(
|
||||||
|
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||||
|
)
|
||||||
|
cipher = Fernet(encryption_key)
|
||||||
|
graph_db_password = cipher.decrypt(
|
||||||
|
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||||
|
graph_db_password
|
||||||
|
)
|
||||||
|
return dataset_database
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
|
pass
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
from .get_or_create_dataset_database import get_or_create_dataset_database
|
from .get_or_create_dataset_database import get_or_create_dataset_database
|
||||||
|
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
import os
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
|
||||||
from cognee.modules.data.methods import create_dataset
|
from cognee.modules.data.methods import create_dataset
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
|
@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||||
|
|
||||||
|
|
||||||
|
async def _existing_dataset_database(
|
||||||
|
dataset_id: UUID,
|
||||||
|
user: User,
|
||||||
|
) -> Optional[DatasetDatabase]:
|
||||||
|
"""
|
||||||
|
Check if a DatasetDatabase row already exists for the given owner + dataset.
|
||||||
|
Return None if it doesn't exist, return the row if it does.
|
||||||
|
Args:
|
||||||
|
dataset_id:
|
||||||
|
user:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatasetDatabase or None
|
||||||
|
"""
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
stmt = select(DatasetDatabase).where(
|
||||||
|
DatasetDatabase.owner_id == user.id,
|
||||||
|
DatasetDatabase.dataset_id == dataset_id,
|
||||||
|
)
|
||||||
|
existing: DatasetDatabase = await session.scalar(stmt)
|
||||||
|
return existing
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_dataset_database(
|
async def get_or_create_dataset_database(
|
||||||
dataset: Union[str, UUID],
|
dataset: Union[str, UUID],
|
||||||
user: User,
|
user: User,
|
||||||
|
|
@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
|
||||||
• If the row already exists, it is fetched and returned.
|
• If the row already exists, it is fetched and returned.
|
||||||
• Otherwise a new one is created atomically and returned.
|
• Otherwise a new one is created atomically and returned.
|
||||||
|
|
||||||
|
DatasetDatabase row contains connection and provider info for vector and graph databases.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
user : User
|
user : User
|
||||||
|
|
@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
|
||||||
|
|
||||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||||
|
|
||||||
vector_config = get_vectordb_config()
|
# If dataset is given as name make sure the dataset is created first
|
||||||
graph_config = get_graph_config()
|
if isinstance(dataset, str):
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
await create_dataset(dataset, user, session)
|
||||||
|
|
||||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
# If dataset database already exists return it
|
||||||
if graph_config.graph_database_provider == "kuzu":
|
existing_dataset_database = await _existing_dataset_database(dataset_id, user)
|
||||||
graph_db_name = f"{dataset_id}.pkl"
|
if existing_dataset_database:
|
||||||
else:
|
return existing_dataset_database
|
||||||
graph_db_name = f"{dataset_id}"
|
|
||||||
|
|
||||||
if vector_config.vector_db_provider == "lancedb":
|
graph_config_dict = await _get_graph_db_info(dataset_id, user)
|
||||||
vector_db_name = f"{dataset_id}.lance.db"
|
vector_config_dict = await _get_vector_db_info(dataset_id, user)
|
||||||
else:
|
|
||||||
vector_db_name = f"{dataset_id}"
|
|
||||||
|
|
||||||
base_config = get_base_config()
|
|
||||||
databases_directory_path = os.path.join(
|
|
||||||
base_config.system_root_directory, "databases", str(user.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine vector database URL
|
|
||||||
if vector_config.vector_db_provider == "lancedb":
|
|
||||||
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
|
||||||
else:
|
|
||||||
vector_db_url = vector_config.vector_database_url
|
|
||||||
|
|
||||||
# Determine graph database URL
|
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Create dataset if it doesn't exist
|
|
||||||
if isinstance(dataset, str):
|
|
||||||
dataset = await create_dataset(dataset, user, session)
|
|
||||||
|
|
||||||
# Try to fetch an existing row first
|
|
||||||
stmt = select(DatasetDatabase).where(
|
|
||||||
DatasetDatabase.owner_id == user.id,
|
|
||||||
DatasetDatabase.dataset_id == dataset_id,
|
|
||||||
)
|
|
||||||
existing: DatasetDatabase = await session.scalar(stmt)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
# If there are no existing rows build a new row
|
# If there are no existing rows build a new row
|
||||||
record = DatasetDatabase(
|
record = DatasetDatabase(
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
vector_database_name=vector_db_name,
|
**graph_config_dict, # Unpack graph db config
|
||||||
graph_database_name=graph_db_name,
|
**vector_config_dict, # Unpack vector db config
|
||||||
vector_database_provider=vector_config.vector_db_provider,
|
|
||||||
graph_database_provider=graph_config.graph_database_provider,
|
|
||||||
vector_database_url=vector_db_url,
|
|
||||||
graph_database_url=graph_config.graph_database_url,
|
|
||||||
vector_database_key=vector_config.vector_db_key,
|
|
||||||
graph_database_key=graph_config.graph_database_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_dataset_database_connection_info(
|
||||||
|
dataset_database: DatasetDatabase,
|
||||||
|
) -> DatasetDatabase:
|
||||||
|
"""
|
||||||
|
Resolve the connection info for the given DatasetDatabase instance.
|
||||||
|
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase instance
|
||||||
|
Returns:
|
||||||
|
DatasetDatabase instance with resolved connection info
|
||||||
|
"""
|
||||||
|
dataset_database = await _get_vector_db_connection_info(dataset_database)
|
||||||
|
dataset_database = await _get_graph_db_connection_info(dataset_database)
|
||||||
|
return dataset_database
|
||||||
|
|
@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
|
||||||
vector_db_name: str = ""
|
vector_db_name: str = ""
|
||||||
vector_db_key: str = ""
|
vector_db_key: str = ""
|
||||||
vector_db_provider: str = "lancedb"
|
vector_db_provider: str = "lancedb"
|
||||||
|
vector_dataset_database_handler: str = "lancedb"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
|
||||||
"vector_db_name": self.vector_db_name,
|
"vector_db_name": self.vector_db_name,
|
||||||
"vector_db_key": self.vector_db_key,
|
"vector_db_key": self.vector_db_key,
|
||||||
"vector_db_provider": self.vector_db_provider,
|
"vector_db_provider": self.vector_db_provider,
|
||||||
|
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ def create_vector_engine(
|
||||||
vector_db_name: str,
|
vector_db_name: str,
|
||||||
vector_db_port: str = "",
|
vector_db_port: str = "",
|
||||||
vector_db_key: str = "",
|
vector_db_key: str = "",
|
||||||
|
vector_dataset_database_handler: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a vector database engine based on the specified provider.
|
Create a vector database engine based on the specified provider.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
import os
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
"""
|
||||||
|
Handler for interacting with LanceDB Dataset databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
base_config = get_base_config()
|
||||||
|
|
||||||
|
if vector_config.vector_db_provider != "lancedb":
|
||||||
|
raise ValueError(
|
||||||
|
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
|
||||||
|
)
|
||||||
|
|
||||||
|
databases_directory_path = os.path.join(
|
||||||
|
base_config.system_root_directory, "databases", str(user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db_name = f"{dataset_id}.lance.db"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vector_database_provider": vector_config.vector_db_provider,
|
||||||
|
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||||
|
"vector_database_key": vector_config.vector_db_key,
|
||||||
|
"vector_database_name": vector_db_name,
|
||||||
|
"vector_dataset_database_handler": "lancedb",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
|
vector_engine = create_vector_engine(
|
||||||
|
vector_db_provider=dataset_database.vector_database_provider,
|
||||||
|
vector_db_url=dataset_database.vector_database_url,
|
||||||
|
vector_db_key=dataset_database.vector_database_key,
|
||||||
|
vector_db_name=dataset_database.vector_database_name,
|
||||||
|
)
|
||||||
|
await vector_engine.prune()
|
||||||
|
|
@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from .models.PayloadSchema import PayloadSchema
|
from .models.PayloadSchema import PayloadSchema
|
||||||
|
from uuid import UUID
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
|
||||||
class VectorDBInterface(Protocol):
|
class VectorDBInterface(Protocol):
|
||||||
|
|
@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
|
||||||
- Any: The schema object suitable for this vector database
|
- Any: The schema object suitable for this vector database
|
||||||
"""
|
"""
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dictionary with connection info for a vector database for the given dataset.
|
||||||
|
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||||
|
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||||
|
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||||
|
|
||||||
|
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||||
|
From which internal mapping of dataset -> database connection info will be done.
|
||||||
|
|
||||||
|
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||||
|
user: User object if needed by the database creation logic
|
||||||
|
Returns:
|
||||||
|
dict: Connection info for the created vector database instance.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
||||||
|
"""
|
||||||
|
Delete the vector database for the given dataset.
|
||||||
|
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
||||||
|
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset
|
||||||
|
user: User object
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,81 @@
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.shared.cache import delete_cache
|
from cognee.shared.cache import delete_cache
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def prune_graph_databases():
|
||||||
|
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[
|
||||||
|
dataset_database.graph_dataset_database_handler
|
||||||
|
]
|
||||||
|
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
|
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
try:
|
||||||
|
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||||
|
# Go through each dataset database and delete the graph database
|
||||||
|
for data_item in data:
|
||||||
|
await _prune_graph_db(data_item)
|
||||||
|
except (OperationalError, EntityNotFoundError) as e:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def prune_vector_databases():
|
||||||
|
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[
|
||||||
|
dataset_database.vector_dataset_database_handler
|
||||||
|
]
|
||||||
|
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
|
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
try:
|
||||||
|
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||||
|
# Go through each dataset database and delete the vector database
|
||||||
|
for data_item in data:
|
||||||
|
await _prune_vector_db(data_item)
|
||||||
|
except (OperationalError, EntityNotFoundError) as e:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||||
if graph:
|
# Note: prune system should not be available through the API, it has no permission checks and will
|
||||||
|
# delete all graph and vector databases if called. It should only be used in development or testing environments.
|
||||||
|
if graph and not backend_access_control_enabled():
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
await graph_engine.delete_graph()
|
await graph_engine.delete_graph()
|
||||||
|
elif graph and backend_access_control_enabled():
|
||||||
|
await prune_graph_databases()
|
||||||
|
|
||||||
if vector:
|
if vector and not backend_access_control_enabled():
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
await vector_engine.prune()
|
await vector_engine.prune()
|
||||||
|
elif vector and backend_access_control_enabled():
|
||||||
|
await prune_vector_databases()
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,6 @@ from cognee.modules.users.models import User
|
||||||
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
||||||
resolve_authorized_user_datasets,
|
resolve_authorized_user_datasets,
|
||||||
)
|
)
|
||||||
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|
||||||
reset_dataset_pipeline_run_status,
|
|
||||||
)
|
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
||||||
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
|
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
|
||||||
|
|
@ -97,10 +94,6 @@ async def memify(
|
||||||
*enrichment_tasks,
|
*enrichment_tasks,
|
||||||
]
|
]
|
||||||
|
|
||||||
await reset_dataset_pipeline_run_status(
|
|
||||||
authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||||
|
|
||||||
|
|
@ -113,6 +106,7 @@ async def memify(
|
||||||
datasets=authorized_dataset.id,
|
datasets=authorized_dataset.id,
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
|
use_pipeline_cache=False,
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
pipeline_name="memify_pipeline",
|
pipeline_name="memify_pipeline",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,9 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
||||||
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
||||||
check_pipeline_run_qualification,
|
check_pipeline_run_qualification,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||||
|
PipelineRunStarted,
|
||||||
|
)
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
logger = get_logger("cognee.pipeline")
|
logger = get_logger("cognee.pipeline")
|
||||||
|
|
@ -35,6 +38,7 @@ async def run_pipeline(
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
vector_db_config: dict = None,
|
vector_db_config: dict = None,
|
||||||
graph_db_config: dict = None,
|
graph_db_config: dict = None,
|
||||||
|
use_pipeline_cache: bool = False,
|
||||||
incremental_loading: bool = False,
|
incremental_loading: bool = False,
|
||||||
data_per_batch: int = 20,
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
|
|
@ -51,6 +55,7 @@ async def run_pipeline(
|
||||||
data=data,
|
data=data,
|
||||||
pipeline_name=pipeline_name,
|
pipeline_name=pipeline_name,
|
||||||
context={"dataset": dataset},
|
context={"dataset": dataset},
|
||||||
|
use_pipeline_cache=use_pipeline_cache,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
):
|
):
|
||||||
|
|
@ -64,6 +69,7 @@ async def run_pipeline_per_dataset(
|
||||||
data=None,
|
data=None,
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
context: dict = None,
|
context: dict = None,
|
||||||
|
use_pipeline_cache=False,
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
data_per_batch: int = 20,
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
|
|
@ -77,8 +83,18 @@ async def run_pipeline_per_dataset(
|
||||||
if process_pipeline_status:
|
if process_pipeline_status:
|
||||||
# If pipeline was already processed or is currently being processed
|
# If pipeline was already processed or is currently being processed
|
||||||
# return status information to async generator and finish execution
|
# return status information to async generator and finish execution
|
||||||
yield process_pipeline_status
|
if use_pipeline_cache:
|
||||||
return
|
# If pipeline caching is enabled we do not proceed with re-processing
|
||||||
|
yield process_pipeline_status
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# If pipeline caching is disabled we always return pipeline started information and proceed with re-processing
|
||||||
|
yield PipelineRunStarted(
|
||||||
|
pipeline_run_id=process_pipeline_status.pipeline_run_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
payload=data,
|
||||||
|
)
|
||||||
|
|
||||||
pipeline_run = run_tasks(
|
pipeline_run = run_tasks(
|
||||||
tasks,
|
tasks,
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ async def run_custom_pipeline(
|
||||||
user: User = None,
|
user: User = None,
|
||||||
vector_db_config: Optional[dict] = None,
|
vector_db_config: Optional[dict] = None,
|
||||||
graph_db_config: Optional[dict] = None,
|
graph_db_config: Optional[dict] = None,
|
||||||
|
use_pipeline_cache: bool = False,
|
||||||
|
incremental_loading: bool = False,
|
||||||
data_per_batch: int = 20,
|
data_per_batch: int = 20,
|
||||||
run_in_background: bool = False,
|
run_in_background: bool = False,
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
|
|
@ -40,6 +42,10 @@ async def run_custom_pipeline(
|
||||||
user: User context for authentication and data access. Uses default if None.
|
user: User context for authentication and data access. Uses default if None.
|
||||||
vector_db_config: Custom vector database configuration for embeddings storage.
|
vector_db_config: Custom vector database configuration for embeddings storage.
|
||||||
graph_db_config: Custom graph database configuration for relationship storage.
|
graph_db_config: Custom graph database configuration for relationship storage.
|
||||||
|
use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again.
|
||||||
|
Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function.
|
||||||
|
incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model).
|
||||||
|
The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash.
|
||||||
data_per_batch: Number of data items to be processed in parallel.
|
data_per_batch: Number of data items to be processed in parallel.
|
||||||
run_in_background: If True, starts processing asynchronously and returns immediately.
|
run_in_background: If True, starts processing asynchronously and returns immediately.
|
||||||
If False, waits for completion before returning.
|
If False, waits for completion before returning.
|
||||||
|
|
@ -63,7 +69,8 @@ async def run_custom_pipeline(
|
||||||
datasets=dataset,
|
datasets=dataset,
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
incremental_loading=False,
|
use_pipeline_cache=use_pipeline_cache,
|
||||||
|
incremental_loading=incremental_loading,
|
||||||
data_per_batch=data_per_batch,
|
data_per_batch=data_per_batch,
|
||||||
pipeline_name=pipeline_name,
|
pipeline_name=pipeline_name,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
|
||||||
|
|
||||||
# Check environment variable to determine authentication requirement
|
# Check environment variable to determine authentication requirement
|
||||||
REQUIRE_AUTHENTICATION = (
|
REQUIRE_AUTHENTICATION = (
|
||||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
|
||||||
or backend_access_control_enabled()
|
or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
fastapi_users = get_fastapi_users()
|
fastapi_users = get_fastapi_users()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
|
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text
|
||||||
from cognee.infrastructure.databases.relational import Base
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -12,17 +12,29 @@ class DatasetDatabase(Base):
|
||||||
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
|
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_database_name = Column(String, unique=True, nullable=False)
|
vector_database_name = Column(String, unique=False, nullable=False)
|
||||||
graph_database_name = Column(String, unique=True, nullable=False)
|
graph_database_name = Column(String, unique=False, nullable=False)
|
||||||
|
|
||||||
vector_database_provider = Column(String, unique=False, nullable=False)
|
vector_database_provider = Column(String, unique=False, nullable=False)
|
||||||
graph_database_provider = Column(String, unique=False, nullable=False)
|
graph_database_provider = Column(String, unique=False, nullable=False)
|
||||||
|
|
||||||
|
graph_dataset_database_handler = Column(String, unique=False, nullable=False)
|
||||||
|
vector_dataset_database_handler = Column(String, unique=False, nullable=False)
|
||||||
|
|
||||||
vector_database_url = Column(String, unique=False, nullable=True)
|
vector_database_url = Column(String, unique=False, nullable=True)
|
||||||
graph_database_url = Column(String, unique=False, nullable=True)
|
graph_database_url = Column(String, unique=False, nullable=True)
|
||||||
|
|
||||||
vector_database_key = Column(String, unique=False, nullable=True)
|
vector_database_key = Column(String, unique=False, nullable=True)
|
||||||
graph_database_key = Column(String, unique=False, nullable=True)
|
graph_database_key = Column(String, unique=False, nullable=True)
|
||||||
|
|
||||||
|
# configuration details for different database types. This would make it more flexible to add new database types
|
||||||
|
# without changing the database schema.
|
||||||
|
graph_database_connection_info = Column(
|
||||||
|
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||||
|
)
|
||||||
|
vector_database_connection_info = Column(
|
||||||
|
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||||
|
)
|
||||||
|
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
|
||||||
|
|
@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None):
|
||||||
# Get a configured logger and log system information
|
# Get a configured logger and log system information
|
||||||
logger = structlog.get_logger(name if name else __name__)
|
logger = structlog.get_logger(name if name else __name__)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation."
|
||||||
|
)
|
||||||
|
|
||||||
if logs_dir is not None:
|
if logs_dir is not None:
|
||||||
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,2 @@
|
||||||
from .classify_documents import classify_documents
|
from .classify_documents import classify_documents
|
||||||
from .extract_chunks_from_documents import extract_chunks_from_documents
|
from .extract_chunks_from_documents import extract_chunks_from_documents
|
||||||
from .check_permissions_on_dataset import check_permissions_on_dataset
|
|
||||||
|
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
from cognee.modules.data.processing.document_types import Document
|
|
||||||
from cognee.modules.users.permissions.methods import check_permission_on_dataset
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
async def check_permissions_on_dataset(
|
|
||||||
documents: List[Document], context: dict, user, permissions
|
|
||||||
) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Validates a user's permissions on a list of documents.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- This function assumes that `check_permission_on_documents` raises an exception if the permission check fails.
|
|
||||||
- It is designed to validate multiple permissions in a sequential manner for the same set of documents.
|
|
||||||
- Ensure that the `Document` and `user` objects conform to the expected structure and interfaces.
|
|
||||||
"""
|
|
||||||
|
|
||||||
for permission in permissions:
|
|
||||||
await check_permission_on_dataset(
|
|
||||||
user,
|
|
||||||
permission,
|
|
||||||
# TODO: pass dataset through argument instead of context
|
|
||||||
context["dataset"].id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return documents
|
|
||||||
137
cognee/tests/test_dataset_database_handler.py
Normal file
137
cognee/tests/test_dataset_database_handler.py
Normal file
|
|
@ -0,0 +1,137 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set custom dataset database handler environment variable
|
||||||
|
os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler"
|
||||||
|
os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler"
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||||
|
from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id, user):
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id))
|
||||||
|
os.makedirs(databases_directory_path, exist_ok=True)
|
||||||
|
|
||||||
|
vector_db_name = "test.lance.db"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vector_dataset_database_handler": "custom_lancedb_handler",
|
||||||
|
"vector_database_name": vector_db_name,
|
||||||
|
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||||
|
"vector_database_provider": "lancedb",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id, user):
|
||||||
|
databases_directory_path = os.path.join("databases", str(user.id))
|
||||||
|
os.makedirs(databases_directory_path, exist_ok=True)
|
||||||
|
|
||||||
|
graph_db_name = "test.kuzu"
|
||||||
|
return {
|
||||||
|
"graph_dataset_database_handler": "custom_kuzu_handler",
|
||||||
|
"graph_database_name": graph_db_name,
|
||||||
|
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
|
||||||
|
"graph_database_provider": "kuzu",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
data_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler"
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
# Add custom dataset database handler
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import (
|
||||||
|
use_dataset_database_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_dataset_database_handler(
|
||||||
|
"custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb"
|
||||||
|
)
|
||||||
|
use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu")
|
||||||
|
|
||||||
|
# Create a clean slate for cognee -- reset data and system state
|
||||||
|
print("Resetting cognee data...")
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
print("Data reset complete.\n")
|
||||||
|
|
||||||
|
# cognee knowledge graph will be created based on this text
|
||||||
|
text = """
|
||||||
|
Natural language processing (NLP) is an interdisciplinary
|
||||||
|
subfield of computer science and information retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("Adding text to cognee:")
|
||||||
|
print(text.strip())
|
||||||
|
|
||||||
|
# Add the text, and make it available for cognify
|
||||||
|
await cognee.add(text)
|
||||||
|
print("Text added successfully.\n")
|
||||||
|
|
||||||
|
# Use LLMs and cognee to create knowledge graph
|
||||||
|
await cognee.cognify()
|
||||||
|
print("Cognify process complete.\n")
|
||||||
|
|
||||||
|
query_text = "Tell me about NLP"
|
||||||
|
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||||
|
# Query cognee for insights on the added text
|
||||||
|
search_results = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Search results:")
|
||||||
|
# Display results
|
||||||
|
for result_text in search_results:
|
||||||
|
print(result_text)
|
||||||
|
|
||||||
|
default_user = await get_default_user()
|
||||||
|
# Assert that the custom database files were created based on the custom dataset database handlers
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu")
|
||||||
|
), "Graph database file not found."
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db")
|
||||||
|
), "Vector database file not found."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = setup_logging(log_level=ERROR)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
164
cognee/tests/test_pipeline_cache.py
Normal file
164
cognee/tests/test_pipeline_cache.py
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
"""
|
||||||
|
Test suite for the pipeline_cache feature in Cognee pipelines.
|
||||||
|
|
||||||
|
This module tests the behavior of the `pipeline_cache` parameter which controls
|
||||||
|
whether a pipeline should skip re-execution when it has already been completed
|
||||||
|
for the same dataset.
|
||||||
|
|
||||||
|
Architecture Overview:
|
||||||
|
---------------------
|
||||||
|
The pipeline_cache mechanism works at the dataset level:
|
||||||
|
1. When a pipeline runs, it logs its status (INITIATED -> STARTED -> COMPLETED)
|
||||||
|
2. Before each run, `check_pipeline_run_qualification()` checks the pipeline status
|
||||||
|
3. If `use_pipeline_cache=True` and status is COMPLETED/STARTED, the pipeline skips
|
||||||
|
4. If `use_pipeline_cache=False`, the pipeline always re-executes regardless of status
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
from cognee.modules.pipelines import run_pipeline
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
||||||
|
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
||||||
|
reset_dataset_pipeline_run_status,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionCounter:
|
||||||
|
"""Helper class to track task execution counts."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def create_counting_task(data, counter: ExecutionCounter):
|
||||||
|
"""Create a task that increments a counter from the ExecutionCounter instance when executed."""
|
||||||
|
counter.count += 1
|
||||||
|
return counter
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineCache:
|
||||||
|
"""Tests for basic pipeline_cache on/off behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_cache_off_allows_reexecution(self):
|
||||||
|
"""
|
||||||
|
Test that with use_pipeline_cache=False, the pipeline re-executes
|
||||||
|
even when it has already completed for the dataset.
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- First run: Pipeline executes fully, task runs once
|
||||||
|
- Second run: Pipeline executes again, task runs again (total: 2 times)
|
||||||
|
"""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await create_db_and_tables()
|
||||||
|
|
||||||
|
counter = ExecutionCounter()
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
tasks = [Task(create_counting_task, counter=counter)]
|
||||||
|
|
||||||
|
# First run
|
||||||
|
pipeline_results_1 = []
|
||||||
|
async for result in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets="test_dataset_cache_off",
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
user=user,
|
||||||
|
pipeline_name="test_cache_off_pipeline",
|
||||||
|
use_pipeline_cache=False,
|
||||||
|
):
|
||||||
|
pipeline_results_1.append(result)
|
||||||
|
|
||||||
|
first_run_count = counter.count
|
||||||
|
assert first_run_count >= 1, "Task should have executed at least once on first run"
|
||||||
|
|
||||||
|
# Second run with pipeline_cache=False
|
||||||
|
pipeline_results_2 = []
|
||||||
|
async for result in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets="test_dataset_cache_off",
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
user=user,
|
||||||
|
pipeline_name="test_cache_off_pipeline",
|
||||||
|
use_pipeline_cache=False,
|
||||||
|
):
|
||||||
|
pipeline_results_2.append(result)
|
||||||
|
|
||||||
|
second_run_count = counter.count
|
||||||
|
assert second_run_count > first_run_count, (
|
||||||
|
f"With pipeline_cache=False, task should re-execute. "
|
||||||
|
f"First run: {first_run_count}, After second run: {second_run_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_pipeline_status_allows_reexecution_with_cache(self):
|
||||||
|
"""
|
||||||
|
Test that resetting pipeline status allows re-execution even with
|
||||||
|
pipeline_cache=True.
|
||||||
|
"""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await create_db_and_tables()
|
||||||
|
|
||||||
|
counter = ExecutionCounter()
|
||||||
|
user = await get_default_user()
|
||||||
|
dataset_name = "reset_status_test"
|
||||||
|
pipeline_name = "test_reset_pipeline"
|
||||||
|
|
||||||
|
tasks = [Task(create_counting_task, counter=counter)]
|
||||||
|
|
||||||
|
# First run
|
||||||
|
pipeline_result = []
|
||||||
|
async for result in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets=dataset_name,
|
||||||
|
user=user,
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
use_pipeline_cache=True,
|
||||||
|
):
|
||||||
|
pipeline_result.append(result)
|
||||||
|
|
||||||
|
first_run_count = counter.count
|
||||||
|
assert first_run_count >= 1
|
||||||
|
|
||||||
|
# Second run without reset - should skip
|
||||||
|
async for _ in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets=dataset_name,
|
||||||
|
user=user,
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
use_pipeline_cache=True,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
after_second_run = counter.count
|
||||||
|
assert after_second_run == first_run_count, "Should have skipped due to cache"
|
||||||
|
|
||||||
|
# Reset the pipeline status
|
||||||
|
await reset_dataset_pipeline_run_status(
|
||||||
|
pipeline_result[0].dataset_id, user, pipeline_names=[pipeline_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Third run after reset - should execute
|
||||||
|
async for _ in run_pipeline(
|
||||||
|
tasks=tasks,
|
||||||
|
datasets=dataset_name,
|
||||||
|
user=user,
|
||||||
|
data=["sample data"], # Data is necessary to trigger processing
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
use_pipeline_cache=True,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
after_reset_run = counter.count
|
||||||
|
assert after_reset_run > after_second_run, (
|
||||||
|
f"After reset, pipeline should re-execute. "
|
||||||
|
f"Before reset: {after_second_run}, After reset run: {after_reset_run}"
|
||||||
|
)
|
||||||
|
|
@ -32,16 +32,13 @@ async def main():
|
||||||
print("Cognify process steps:")
|
print("Cognify process steps:")
|
||||||
print("1. Classifying the document: Determining the type and category of the input text.")
|
print("1. Classifying the document: Determining the type and category of the input text.")
|
||||||
print(
|
print(
|
||||||
"2. Checking permissions: Ensuring the user has the necessary rights to process the text."
|
"2. Extracting text chunks: Breaking down the text into sentences or phrases for analysis."
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis."
|
"3. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph."
|
||||||
)
|
)
|
||||||
print("4. Adding data points: Storing the extracted chunks for processing.")
|
print("4. Summarizing text: Creating concise summaries of the content for quick insights.")
|
||||||
print(
|
print("5. Adding data points: Storing the extracted chunks for processing.\n")
|
||||||
"5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph."
|
|
||||||
)
|
|
||||||
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
|
|
||||||
|
|
||||||
# Use LLMs and cognee to create knowledge graph
|
# Use LLMs and cognee to create knowledge graph
|
||||||
await cognee.cognify()
|
await cognee.cognify()
|
||||||
|
|
|
||||||
4
notebooks/cognee_demo.ipynb
vendored
4
notebooks/cognee_demo.ipynb
vendored
|
|
@ -591,7 +591,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": null,
|
||||||
"id": "7c431fdef4921ae0",
|
"id": "7c431fdef4921ae0",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -609,7 +609,6 @@
|
||||||
"from cognee.modules.pipelines import run_tasks\n",
|
"from cognee.modules.pipelines import run_tasks\n",
|
||||||
"from cognee.modules.users.models import User\n",
|
"from cognee.modules.users.models import User\n",
|
||||||
"from cognee.tasks.documents import (\n",
|
"from cognee.tasks.documents import (\n",
|
||||||
" check_permissions_on_dataset,\n",
|
|
||||||
" classify_documents,\n",
|
" classify_documents,\n",
|
||||||
" extract_chunks_from_documents,\n",
|
" extract_chunks_from_documents,\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
|
@ -627,7 +626,6 @@
|
||||||
"\n",
|
"\n",
|
||||||
" tasks = [\n",
|
" tasks = [\n",
|
||||||
" Task(classify_documents),\n",
|
" Task(classify_documents),\n",
|
||||||
" Task(check_permissions_on_dataset, user=user, permissions=[\"write\"]),\n",
|
|
||||||
" Task(\n",
|
" Task(\n",
|
||||||
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
|
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
|
||||||
" ), # Extract text chunks based on the document type.\n",
|
" ), # Extract text chunks based on the document type.\n",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue