feat: add dataset database handler logic and neo4j/lancedb/kuzu handlers (#1776)

<!-- .github/pull_request_template.md -->

## Description
Add ability to use multi tenant multi user mode with Neo4j

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

* **New Features**
* Multi-user support with per-dataset database isolation enabled by
default, allowing backend access control for secure data separation.
* Configurable database handlers via environment variables
(GRAPH_DATASET_DATABASE_HANDLER, VECTOR_DATASET_DATABASE_HANDLER) for
flexible deployment options.

* **Chores**
* Database schema migration to support per-user dataset database
configurations.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Igor Ilic 2025-12-11 14:15:20 +01:00 committed by GitHub
commit 46ddd4fd12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 1143 additions and 74 deletions

View file

@ -97,6 +97,8 @@ DB_NAME=cognee_db
# Default (local file-based)
GRAPH_DATABASE_PROVIDER="kuzu"
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
#GRAPH_DATABASE_PROVIDER="kuzu"
@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb"
# Not needed if a cloud vector database is not used
VECTOR_DB_URL=
VECTOR_DB_KEY=
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
################################################################################
# 🧩 Ontology resolver settings

View file

@ -61,6 +61,7 @@ jobs:
- name: Run Neo4j Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -142,6 +143,7 @@ jobs:
- name: Run PGVector Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -47,6 +47,7 @@ jobs:
- name: Run Distributed Cognee (Modal)
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -147,6 +147,7 @@ jobs:
- name: Run Deduplication Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
@ -211,6 +212,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_parallel_databases.py
test-dataset-database-handler:
name: Test dataset database handlers in Cognee
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run dataset databases handler test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_dataset_database_handler.py
test-permissions:
name: Test permissions with different situations in Cognee
runs-on: ubuntu-22.04

View file

@ -72,6 +72,7 @@ jobs:
- name: Run Descriptive Graph Metrics Example
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -78,6 +78,7 @@ jobs:
- name: Run default Neo4j
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

View file

@ -72,6 +72,7 @@ jobs:
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@ -123,6 +124,7 @@ jobs:
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
env:
ENV: dev
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@ -189,6 +191,7 @@ jobs:
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
env:
ENV: dev
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View file

@ -92,6 +92,7 @@ jobs:
- name: Run PGVector Tests
env:
ENV: 'dev'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
@ -127,4 +128,4 @@ jobs:
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_lancedb.py
run: uv run python ./cognee/tests/test_lancedb.py

View file

@ -94,6 +94,7 @@ jobs:
- name: Run Weighted Edges Tests
env:
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
@ -165,5 +166,3 @@ jobs:
uses: astral-sh/ruff-action@v2
with:
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"

View file

@ -0,0 +1,267 @@
"""Expand dataset database with json connection field
Revision ID: 46a6ce2bd2b2
Revises: 76625596c5c3
Create Date: 2025-11-25 17:56:28.938931
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "46a6ce2bd2b2"
down_revision: Union[str, None] = "76625596c5c3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
graph_constraint_name = "dataset_database_graph_database_name_key"
vector_constraint_name = "dataset_database_vector_database_name_key"
TABLE_NAME = "dataset_database"
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def _recreate_table_without_unique_constraint_sqlite(op, insp):
"""
SQLite cannot drop unique constraints on individual columns. We must:
1. Create a new table without the unique constraints.
2. Copy data from the old table.
3. Drop the old table.
4. Rename the new table.
"""
conn = op.get_bind()
# Create new table definition (without unique constraints)
op.create_table(
f"{TABLE_NAME}_new",
sa.Column("owner_id", sa.UUID()),
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
sa.Column("vector_database_name", sa.String(), nullable=False),
sa.Column("graph_database_name", sa.String(), nullable=False),
sa.Column("vector_database_provider", sa.String(), nullable=False),
sa.Column("graph_database_provider", sa.String(), nullable=False),
sa.Column("vector_database_url", sa.String()),
sa.Column("graph_database_url", sa.String()),
sa.Column("vector_database_key", sa.String()),
sa.Column("graph_database_key", sa.String()),
sa.Column(
"graph_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column(
"vector_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column("created_at", sa.DateTime()),
sa.Column("updated_at", sa.DateTime()),
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
)
# Copy data into new table
conn.execute(
sa.text(f"""
INSERT INTO {TABLE_NAME}_new
SELECT
owner_id,
dataset_id,
vector_database_name,
graph_database_name,
vector_database_provider,
graph_database_provider,
vector_database_url,
graph_database_url,
vector_database_key,
graph_database_key,
COALESCE(graph_database_connection_info, '{{}}'),
COALESCE(vector_database_connection_info, '{{}}'),
created_at,
updated_at
FROM {TABLE_NAME}
""")
)
# Drop old table
op.drop_table(TABLE_NAME)
# Rename new table
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
def _recreate_table_with_unique_constraint_sqlite(op, insp):
"""
SQLite cannot drop unique constraints on individual columns. We must:
1. Create a new table without the unique constraints.
2. Copy data from the old table.
3. Drop the old table.
4. Rename the new table.
"""
conn = op.get_bind()
# Create new table definition (without unique constraints)
op.create_table(
f"{TABLE_NAME}_new",
sa.Column("owner_id", sa.UUID()),
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
sa.Column("vector_database_provider", sa.String(), nullable=False),
sa.Column("graph_database_provider", sa.String(), nullable=False),
sa.Column("vector_database_url", sa.String()),
sa.Column("graph_database_url", sa.String()),
sa.Column("vector_database_key", sa.String()),
sa.Column("graph_database_key", sa.String()),
sa.Column(
"graph_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column(
"vector_database_connection_info",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
sa.Column("created_at", sa.DateTime()),
sa.Column("updated_at", sa.DateTime()),
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
)
# Copy data into new table
conn.execute(
sa.text(f"""
INSERT INTO {TABLE_NAME}_new
SELECT
owner_id,
dataset_id,
vector_database_name,
graph_database_name,
vector_database_provider,
graph_database_provider,
vector_database_url,
graph_database_url,
vector_database_key,
graph_database_key,
COALESCE(graph_database_connection_info, '{{}}'),
COALESCE(vector_database_connection_info, '{{}}'),
created_at,
updated_at
FROM {TABLE_NAME}
""")
)
# Drop old table
op.drop_table(TABLE_NAME)
# Rename new table
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
vector_database_connection_info_column = _get_column(
insp, "dataset_database", "vector_database_connection_info"
)
if not vector_database_connection_info_column:
op.add_column(
"dataset_database",
sa.Column(
"vector_database_connection_info",
sa.JSON(),
unique=False,
nullable=False,
server_default=sa.text("'{}'"),
),
)
graph_database_connection_info_column = _get_column(
insp, "dataset_database", "graph_database_connection_info"
)
if not graph_database_connection_info_column:
op.add_column(
"dataset_database",
sa.Column(
"graph_database_connection_info",
sa.JSON(),
unique=False,
nullable=False,
server_default=sa.text("'{}'"),
),
)
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
# Drop the unique constraint to make unique=False
graph_constraint_to_drop = None
for uc in unique_constraints:
# Check if the constraint covers ONLY the target column
if uc["name"] == graph_constraint_name:
graph_constraint_to_drop = uc["name"]
break
vector_constraint_to_drop = None
for uc in unique_constraints:
# Check if the constraint covers ONLY the target column
if uc["name"] == vector_constraint_name:
vector_constraint_to_drop = uc["name"]
break
if (
vector_constraint_to_drop
and graph_constraint_to_drop
and op.get_context().dialect.name == "postgresql"
):
# PostgreSQL
batch_op.drop_constraint(graph_constraint_name, type_="unique")
batch_op.drop_constraint(vector_constraint_name, type_="unique")
if op.get_context().dialect.name == "sqlite":
conn = op.get_bind()
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
# So we need to check for them and drop them by recreating the table (altering column also won't work)
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
rows = result.fetchall()
unique_auto_indexes = [row for row in rows if row[3] == "u"]
for row in unique_auto_indexes:
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
index_info = result.fetchall()
if index_info[0][2] == "vector_database_name":
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
_recreate_table_without_unique_constraint_sqlite(op, insp)
def downgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
if op.get_context().dialect.name == "sqlite":
_recreate_table_with_unique_constraint_sqlite(op, insp)
elif op.get_context().dialect.name == "postgresql":
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
# Re-add the unique constraint to return to unique=True
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
# Re-add the unique constraint to return to unique=True
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
op.drop_column("dataset_database", "vector_database_connection_info")
op.drop_column("dataset_database", "graph_database_connection_info")

View file

@ -4,9 +4,10 @@ from typing import Union
from uuid import UUID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
async def set_session_user_context_variable(user):
session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
graph_db_config = get_graph_config()
vector_db_config = get_vectordb_config()
graph_handler = graph_db_config.graph_dataset_database_handler
vector_handler = vector_db_config.vector_dataset_database_handler
from cognee.infrastructure.databases.dataset_database_handler import (
supported_dataset_database_handlers,
)
if graph_handler not in supported_dataset_database_handlers:
raise EnvironmentError(
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected graph dataset to database handler: {graph_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
if vector_handler not in supported_dataset_database_handlers:
raise EnvironmentError(
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected vector dataset to database handler: {vector_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
if (
supported_dataset_database_handlers[graph_handler]["handler_provider"]
!= graph_db_config.graph_database_provider
):
raise EnvironmentError(
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
f"Selected graph dataset to database handler: {graph_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
if (
supported_dataset_database_handlers[vector_handler]["handler_provider"]
!= vector_db_config.vector_db_provider
):
raise EnvironmentError(
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
f"Selected vector dataset to database handler: {vector_handler}\n"
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
)
return True
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
@ -41,12 +79,7 @@ def backend_access_control_enabled():
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return multi_user_support_possible()
return False
@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user)
# Ensure that all connection info is resolved properly
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
base_config = get_base_config()
data_root_directory = os.path.join(
@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
)
# Set vector and graph database configuration based on dataset database information
# TODO: Add better handling of vector and graph config accross Cognee.
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
vector_config = {
"vector_db_provider": dataset_database.vector_database_provider,
"vector_db_url": dataset_database.vector_database_url,
@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
"graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name
),
"graph_database_username": dataset_database.graph_database_connection_info.get(
"graph_database_username", ""
),
"graph_database_password": dataset_database.graph_database_connection_info.get(
"graph_database_password", ""
),
"graph_dataset_database_handler": "",
"graph_database_port": "",
}
storage_config = {

View file

@ -0,0 +1,3 @@
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
from .supported_dataset_database_handlers import supported_dataset_database_handlers
from .use_dataset_database_handler import use_dataset_database_handler

View file

@ -0,0 +1,80 @@
from typing import Optional
from uuid import UUID
from abc import ABC, abstractmethod
from cognee.modules.users.models.User import User
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
class DatasetDatabaseHandlerInterface(ABC):
@classmethod
@abstractmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
Function can auto handle deploying of the actual database if needed, but is not necessary.
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
Needed for Cognee multi-tenant/multi-user and backend access control support.
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
From which internal mapping of dataset -> database connection info will be done.
The returned dictionary is stored verbatim in the relational database and is later passed to
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
returning only references to secrets or role identifiers, not plaintext credentials.
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
Args:
dataset_id: UUID of the dataset if needed by the database creation logic
user: User object if needed by the database creation logic
Returns:
dict: Connection info for the created graph or vector database instance.
"""
pass
@classmethod
async def resolve_dataset_connection_info(
cls, dataset_database: DatasetDatabase
) -> DatasetDatabase:
"""
Resolve runtime connection details for a datasets backing graph/vector database.
Function is intended to be overwritten to implement custom logic for resolving connection info.
This method is invoked right before the application opens a connection for a given dataset.
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
connection info and only change parameters related to its appropriate database, the resolution function will then
be called one after another with the updated DatasetDatabase value from the previous function as the input.
Typical behavior:
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
or api_url/api_key), return them as-is.
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
to the caller.
Args:
dataset_database: DatasetDatabase row from the relational database
Returns:
DatasetDatabase: Updated instance with resolved connection info
"""
return dataset_database
@classmethod
@abstractmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
"""
Delete the graph or vector database for the given dataset.
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
Args:
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
"""
pass

View file

@ -0,0 +1,18 @@
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
Neo4jAuraDevDatasetDatabaseHandler,
)
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
LanceDBDatasetDatabaseHandler,
)
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
KuzuDatasetDatabaseHandler,
)
supported_dataset_database_handlers = {
"neo4j_aura_dev": {
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
"handler_provider": "neo4j",
},
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
}

View file

@ -0,0 +1,10 @@
from .supported_dataset_database_handlers import supported_dataset_database_handlers
def use_dataset_database_handler(
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
):
supported_dataset_database_handlers[dataset_database_handler_name] = {
"handler_instance": dataset_database_handler,
"handler_provider": dataset_database_provider,
}

View file

@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
graph_filename: str = ""
graph_model: object = KnowledgeGraph
graph_topology: object = KnowledgeGraph
graph_dataset_database_handler: str = "kuzu"
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
"model_config": self.model_config,
"graph_dataset_database_handler": self.graph_dataset_database_handler,
}
def to_hashable_dict(self) -> dict:
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
"graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
"graph_dataset_database_handler": self.graph_dataset_database_handler,
}

View file

@ -34,6 +34,7 @@ def create_graph_engine(
graph_database_password="",
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Create a graph engine based on the specified provider type.

View file

@ -0,0 +1,80 @@
import os
from uuid import UUID
from typing import Optional
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.base_config import get_base_config
from cognee.modules.users.models import User
from cognee.modules.users.models import DatasetDatabase
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for interacting with Kuzu Dataset databases.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
Args:
dataset_id: Dataset UUID
user: User object who owns the dataset and is making the request
Returns:
dict: Connection details for the created Kuzu instance
"""
from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
if graph_config.graph_database_provider != "kuzu":
raise ValueError(
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
)
graph_db_name = f"{dataset_id}.pkl"
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": graph_config.graph_database_provider,
"graph_database_key": graph_db_key,
"graph_database_connection_info": {
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
},
}
@classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase):
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
)
graph_file_path = os.path.join(
databases_directory_path, dataset_database.graph_database_name
)
graph_engine = create_graph_engine(
graph_database_provider=dataset_database.graph_database_provider,
graph_database_url=dataset_database.graph_database_url,
graph_database_name=dataset_database.graph_database_name,
graph_database_key=dataset_database.graph_database_key,
graph_file_path=graph_file_path,
graph_database_username=dataset_database.graph_database_connection_info.get(
"graph_database_username", ""
),
graph_database_password=dataset_database.graph_database_connection_info.get(
"graph_database_password", ""
),
graph_dataset_database_handler="",
graph_database_port="",
)
await graph_engine.delete_graph()

View file

@ -0,0 +1,167 @@
import os
import asyncio
import requests
import base64
import hashlib
from uuid import UUID
from typing import Optional
from cryptography.fernet import Fernet
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.users.models import User, DatasetDatabase
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
Improvements needed to be production ready:
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
a secret manager or a similar system should be used instead.
Quality of life improvements:
- Allow configuration of different Neo4j Aura plans and regions.
- Requests should be made async, currently a blocking requests library is used.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
Args:
dataset_id: Dataset UUID
user: User object who owns the dataset and is making the request
Returns:
dict: Connection details for the created Neo4j instance
"""
graph_config = get_graph_config()
if graph_config.graph_database_provider != "neo4j":
raise ValueError(
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
)
graph_db_name = f"{dataset_id}"
# Client credentials and encryption
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
if client_id is None or client_secret is None or tenant_id is None:
raise ValueError(
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
)
# Make the request with HTTP Basic Auth
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
# Too allow different configurations between datasets
payload = {
"version": "5",
"region": "europe-west1",
"memory": "1GB",
"name": graph_db_name[
0:29
], # TODO: Find better name to name Neo4j instance within 30 character limit
"type": "professional-db",
"tenant_id": tenant_id,
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
# Poll until the instance is running
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = ""
for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(
status_url, headers=headers
) # TODO: Use async requests with httpx
status = status_resp.json()["data"]["status"]
if status.lower() == "running":
return
await asyncio.sleep(10)
raise TimeoutError(
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
)
instance_id = response.json()["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
encrypted_db_password_string = encrypted_db_password_bytes.decode()
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": "neo4j",
"graph_database_key": graph_db_key,
"graph_database_connection_info": {
"graph_database_username": graph_db_username,
"graph_database_password": encrypted_db_password_string,
},
}
@classmethod
async def resolve_dataset_connection_info(
cls, dataset_database: DatasetDatabase
) -> DatasetDatabase:
"""
Resolve and decrypt connection info for the Neo4j dataset database.
In this case, decrypt the password stored in the database.
Args:
dataset_database: DatasetDatabase instance containing encrypted connection info.
"""
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
graph_db_password = cipher.decrypt(
dataset_database.graph_database_connection_info["graph_database_password"].encode()
).decode()
dataset_database.graph_database_connection_info["graph_database_password"] = (
graph_db_password
)
return dataset_database
@classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase):
pass

View file

@ -1 +1,2 @@
from .get_or_create_dataset_database import get_or_create_dataset_database
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info

View file

@ -1,11 +1,9 @@
import os
from uuid import UUID
from typing import Union
from typing import Union, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from cognee.base_config import get_base_config
from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vectordb_config
@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
vector_config = get_vectordb_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
return await handler["handler_instance"].create_dataset(dataset_id, user)
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
graph_config = get_graph_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
return await handler["handler_instance"].create_dataset(dataset_id, user)
async def _existing_dataset_database(
dataset_id: UUID,
user: User,
) -> Optional[DatasetDatabase]:
"""
Check if a DatasetDatabase row already exists for the given owner + dataset.
Return None if it doesn't exist, return the row if it does.
Args:
dataset_id:
user:
Returns:
DatasetDatabase or None
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
stmt = select(DatasetDatabase).where(
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id,
)
existing: DatasetDatabase = await session.scalar(stmt)
return existing
async def get_or_create_dataset_database(
dataset: Union[str, UUID],
user: User,
@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
If the row already exists, it is fetched and returned.
Otherwise a new one is created atomically and returned.
DatasetDatabase row contains connection and provider info for vector and graph databases.
Parameters
----------
user : User
@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user)
vector_config = get_vectordb_config()
graph_config = get_graph_config()
# If dataset is given as name make sure the dataset is created first
if isinstance(dataset, str):
async with db_engine.get_async_session() as session:
await create_dataset(dataset, user, session)
# Note: for hybrid databases both graph and vector DB name have to be the same
if graph_config.graph_database_provider == "kuzu":
graph_db_name = f"{dataset_id}.pkl"
else:
graph_db_name = f"{dataset_id}"
# If dataset database already exists return it
existing_dataset_database = await _existing_dataset_database(dataset_id, user)
if existing_dataset_database:
return existing_dataset_database
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}"
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector database URL
if vector_config.vector_db_provider == "lancedb":
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
else:
vector_db_url = vector_config.vector_database_url
# Determine graph database URL
graph_config_dict = await _get_graph_db_info(dataset_id, user)
vector_config_dict = await _get_vector_db_info(dataset_id, user)
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
if isinstance(dataset, str):
dataset = await create_dataset(dataset, user, session)
# Try to fetch an existing row first
stmt = select(DatasetDatabase).where(
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id,
)
existing: DatasetDatabase = await session.scalar(stmt)
if existing:
return existing
# If there are no existing rows build a new row
record = DatasetDatabase(
owner_id=user.id,
dataset_id=dataset_id,
vector_database_name=vector_db_name,
graph_database_name=graph_db_name,
vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_db_url,
graph_database_url=graph_config.graph_database_url,
vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_config.graph_database_key,
**graph_config_dict, # Unpack graph db config
**vector_config_dict, # Unpack vector db config
)
try:

View file

@ -0,0 +1,42 @@
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
vector_config = get_vectordb_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
graph_config = get_graph_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
async def resolve_dataset_database_connection_info(
dataset_database: DatasetDatabase,
) -> DatasetDatabase:
"""
Resolve the connection info for the given DatasetDatabase instance.
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
Args:
dataset_database: DatasetDatabase instance
Returns:
DatasetDatabase instance with resolved connection info
"""
dataset_database = await _get_vector_db_connection_info(dataset_database)
dataset_database = await _get_graph_db_connection_info(dataset_database)
return dataset_database

View file

@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
vector_db_name: str = ""
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
vector_dataset_database_handler: str = "lancedb"
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
"vector_dataset_database_handler": self.vector_dataset_database_handler,
}

View file

@ -12,6 +12,7 @@ def create_vector_engine(
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Create a vector database engine based on the specified provider.

View file

@ -0,0 +1,49 @@
import os
from uuid import UUID
from typing import Optional
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.modules.users.models import User
from cognee.modules.users.models import DatasetDatabase
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for interacting with LanceDB Dataset databases.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
vector_config = get_vectordb_config()
base_config = get_base_config()
if vector_config.vector_db_provider != "lancedb":
raise ValueError(
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
)
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
vector_db_name = f"{dataset_id}.lance.db"
return {
"vector_database_provider": vector_config.vector_db_provider,
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
"vector_database_key": vector_config.vector_db_key,
"vector_database_name": vector_db_name,
}
@classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase):
vector_engine = create_vector_engine(
vector_db_provider=dataset_database.vector_database_provider,
vector_db_url=dataset_database.vector_database_url,
vector_db_key=dataset_database.vector_database_key,
vector_db_name=dataset_database.vector_database_name,
)
await vector_engine.prune()

View file

@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
from abc import abstractmethod
from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema
from uuid import UUID
from cognee.modules.users.models import User
class VectorDBInterface(Protocol):
@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
- Any: The schema object suitable for this vector database
"""
return model_type
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Return a dictionary with connection info for a vector database for the given dataset.
Function can auto handle deploying of the actual database if needed, but is not necessary.
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
Needed for Cognee multi-tenant/multi-user and backend access control support.
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
From which internal mapping of dataset -> database connection info will be done.
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
Args:
dataset_id: UUID of the dataset if needed by the database creation logic
user: User object if needed by the database creation logic
Returns:
dict: Connection info for the created vector database instance.
"""
pass
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
"""
Delete the vector database for the given dataset.
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
Args:
dataset_id: UUID of the dataset
user: User object
"""
pass

View file

@ -1,17 +1,82 @@
from sqlalchemy.exc import OperationalError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector.config import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.shared.cache import delete_cache
from cognee.modules.users.models import DatasetDatabase
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def prune_graph_databases():
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
graph_config = get_graph_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine()
try:
data = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the graph database
for data_item in data:
await _prune_graph_db(data_item)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
e,
)
return
async def prune_vector_databases():
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
vector_config = get_vectordb_config()
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
return await handler["handler_instance"].delete_dataset(dataset_database)
db_engine = get_relational_engine()
try:
data = await db_engine.get_all_data_from_table("dataset_database")
# Go through each dataset database and delete the vector database
for data_item in data:
await _prune_vector_db(data_item)
except (OperationalError, EntityNotFoundError) as e:
logger.debug(
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
e,
)
return
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
if graph:
# Note: prune system should not be available through the API, it has no permission checks and will
# delete all graph and vector databases if called. It should only be used in development or testing environments.
if graph and not backend_access_control_enabled():
graph_engine = await get_graph_engine()
await graph_engine.delete_graph()
elif graph and backend_access_control_enabled():
await prune_graph_databases()
if vector:
if vector and not backend_access_control_enabled():
vector_engine = get_vector_engine()
await vector_engine.prune()
elif vector and backend_access_control_enabled():
await prune_vector_databases()
if metadata:
db_engine = get_relational_engine()

View file

@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or backend_access_control_enabled()
os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
)
fastapi_users = get_fastapi_users()

View file

@ -1,6 +1,6 @@
from datetime import datetime, timezone
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text
from cognee.infrastructure.databases.relational import Base
@ -12,8 +12,8 @@ class DatasetDatabase(Base):
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
)
vector_database_name = Column(String, unique=True, nullable=False)
graph_database_name = Column(String, unique=True, nullable=False)
vector_database_name = Column(String, unique=False, nullable=False)
graph_database_name = Column(String, unique=False, nullable=False)
vector_database_provider = Column(String, unique=False, nullable=False)
graph_database_provider = Column(String, unique=False, nullable=False)
@ -24,5 +24,14 @@ class DatasetDatabase(Base):
vector_database_key = Column(String, unique=False, nullable=True)
graph_database_key = Column(String, unique=False, nullable=True)
# configuration details for different database types. This would make it more flexible to add new database types
# without changing the database schema.
graph_database_connection_info = Column(
JSON, unique=False, nullable=False, server_default=text("'{}'")
)
vector_database_connection_info = Column(
JSON, unique=False, nullable=False, server_default=text("'{}'")
)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))

View file

@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None):
# Get a configured logger and log system information
logger = structlog.get_logger(name if name else __name__)
logger.warning(
"From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation."
)
if logs_dir is not None:
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)

View file

@ -0,0 +1,135 @@
import asyncio
import os
# Set custom dataset database handler environment variable
os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler"
os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler"
import cognee
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
from cognee.shared.logging_utils import setup_logging, ERROR
from cognee.api.v1.search import SearchType
class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
@classmethod
async def create_dataset(cls, dataset_id, user):
import pathlib
cognee_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
)
).resolve()
)
databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id))
os.makedirs(databases_directory_path, exist_ok=True)
vector_db_name = "test.lance.db"
return {
"vector_database_name": vector_db_name,
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
"vector_database_provider": "lancedb",
}
class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
@classmethod
async def create_dataset(cls, dataset_id, user):
databases_directory_path = os.path.join("databases", str(user.id))
os.makedirs(databases_directory_path, exist_ok=True)
graph_db_name = "test.kuzu"
return {
"graph_database_name": graph_db_name,
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
"graph_database_provider": "kuzu",
}
async def main():
import pathlib
data_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler"
)
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
)
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
# Add custom dataset database handler
from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import (
use_dataset_database_handler,
)
use_dataset_database_handler(
"custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb"
)
use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu")
# Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text
text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""
print("Adding text to cognee:")
print(text.strip())
# Add the text, and make it available for cognify
await cognee.add(text)
print("Text added successfully.\n")
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()
print("Cognify process complete.\n")
query_text = "Tell me about NLP"
print(f"Searching cognee for insights with query: '{query_text}'")
# Query cognee for insights on the added text
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
)
print("Search results:")
# Display results
for result_text in search_results:
print(result_text)
default_user = await get_default_user()
# Assert that the custom database files were created based on the custom dataset database handlers
assert os.path.exists(
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu")
), "Graph database file not found."
assert os.path.exists(
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db")
), "Vector database file not found."
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())