feat: add dataset database handler logic and neo4j/lancedb/kuzu handlers (#1776)
<!-- .github/pull_request_template.md -->
## Description
Add ability to use multi tenant multi user mode with Neo4j
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Multi-user support with per-dataset database isolation enabled by
default, allowing backend access control for secure data separation.
* Configurable database handlers via environment variables
(GRAPH_DATASET_DATABASE_HANDLER, VECTOR_DATASET_DATABASE_HANDLER) for
flexible deployment options.
* **Chores**
* Database schema migration to support per-user dataset database
configurations.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
46ddd4fd12
31 changed files with 1143 additions and 74 deletions
|
|
@ -97,6 +97,8 @@ DB_NAME=cognee_db
|
||||||
|
|
||||||
# Default (local file-based)
|
# Default (local file-based)
|
||||||
GRAPH_DATABASE_PROVIDER="kuzu"
|
GRAPH_DATABASE_PROVIDER="kuzu"
|
||||||
|
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||||
|
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
|
||||||
|
|
||||||
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
||||||
#GRAPH_DATABASE_PROVIDER="kuzu"
|
#GRAPH_DATABASE_PROVIDER="kuzu"
|
||||||
|
|
@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb"
|
||||||
# Not needed if a cloud vector database is not used
|
# Not needed if a cloud vector database is not used
|
||||||
VECTOR_DB_URL=
|
VECTOR_DB_URL=
|
||||||
VECTOR_DB_KEY=
|
VECTOR_DB_KEY=
|
||||||
|
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||||
|
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🧩 Ontology resolver settings
|
# 🧩 Ontology resolver settings
|
||||||
|
|
|
||||||
2
.github/workflows/db_examples_tests.yml
vendored
2
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -61,6 +61,7 @@ jobs:
|
||||||
- name: Run Neo4j Example
|
- name: Run Neo4j Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
@ -142,6 +143,7 @@ jobs:
|
||||||
- name: Run PGVector Example
|
- name: Run PGVector Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
1
.github/workflows/distributed_test.yml
vendored
1
.github/workflows/distributed_test.yml
vendored
|
|
@ -47,6 +47,7 @@ jobs:
|
||||||
- name: Run Distributed Cognee (Modal)
|
- name: Run Distributed Cognee (Modal)
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
26
.github/workflows/e2e_tests.yml
vendored
26
.github/workflows/e2e_tests.yml
vendored
|
|
@ -147,6 +147,7 @@ jobs:
|
||||||
- name: Run Deduplication Example
|
- name: Run Deduplication Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
|
@ -211,6 +212,31 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_parallel_databases.py
|
run: uv run python ./cognee/tests/test_parallel_databases.py
|
||||||
|
|
||||||
|
test-dataset-database-handler:
|
||||||
|
name: Test dataset database handlers in Cognee
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Run dataset databases handler test
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
||||||
|
|
||||||
test-permissions:
|
test-permissions:
|
||||||
name: Test permissions with different situations in Cognee
|
name: Test permissions with different situations in Cognee
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
|
|
||||||
1
.github/workflows/examples_tests.yml
vendored
1
.github/workflows/examples_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
||||||
- name: Run Descriptive Graph Metrics Example
|
- name: Run Descriptive Graph Metrics Example
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
1
.github/workflows/graph_db_tests.yml
vendored
1
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -78,6 +78,7 @@ jobs:
|
||||||
- name: Run default Neo4j
|
- name: Run default Neo4j
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
|
||||||
3
.github/workflows/temporal_graph_tests.yml
vendored
3
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
||||||
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
@ -123,6 +124,7 @@ jobs:
|
||||||
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
||||||
env:
|
env:
|
||||||
ENV: dev
|
ENV: dev
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
@ -189,6 +191,7 @@ jobs:
|
||||||
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
||||||
env:
|
env:
|
||||||
ENV: dev
|
ENV: dev
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
|
||||||
3
.github/workflows/vector_db_tests.yml
vendored
3
.github/workflows/vector_db_tests.yml
vendored
|
|
@ -92,6 +92,7 @@ jobs:
|
||||||
- name: Run PGVector Tests
|
- name: Run PGVector Tests
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
|
@ -127,4 +128,4 @@ jobs:
|
||||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_lancedb.py
|
run: uv run python ./cognee/tests/test_lancedb.py
|
||||||
|
|
|
||||||
3
.github/workflows/weighted_edges_tests.yml
vendored
3
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -94,6 +94,7 @@ jobs:
|
||||||
|
|
||||||
- name: Run Weighted Edges Tests
|
- name: Run Weighted Edges Tests
|
||||||
env:
|
env:
|
||||||
|
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||||
|
|
@ -165,5 +166,3 @@ jobs:
|
||||||
uses: astral-sh/ruff-action@v2
|
uses: astral-sh/ruff-action@v2
|
||||||
with:
|
with:
|
||||||
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,267 @@
|
||||||
|
"""Expand dataset database with json connection field
|
||||||
|
|
||||||
|
Revision ID: 46a6ce2bd2b2
|
||||||
|
Revises: 76625596c5c3
|
||||||
|
Create Date: 2025-11-25 17:56:28.938931
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "46a6ce2bd2b2"
|
||||||
|
down_revision: Union[str, None] = "76625596c5c3"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
graph_constraint_name = "dataset_database_graph_database_name_key"
|
||||||
|
vector_constraint_name = "dataset_database_vector_database_name_key"
|
||||||
|
TABLE_NAME = "dataset_database"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_column(inspector, table, name, schema=None):
|
||||||
|
for col in inspector.get_columns(table, schema=schema):
|
||||||
|
if col["name"] == name:
|
||||||
|
return col
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||||
|
"""
|
||||||
|
SQLite cannot drop unique constraints on individual columns. We must:
|
||||||
|
1. Create a new table without the unique constraints.
|
||||||
|
2. Copy data from the old table.
|
||||||
|
3. Drop the old table.
|
||||||
|
4. Rename the new table.
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# Create new table definition (without unique constraints)
|
||||||
|
op.create_table(
|
||||||
|
f"{TABLE_NAME}_new",
|
||||||
|
sa.Column("owner_id", sa.UUID()),
|
||||||
|
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||||
|
sa.Column("vector_database_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("graph_database_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column("vector_database_url", sa.String()),
|
||||||
|
sa.Column("graph_database_url", sa.String()),
|
||||||
|
sa.Column("vector_database_key", sa.String()),
|
||||||
|
sa.Column("graph_database_key", sa.String()),
|
||||||
|
sa.Column(
|
||||||
|
"graph_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"vector_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column("created_at", sa.DateTime()),
|
||||||
|
sa.Column("updated_at", sa.DateTime()),
|
||||||
|
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy data into new table
|
||||||
|
conn.execute(
|
||||||
|
sa.text(f"""
|
||||||
|
INSERT INTO {TABLE_NAME}_new
|
||||||
|
SELECT
|
||||||
|
owner_id,
|
||||||
|
dataset_id,
|
||||||
|
vector_database_name,
|
||||||
|
graph_database_name,
|
||||||
|
vector_database_provider,
|
||||||
|
graph_database_provider,
|
||||||
|
vector_database_url,
|
||||||
|
graph_database_url,
|
||||||
|
vector_database_key,
|
||||||
|
graph_database_key,
|
||||||
|
COALESCE(graph_database_connection_info, '{{}}'),
|
||||||
|
COALESCE(vector_database_connection_info, '{{}}'),
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
FROM {TABLE_NAME}
|
||||||
|
""")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop old table
|
||||||
|
op.drop_table(TABLE_NAME)
|
||||||
|
|
||||||
|
# Rename new table
|
||||||
|
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||||
|
"""
|
||||||
|
SQLite cannot drop unique constraints on individual columns. We must:
|
||||||
|
1. Create a new table without the unique constraints.
|
||||||
|
2. Copy data from the old table.
|
||||||
|
3. Drop the old table.
|
||||||
|
4. Rename the new table.
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# Create new table definition (without unique constraints)
|
||||||
|
op.create_table(
|
||||||
|
f"{TABLE_NAME}_new",
|
||||||
|
sa.Column("owner_id", sa.UUID()),
|
||||||
|
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||||
|
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
|
||||||
|
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
||||||
|
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||||
|
sa.Column("vector_database_url", sa.String()),
|
||||||
|
sa.Column("graph_database_url", sa.String()),
|
||||||
|
sa.Column("vector_database_key", sa.String()),
|
||||||
|
sa.Column("graph_database_key", sa.String()),
|
||||||
|
sa.Column(
|
||||||
|
"graph_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"vector_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
sa.Column("created_at", sa.DateTime()),
|
||||||
|
sa.Column("updated_at", sa.DateTime()),
|
||||||
|
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy data into new table
|
||||||
|
conn.execute(
|
||||||
|
sa.text(f"""
|
||||||
|
INSERT INTO {TABLE_NAME}_new
|
||||||
|
SELECT
|
||||||
|
owner_id,
|
||||||
|
dataset_id,
|
||||||
|
vector_database_name,
|
||||||
|
graph_database_name,
|
||||||
|
vector_database_provider,
|
||||||
|
graph_database_provider,
|
||||||
|
vector_database_url,
|
||||||
|
graph_database_url,
|
||||||
|
vector_database_key,
|
||||||
|
graph_database_key,
|
||||||
|
COALESCE(graph_database_connection_info, '{{}}'),
|
||||||
|
COALESCE(vector_database_connection_info, '{{}}'),
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
FROM {TABLE_NAME}
|
||||||
|
""")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop old table
|
||||||
|
op.drop_table(TABLE_NAME)
|
||||||
|
|
||||||
|
# Rename new table
|
||||||
|
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
insp = sa.inspect(conn)
|
||||||
|
|
||||||
|
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
|
||||||
|
|
||||||
|
vector_database_connection_info_column = _get_column(
|
||||||
|
insp, "dataset_database", "vector_database_connection_info"
|
||||||
|
)
|
||||||
|
if not vector_database_connection_info_column:
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"vector_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_database_connection_info_column = _get_column(
|
||||||
|
insp, "dataset_database", "graph_database_connection_info"
|
||||||
|
)
|
||||||
|
if not graph_database_connection_info_column:
|
||||||
|
op.add_column(
|
||||||
|
"dataset_database",
|
||||||
|
sa.Column(
|
||||||
|
"graph_database_connection_info",
|
||||||
|
sa.JSON(),
|
||||||
|
unique=False,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||||
|
# Drop the unique constraint to make unique=False
|
||||||
|
graph_constraint_to_drop = None
|
||||||
|
for uc in unique_constraints:
|
||||||
|
# Check if the constraint covers ONLY the target column
|
||||||
|
if uc["name"] == graph_constraint_name:
|
||||||
|
graph_constraint_to_drop = uc["name"]
|
||||||
|
break
|
||||||
|
|
||||||
|
vector_constraint_to_drop = None
|
||||||
|
for uc in unique_constraints:
|
||||||
|
# Check if the constraint covers ONLY the target column
|
||||||
|
if uc["name"] == vector_constraint_name:
|
||||||
|
vector_constraint_to_drop = uc["name"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if (
|
||||||
|
vector_constraint_to_drop
|
||||||
|
and graph_constraint_to_drop
|
||||||
|
and op.get_context().dialect.name == "postgresql"
|
||||||
|
):
|
||||||
|
# PostgreSQL
|
||||||
|
batch_op.drop_constraint(graph_constraint_name, type_="unique")
|
||||||
|
batch_op.drop_constraint(vector_constraint_name, type_="unique")
|
||||||
|
|
||||||
|
if op.get_context().dialect.name == "sqlite":
|
||||||
|
conn = op.get_bind()
|
||||||
|
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
|
||||||
|
# So we need to check for them and drop them by recreating the table (altering column also won't work)
|
||||||
|
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
|
||||||
|
rows = result.fetchall()
|
||||||
|
unique_auto_indexes = [row for row in rows if row[3] == "u"]
|
||||||
|
for row in unique_auto_indexes:
|
||||||
|
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
|
||||||
|
index_info = result.fetchall()
|
||||||
|
if index_info[0][2] == "vector_database_name":
|
||||||
|
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
|
||||||
|
_recreate_table_without_unique_constraint_sqlite(op, insp)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
insp = sa.inspect(conn)
|
||||||
|
|
||||||
|
if op.get_context().dialect.name == "sqlite":
|
||||||
|
_recreate_table_with_unique_constraint_sqlite(op, insp)
|
||||||
|
elif op.get_context().dialect.name == "postgresql":
|
||||||
|
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||||
|
# Re-add the unique constraint to return to unique=True
|
||||||
|
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
|
||||||
|
|
||||||
|
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||||
|
# Re-add the unique constraint to return to unique=True
|
||||||
|
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
|
||||||
|
|
||||||
|
op.drop_column("dataset_database", "vector_database_connection_info")
|
||||||
|
op.drop_column("dataset_database", "graph_database_connection_info")
|
||||||
|
|
@ -4,9 +4,10 @@ from typing import Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||||
|
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
|
||||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||||
from cognee.modules.users.methods import get_user
|
from cognee.modules.users.methods import get_user
|
||||||
|
|
||||||
|
|
@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
||||||
graph_db_config = ContextVar("graph_db_config", default=None)
|
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||||
session_user = ContextVar("session_user", default=None)
|
session_user = ContextVar("session_user", default=None)
|
||||||
|
|
||||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
|
||||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
|
||||||
|
|
||||||
|
|
||||||
async def set_session_user_context_variable(user):
|
async def set_session_user_context_variable(user):
|
||||||
session_user.set(user)
|
session_user.set(user)
|
||||||
|
|
||||||
|
|
||||||
def multi_user_support_possible():
|
def multi_user_support_possible():
|
||||||
graph_db_config = get_graph_context_config()
|
graph_db_config = get_graph_config()
|
||||||
vector_db_config = get_vectordb_context_config()
|
vector_db_config = get_vectordb_config()
|
||||||
return (
|
|
||||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
graph_handler = graph_db_config.graph_dataset_database_handler
|
||||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
vector_handler = vector_db_config.vector_dataset_database_handler
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if graph_handler not in supported_dataset_database_handlers:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if vector_handler not in supported_dataset_database_handlers:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
supported_dataset_database_handlers[graph_handler]["handler_provider"]
|
||||||
|
!= graph_db_config.graph_database_provider
|
||||||
|
):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
|
||||||
|
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
supported_dataset_database_handlers[vector_handler]["handler_provider"]
|
||||||
|
!= vector_db_config.vector_db_provider
|
||||||
|
):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||||
|
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
|
||||||
|
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||||
|
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def backend_access_control_enabled():
|
def backend_access_control_enabled():
|
||||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||||
|
|
@ -41,12 +79,7 @@ def backend_access_control_enabled():
|
||||||
return multi_user_support_possible()
|
return multi_user_support_possible()
|
||||||
elif backend_access_control.lower() == "true":
|
elif backend_access_control.lower() == "true":
|
||||||
# If enabled, ensure that the current graph and vector DBs can support it
|
# If enabled, ensure that the current graph and vector DBs can support it
|
||||||
multi_user_support = multi_user_support_possible()
|
return multi_user_support_possible()
|
||||||
if not multi_user_support:
|
|
||||||
raise EnvironmentError(
|
|
||||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
|
|
||||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||||
|
# Ensure that all connection info is resolved properly
|
||||||
|
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
|
||||||
|
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
data_root_directory = os.path.join(
|
data_root_directory = os.path.join(
|
||||||
|
|
@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set vector and graph database configuration based on dataset database information
|
# Set vector and graph database configuration based on dataset database information
|
||||||
|
# TODO: Add better handling of vector and graph config accross Cognee.
|
||||||
|
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||||
vector_config = {
|
vector_config = {
|
||||||
"vector_db_provider": dataset_database.vector_database_provider,
|
"vector_db_provider": dataset_database.vector_database_provider,
|
||||||
"vector_db_url": dataset_database.vector_database_url,
|
"vector_db_url": dataset_database.vector_database_url,
|
||||||
|
|
@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
"graph_file_path": os.path.join(
|
"graph_file_path": os.path.join(
|
||||||
databases_directory_path, dataset_database.graph_database_name
|
databases_directory_path, dataset_database.graph_database_name
|
||||||
),
|
),
|
||||||
|
"graph_database_username": dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_username", ""
|
||||||
|
),
|
||||||
|
"graph_database_password": dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_password", ""
|
||||||
|
),
|
||||||
|
"graph_dataset_database_handler": "",
|
||||||
|
"graph_database_port": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
storage_config = {
|
storage_config = {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
|
||||||
|
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||||
|
from .use_dataset_database_handler import use_dataset_database_handler
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from cognee.modules.users.models.User import User
|
||||||
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetDatabaseHandlerInterface(ABC):
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
|
||||||
|
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||||
|
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||||
|
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||||
|
|
||||||
|
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||||
|
From which internal mapping of dataset -> database connection info will be done.
|
||||||
|
|
||||||
|
The returned dictionary is stored verbatim in the relational database and is later passed to
|
||||||
|
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
|
||||||
|
returning only references to secrets or role identifiers, not plaintext credentials.
|
||||||
|
|
||||||
|
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||||
|
user: User object if needed by the database creation logic
|
||||||
|
Returns:
|
||||||
|
dict: Connection info for the created graph or vector database instance.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def resolve_dataset_connection_info(
|
||||||
|
cls, dataset_database: DatasetDatabase
|
||||||
|
) -> DatasetDatabase:
|
||||||
|
"""
|
||||||
|
Resolve runtime connection details for a dataset’s backing graph/vector database.
|
||||||
|
Function is intended to be overwritten to implement custom logic for resolving connection info.
|
||||||
|
|
||||||
|
This method is invoked right before the application opens a connection for a given dataset.
|
||||||
|
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
|
||||||
|
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
|
||||||
|
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
|
||||||
|
|
||||||
|
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
|
||||||
|
connection info and only change parameters related to its appropriate database, the resolution function will then
|
||||||
|
be called one after another with the updated DatasetDatabase value from the previous function as the input.
|
||||||
|
|
||||||
|
Typical behavior:
|
||||||
|
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
|
||||||
|
or api_url/api_key), return them as-is.
|
||||||
|
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
|
||||||
|
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
|
||||||
|
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
|
||||||
|
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
|
||||||
|
to the caller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase row from the relational database
|
||||||
|
Returns:
|
||||||
|
DatasetDatabase: Updated instance with resolved connection info
|
||||||
|
"""
|
||||||
|
return dataset_database
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
|
||||||
|
"""
|
||||||
|
Delete the graph or vector database for the given dataset.
|
||||||
|
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
|
||||||
|
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
|
||||||
|
Neo4jAuraDevDatasetDatabaseHandler,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
|
||||||
|
LanceDBDatasetDatabaseHandler,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||||
|
KuzuDatasetDatabaseHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
supported_dataset_database_handlers = {
|
||||||
|
"neo4j_aura_dev": {
|
||||||
|
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
|
||||||
|
"handler_provider": "neo4j",
|
||||||
|
},
|
||||||
|
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||||
|
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||||
|
|
||||||
|
|
||||||
|
def use_dataset_database_handler(
|
||||||
|
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
|
||||||
|
):
|
||||||
|
supported_dataset_database_handlers[dataset_database_handler_name] = {
|
||||||
|
"handler_instance": dataset_database_handler,
|
||||||
|
"handler_provider": dataset_database_provider,
|
||||||
|
}
|
||||||
|
|
@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
|
||||||
graph_filename: str = ""
|
graph_filename: str = ""
|
||||||
graph_model: object = KnowledgeGraph
|
graph_model: object = KnowledgeGraph
|
||||||
graph_topology: object = KnowledgeGraph
|
graph_topology: object = KnowledgeGraph
|
||||||
|
graph_dataset_database_handler: str = "kuzu"
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||||
|
|
||||||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||||
|
|
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
|
||||||
"graph_model": self.graph_model,
|
"graph_model": self.graph_model,
|
||||||
"graph_topology": self.graph_topology,
|
"graph_topology": self.graph_topology,
|
||||||
"model_config": self.model_config,
|
"model_config": self.model_config,
|
||||||
|
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_hashable_dict(self) -> dict:
|
def to_hashable_dict(self) -> dict:
|
||||||
|
|
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
|
||||||
"graph_database_port": self.graph_database_port,
|
"graph_database_port": self.graph_database_port,
|
||||||
"graph_database_key": self.graph_database_key,
|
"graph_database_key": self.graph_database_key,
|
||||||
"graph_file_path": self.graph_file_path,
|
"graph_file_path": self.graph_file_path,
|
||||||
|
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ def create_graph_engine(
|
||||||
graph_database_password="",
|
graph_database_password="",
|
||||||
graph_database_port="",
|
graph_database_port="",
|
||||||
graph_database_key="",
|
graph_database_key="",
|
||||||
|
graph_dataset_database_handler="",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a graph engine based on the specified provider type.
|
Create a graph engine based on the specified provider type.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import os
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
"""
|
||||||
|
Handler for interacting with Kuzu Dataset databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset UUID
|
||||||
|
user: User object who owns the dataset and is making the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Connection details for the created Kuzu instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
|
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
if graph_config.graph_database_provider != "kuzu":
|
||||||
|
raise ValueError(
|
||||||
|
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_db_name = f"{dataset_id}.pkl"
|
||||||
|
graph_db_url = graph_config.graph_database_url
|
||||||
|
graph_db_key = graph_config.graph_database_key
|
||||||
|
graph_db_username = graph_config.graph_database_username
|
||||||
|
graph_db_password = graph_config.graph_database_password
|
||||||
|
|
||||||
|
return {
|
||||||
|
"graph_database_name": graph_db_name,
|
||||||
|
"graph_database_url": graph_db_url,
|
||||||
|
"graph_database_provider": graph_config.graph_database_provider,
|
||||||
|
"graph_database_key": graph_db_key,
|
||||||
|
"graph_database_connection_info": {
|
||||||
|
"graph_database_username": graph_db_username,
|
||||||
|
"graph_database_password": graph_db_password,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
|
base_config = get_base_config()
|
||||||
|
databases_directory_path = os.path.join(
|
||||||
|
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
|
||||||
|
)
|
||||||
|
graph_file_path = os.path.join(
|
||||||
|
databases_directory_path, dataset_database.graph_database_name
|
||||||
|
)
|
||||||
|
graph_engine = create_graph_engine(
|
||||||
|
graph_database_provider=dataset_database.graph_database_provider,
|
||||||
|
graph_database_url=dataset_database.graph_database_url,
|
||||||
|
graph_database_name=dataset_database.graph_database_name,
|
||||||
|
graph_database_key=dataset_database.graph_database_key,
|
||||||
|
graph_file_path=graph_file_path,
|
||||||
|
graph_database_username=dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_username", ""
|
||||||
|
),
|
||||||
|
graph_database_password=dataset_database.graph_database_connection_info.get(
|
||||||
|
"graph_database_password", ""
|
||||||
|
),
|
||||||
|
graph_dataset_database_handler="",
|
||||||
|
graph_database_port="",
|
||||||
|
)
|
||||||
|
await graph_engine.delete_graph()
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
from cognee.modules.users.models import User, DatasetDatabase
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
"""
|
||||||
|
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
|
||||||
|
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
|
||||||
|
|
||||||
|
Improvements needed to be production ready:
|
||||||
|
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
|
||||||
|
a secret manager or a similar system should be used instead.
|
||||||
|
|
||||||
|
Quality of life improvements:
|
||||||
|
- Allow configuration of different Neo4j Aura plans and regions.
|
||||||
|
- Requests should be made async, currently a blocking requests library is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset UUID
|
||||||
|
user: User object who owns the dataset and is making the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Connection details for the created Neo4j instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
if graph_config.graph_database_provider != "neo4j":
|
||||||
|
raise ValueError(
|
||||||
|
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_db_name = f"{dataset_id}"
|
||||||
|
|
||||||
|
# Client credentials and encryption
|
||||||
|
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||||
|
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||||
|
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||||
|
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||||
|
encryption_key = base64.urlsafe_b64encode(
|
||||||
|
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||||
|
)
|
||||||
|
cipher = Fernet(encryption_key)
|
||||||
|
|
||||||
|
if client_id is None or client_secret is None or tenant_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make the request with HTTP Basic Auth
|
||||||
|
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||||
|
url = "https://api.neo4j.io/oauth/token"
|
||||||
|
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||||
|
|
||||||
|
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||||
|
resp.raise_for_status() # raises if the request failed
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
resp = get_aura_token(client_id, client_secret)
|
||||||
|
|
||||||
|
url = "https://api.neo4j.io/v1/instances"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"Authorization": f"Bearer {resp['access_token']}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||||
|
# Too allow different configurations between datasets
|
||||||
|
payload = {
|
||||||
|
"version": "5",
|
||||||
|
"region": "europe-west1",
|
||||||
|
"memory": "1GB",
|
||||||
|
"name": graph_db_name[
|
||||||
|
0:29
|
||||||
|
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||||
|
"type": "professional-db",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"cloud_provider": "gcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
|
|
||||||
|
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||||
|
graph_db_url = response.json()["data"]["connection_url"]
|
||||||
|
graph_db_key = resp["access_token"]
|
||||||
|
graph_db_username = response.json()["data"]["username"]
|
||||||
|
graph_db_password = response.json()["data"]["password"]
|
||||||
|
|
||||||
|
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||||
|
# Poll until the instance is running
|
||||||
|
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||||
|
status = ""
|
||||||
|
for attempt in range(30): # Try for up to ~5 minutes
|
||||||
|
status_resp = requests.get(
|
||||||
|
status_url, headers=headers
|
||||||
|
) # TODO: Use async requests with httpx
|
||||||
|
status = status_resp.json()["data"]["status"]
|
||||||
|
if status.lower() == "running":
|
||||||
|
return
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
instance_id = response.json()["data"]["id"]
|
||||||
|
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||||
|
|
||||||
|
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||||
|
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"graph_database_name": graph_db_name,
|
||||||
|
"graph_database_url": graph_db_url,
|
||||||
|
"graph_database_provider": "neo4j",
|
||||||
|
"graph_database_key": graph_db_key,
|
||||||
|
"graph_database_connection_info": {
|
||||||
|
"graph_database_username": graph_db_username,
|
||||||
|
"graph_database_password": encrypted_db_password_string,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def resolve_dataset_connection_info(
|
||||||
|
cls, dataset_database: DatasetDatabase
|
||||||
|
) -> DatasetDatabase:
|
||||||
|
"""
|
||||||
|
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||||
|
In this case, decrypt the password stored in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||||
|
"""
|
||||||
|
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||||
|
encryption_key = base64.urlsafe_b64encode(
|
||||||
|
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||||
|
)
|
||||||
|
cipher = Fernet(encryption_key)
|
||||||
|
graph_db_password = cipher.decrypt(
|
||||||
|
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||||
|
graph_db_password
|
||||||
|
)
|
||||||
|
return dataset_database
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
|
pass
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
from .get_or_create_dataset_database import get_or_create_dataset_database
|
from .get_or_create_dataset_database import get_or_create_dataset_database
|
||||||
|
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
import os
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
|
||||||
from cognee.modules.data.methods import create_dataset
|
from cognee.modules.data.methods import create_dataset
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
|
@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||||
|
|
||||||
|
|
||||||
|
async def _existing_dataset_database(
|
||||||
|
dataset_id: UUID,
|
||||||
|
user: User,
|
||||||
|
) -> Optional[DatasetDatabase]:
|
||||||
|
"""
|
||||||
|
Check if a DatasetDatabase row already exists for the given owner + dataset.
|
||||||
|
Return None if it doesn't exist, return the row if it does.
|
||||||
|
Args:
|
||||||
|
dataset_id:
|
||||||
|
user:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatasetDatabase or None
|
||||||
|
"""
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
stmt = select(DatasetDatabase).where(
|
||||||
|
DatasetDatabase.owner_id == user.id,
|
||||||
|
DatasetDatabase.dataset_id == dataset_id,
|
||||||
|
)
|
||||||
|
existing: DatasetDatabase = await session.scalar(stmt)
|
||||||
|
return existing
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_dataset_database(
|
async def get_or_create_dataset_database(
|
||||||
dataset: Union[str, UUID],
|
dataset: Union[str, UUID],
|
||||||
user: User,
|
user: User,
|
||||||
|
|
@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
|
||||||
• If the row already exists, it is fetched and returned.
|
• If the row already exists, it is fetched and returned.
|
||||||
• Otherwise a new one is created atomically and returned.
|
• Otherwise a new one is created atomically and returned.
|
||||||
|
|
||||||
|
DatasetDatabase row contains connection and provider info for vector and graph databases.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
user : User
|
user : User
|
||||||
|
|
@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
|
||||||
|
|
||||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||||
|
|
||||||
vector_config = get_vectordb_config()
|
# If dataset is given as name make sure the dataset is created first
|
||||||
graph_config = get_graph_config()
|
if isinstance(dataset, str):
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
await create_dataset(dataset, user, session)
|
||||||
|
|
||||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
# If dataset database already exists return it
|
||||||
if graph_config.graph_database_provider == "kuzu":
|
existing_dataset_database = await _existing_dataset_database(dataset_id, user)
|
||||||
graph_db_name = f"{dataset_id}.pkl"
|
if existing_dataset_database:
|
||||||
else:
|
return existing_dataset_database
|
||||||
graph_db_name = f"{dataset_id}"
|
|
||||||
|
|
||||||
if vector_config.vector_db_provider == "lancedb":
|
graph_config_dict = await _get_graph_db_info(dataset_id, user)
|
||||||
vector_db_name = f"{dataset_id}.lance.db"
|
vector_config_dict = await _get_vector_db_info(dataset_id, user)
|
||||||
else:
|
|
||||||
vector_db_name = f"{dataset_id}"
|
|
||||||
|
|
||||||
base_config = get_base_config()
|
|
||||||
databases_directory_path = os.path.join(
|
|
||||||
base_config.system_root_directory, "databases", str(user.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine vector database URL
|
|
||||||
if vector_config.vector_db_provider == "lancedb":
|
|
||||||
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
|
||||||
else:
|
|
||||||
vector_db_url = vector_config.vector_database_url
|
|
||||||
|
|
||||||
# Determine graph database URL
|
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Create dataset if it doesn't exist
|
|
||||||
if isinstance(dataset, str):
|
|
||||||
dataset = await create_dataset(dataset, user, session)
|
|
||||||
|
|
||||||
# Try to fetch an existing row first
|
|
||||||
stmt = select(DatasetDatabase).where(
|
|
||||||
DatasetDatabase.owner_id == user.id,
|
|
||||||
DatasetDatabase.dataset_id == dataset_id,
|
|
||||||
)
|
|
||||||
existing: DatasetDatabase = await session.scalar(stmt)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
# If there are no existing rows build a new row
|
# If there are no existing rows build a new row
|
||||||
record = DatasetDatabase(
|
record = DatasetDatabase(
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
vector_database_name=vector_db_name,
|
**graph_config_dict, # Unpack graph db config
|
||||||
graph_database_name=graph_db_name,
|
**vector_config_dict, # Unpack vector db config
|
||||||
vector_database_provider=vector_config.vector_db_provider,
|
|
||||||
graph_database_provider=graph_config.graph_database_provider,
|
|
||||||
vector_database_url=vector_db_url,
|
|
||||||
graph_database_url=graph_config.graph_database_url,
|
|
||||||
vector_database_key=vector_config.vector_db_key,
|
|
||||||
graph_database_key=graph_config.graph_database_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
|
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_dataset_database_connection_info(
|
||||||
|
dataset_database: DatasetDatabase,
|
||||||
|
) -> DatasetDatabase:
|
||||||
|
"""
|
||||||
|
Resolve the connection info for the given DatasetDatabase instance.
|
||||||
|
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase instance
|
||||||
|
Returns:
|
||||||
|
DatasetDatabase instance with resolved connection info
|
||||||
|
"""
|
||||||
|
dataset_database = await _get_vector_db_connection_info(dataset_database)
|
||||||
|
dataset_database = await _get_graph_db_connection_info(dataset_database)
|
||||||
|
return dataset_database
|
||||||
|
|
@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
|
||||||
vector_db_name: str = ""
|
vector_db_name: str = ""
|
||||||
vector_db_key: str = ""
|
vector_db_key: str = ""
|
||||||
vector_db_provider: str = "lancedb"
|
vector_db_provider: str = "lancedb"
|
||||||
|
vector_dataset_database_handler: str = "lancedb"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
|
||||||
"vector_db_name": self.vector_db_name,
|
"vector_db_name": self.vector_db_name,
|
||||||
"vector_db_key": self.vector_db_key,
|
"vector_db_key": self.vector_db_key,
|
||||||
"vector_db_provider": self.vector_db_provider,
|
"vector_db_provider": self.vector_db_provider,
|
||||||
|
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ def create_vector_engine(
|
||||||
vector_db_name: str,
|
vector_db_name: str,
|
||||||
vector_db_port: str = "",
|
vector_db_port: str = "",
|
||||||
vector_db_key: str = "",
|
vector_db_key: str = "",
|
||||||
|
vector_dataset_database_handler: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a vector database engine based on the specified provider.
|
Create a vector database engine based on the specified provider.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
import os
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
"""
|
||||||
|
Handler for interacting with LanceDB Dataset databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
base_config = get_base_config()
|
||||||
|
|
||||||
|
if vector_config.vector_db_provider != "lancedb":
|
||||||
|
raise ValueError(
|
||||||
|
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
|
||||||
|
)
|
||||||
|
|
||||||
|
databases_directory_path = os.path.join(
|
||||||
|
base_config.system_root_directory, "databases", str(user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db_name = f"{dataset_id}.lance.db"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vector_database_provider": vector_config.vector_db_provider,
|
||||||
|
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||||
|
"vector_database_key": vector_config.vector_db_key,
|
||||||
|
"vector_database_name": vector_db_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
|
vector_engine = create_vector_engine(
|
||||||
|
vector_db_provider=dataset_database.vector_database_provider,
|
||||||
|
vector_db_url=dataset_database.vector_database_url,
|
||||||
|
vector_db_key=dataset_database.vector_database_key,
|
||||||
|
vector_db_name=dataset_database.vector_database_name,
|
||||||
|
)
|
||||||
|
await vector_engine.prune()
|
||||||
|
|
@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from .models.PayloadSchema import PayloadSchema
|
from .models.PayloadSchema import PayloadSchema
|
||||||
|
from uuid import UUID
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
|
||||||
class VectorDBInterface(Protocol):
|
class VectorDBInterface(Protocol):
|
||||||
|
|
@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
|
||||||
- Any: The schema object suitable for this vector database
|
- Any: The schema object suitable for this vector database
|
||||||
"""
|
"""
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dictionary with connection info for a vector database for the given dataset.
|
||||||
|
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||||
|
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||||
|
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||||
|
|
||||||
|
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||||
|
From which internal mapping of dataset -> database connection info will be done.
|
||||||
|
|
||||||
|
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||||
|
user: User object if needed by the database creation logic
|
||||||
|
Returns:
|
||||||
|
dict: Connection info for the created vector database instance.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
||||||
|
"""
|
||||||
|
Delete the vector database for the given dataset.
|
||||||
|
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
||||||
|
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset
|
||||||
|
user: User object
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,82 @@
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||||
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
from cognee.shared.cache import delete_cache
|
from cognee.shared.cache import delete_cache
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def prune_graph_databases():
|
||||||
|
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
|
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
try:
|
||||||
|
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||||
|
# Go through each dataset database and delete the graph database
|
||||||
|
for data_item in data:
|
||||||
|
await _prune_graph_db(data_item)
|
||||||
|
except (OperationalError, EntityNotFoundError) as e:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def prune_vector_databases():
|
||||||
|
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||||
|
supported_dataset_database_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||||
|
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||||
|
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
try:
|
||||||
|
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||||
|
# Go through each dataset database and delete the vector database
|
||||||
|
for data_item in data:
|
||||||
|
await _prune_vector_db(data_item)
|
||||||
|
except (OperationalError, EntityNotFoundError) as e:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||||
if graph:
|
# Note: prune system should not be available through the API, it has no permission checks and will
|
||||||
|
# delete all graph and vector databases if called. It should only be used in development or testing environments.
|
||||||
|
if graph and not backend_access_control_enabled():
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
await graph_engine.delete_graph()
|
await graph_engine.delete_graph()
|
||||||
|
elif graph and backend_access_control_enabled():
|
||||||
|
await prune_graph_databases()
|
||||||
|
|
||||||
if vector:
|
if vector and not backend_access_control_enabled():
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
await vector_engine.prune()
|
await vector_engine.prune()
|
||||||
|
elif vector and backend_access_control_enabled():
|
||||||
|
await prune_vector_databases()
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
|
||||||
|
|
||||||
# Check environment variable to determine authentication requirement
|
# Check environment variable to determine authentication requirement
|
||||||
REQUIRE_AUTHENTICATION = (
|
REQUIRE_AUTHENTICATION = (
|
||||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
|
||||||
or backend_access_control_enabled()
|
or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
fastapi_users = get_fastapi_users()
|
fastapi_users = get_fastapi_users()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
|
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text
|
||||||
from cognee.infrastructure.databases.relational import Base
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -12,8 +12,8 @@ class DatasetDatabase(Base):
|
||||||
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
|
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_database_name = Column(String, unique=True, nullable=False)
|
vector_database_name = Column(String, unique=False, nullable=False)
|
||||||
graph_database_name = Column(String, unique=True, nullable=False)
|
graph_database_name = Column(String, unique=False, nullable=False)
|
||||||
|
|
||||||
vector_database_provider = Column(String, unique=False, nullable=False)
|
vector_database_provider = Column(String, unique=False, nullable=False)
|
||||||
graph_database_provider = Column(String, unique=False, nullable=False)
|
graph_database_provider = Column(String, unique=False, nullable=False)
|
||||||
|
|
@ -24,5 +24,14 @@ class DatasetDatabase(Base):
|
||||||
vector_database_key = Column(String, unique=False, nullable=True)
|
vector_database_key = Column(String, unique=False, nullable=True)
|
||||||
graph_database_key = Column(String, unique=False, nullable=True)
|
graph_database_key = Column(String, unique=False, nullable=True)
|
||||||
|
|
||||||
|
# configuration details for different database types. This would make it more flexible to add new database types
|
||||||
|
# without changing the database schema.
|
||||||
|
graph_database_connection_info = Column(
|
||||||
|
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||||
|
)
|
||||||
|
vector_database_connection_info = Column(
|
||||||
|
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||||
|
)
|
||||||
|
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
|
||||||
|
|
@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None):
|
||||||
# Get a configured logger and log system information
|
# Get a configured logger and log system information
|
||||||
logger = structlog.get_logger(name if name else __name__)
|
logger = structlog.get_logger(name if name else __name__)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation."
|
||||||
|
)
|
||||||
|
|
||||||
if logs_dir is not None:
|
if logs_dir is not None:
|
||||||
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
||||||
|
|
||||||
|
|
|
||||||
135
cognee/tests/test_dataset_database_handler.py
Normal file
135
cognee/tests/test_dataset_database_handler.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set custom dataset database handler environment variable
|
||||||
|
os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler"
|
||||||
|
os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler"
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||||
|
from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id, user):
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id))
|
||||||
|
os.makedirs(databases_directory_path, exist_ok=True)
|
||||||
|
|
||||||
|
vector_db_name = "test.lance.db"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vector_database_name": vector_db_name,
|
||||||
|
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||||
|
"vector_database_provider": "lancedb",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id, user):
|
||||||
|
databases_directory_path = os.path.join("databases", str(user.id))
|
||||||
|
os.makedirs(databases_directory_path, exist_ok=True)
|
||||||
|
|
||||||
|
graph_db_name = "test.kuzu"
|
||||||
|
return {
|
||||||
|
"graph_database_name": graph_db_name,
|
||||||
|
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
|
||||||
|
"graph_database_provider": "kuzu",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
data_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler"
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
# Add custom dataset database handler
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import (
|
||||||
|
use_dataset_database_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_dataset_database_handler(
|
||||||
|
"custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb"
|
||||||
|
)
|
||||||
|
use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu")
|
||||||
|
|
||||||
|
# Create a clean slate for cognee -- reset data and system state
|
||||||
|
print("Resetting cognee data...")
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
print("Data reset complete.\n")
|
||||||
|
|
||||||
|
# cognee knowledge graph will be created based on this text
|
||||||
|
text = """
|
||||||
|
Natural language processing (NLP) is an interdisciplinary
|
||||||
|
subfield of computer science and information retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("Adding text to cognee:")
|
||||||
|
print(text.strip())
|
||||||
|
|
||||||
|
# Add the text, and make it available for cognify
|
||||||
|
await cognee.add(text)
|
||||||
|
print("Text added successfully.\n")
|
||||||
|
|
||||||
|
# Use LLMs and cognee to create knowledge graph
|
||||||
|
await cognee.cognify()
|
||||||
|
print("Cognify process complete.\n")
|
||||||
|
|
||||||
|
query_text = "Tell me about NLP"
|
||||||
|
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||||
|
# Query cognee for insights on the added text
|
||||||
|
search_results = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Search results:")
|
||||||
|
# Display results
|
||||||
|
for result_text in search_results:
|
||||||
|
print(result_text)
|
||||||
|
|
||||||
|
default_user = await get_default_user()
|
||||||
|
# Assert that the custom database files were created based on the custom dataset database handlers
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu")
|
||||||
|
), "Graph database file not found."
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db")
|
||||||
|
), "Vector database file not found."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = setup_logging(log_level=ERROR)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
Loading…
Add table
Reference in a new issue