feat: add dataset database handler logic and neo4j/lancedb/kuzu handlers (#1776)
<!-- .github/pull_request_template.md -->
## Description
Add ability to use multi tenant multi user mode with Neo4j
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Multi-user support with per-dataset database isolation enabled by
default, allowing backend access control for secure data separation.
* Configurable database handlers via environment variables
(GRAPH_DATASET_DATABASE_HANDLER, VECTOR_DATASET_DATABASE_HANDLER) for
flexible deployment options.
* **Chores**
* Database schema migration to support per-user dataset database
configurations.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
46ddd4fd12
31 changed files with 1143 additions and 74 deletions
|
|
@ -97,6 +97,8 @@ DB_NAME=cognee_db
|
|||
|
||||
# Default (local file-based)
|
||||
GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
GRAPH_DATASET_DATABASE_HANDLER="kuzu"
|
||||
|
||||
# -- To switch to Remote Kuzu uncomment and fill these: -------------------------------------------------------------
|
||||
#GRAPH_DATABASE_PROVIDER="kuzu"
|
||||
|
|
@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb"
|
|||
# Not needed if a cloud vector database is not used
|
||||
VECTOR_DB_URL=
|
||||
VECTOR_DB_KEY=
|
||||
# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset
|
||||
VECTOR_DATASET_DATABASE_HANDLER="lancedb"
|
||||
|
||||
################################################################################
|
||||
# 🧩 Ontology resolver settings
|
||||
|
|
|
|||
2
.github/workflows/db_examples_tests.yml
vendored
2
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -61,6 +61,7 @@ jobs:
|
|||
- name: Run Neo4j Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -142,6 +143,7 @@ jobs:
|
|||
- name: Run PGVector Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
1
.github/workflows/distributed_test.yml
vendored
1
.github/workflows/distributed_test.yml
vendored
|
|
@ -47,6 +47,7 @@ jobs:
|
|||
- name: Run Distributed Cognee (Modal)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
26
.github/workflows/e2e_tests.yml
vendored
26
.github/workflows/e2e_tests.yml
vendored
|
|
@ -147,6 +147,7 @@ jobs:
|
|||
- name: Run Deduplication Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
|
|
@ -211,6 +212,31 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_parallel_databases.py
|
||||
|
||||
test-dataset-database-handler:
|
||||
name: Test dataset database handlers in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run dataset databases handler test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_dataset_database_handler.py
|
||||
|
||||
test-permissions:
|
||||
name: Test permissions with different situations in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
1
.github/workflows/examples_tests.yml
vendored
1
.github/workflows/examples_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
|||
- name: Run Descriptive Graph Metrics Example
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
1
.github/workflows/graph_db_tests.yml
vendored
1
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -78,6 +78,7 @@ jobs:
|
|||
- name: Run default Neo4j
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
|
|||
3
.github/workflows/temporal_graph_tests.yml
vendored
3
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -72,6 +72,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (lancedb + sqlite)
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -123,6 +124,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Kuzu (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
@ -189,6 +191,7 @@ jobs:
|
|||
- name: Run Temporal Graph with Neo4j (postgres + pgvector)
|
||||
env:
|
||||
ENV: dev
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
|
|
|||
3
.github/workflows/vector_db_tests.yml
vendored
3
.github/workflows/vector_db_tests.yml
vendored
|
|
@ -92,6 +92,7 @@ jobs:
|
|||
- name: Run PGVector Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
|
|
@ -127,4 +128,4 @@ jobs:
|
|||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
run: uv run python ./cognee/tests/test_lancedb.py
|
||||
|
|
|
|||
3
.github/workflows/weighted_edges_tests.yml
vendored
3
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -94,6 +94,7 @@ jobs:
|
|||
|
||||
- name: Run Weighted Edges Tests
|
||||
env:
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||
|
|
@ -165,5 +166,3 @@ jobs:
|
|||
uses: astral-sh/ruff-action@v2
|
||||
with:
|
||||
args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py"
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
"""Expand dataset database with json connection field
|
||||
|
||||
Revision ID: 46a6ce2bd2b2
|
||||
Revises: 76625596c5c3
|
||||
Create Date: 2025-11-25 17:56:28.938931
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "46a6ce2bd2b2"
|
||||
down_revision: Union[str, None] = "76625596c5c3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
graph_constraint_name = "dataset_database_graph_database_name_key"
|
||||
vector_constraint_name = "dataset_database_vector_database_name_key"
|
||||
TABLE_NAME = "dataset_database"
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def _recreate_table_without_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def _recreate_table_with_unique_constraint_sqlite(op, insp):
|
||||
"""
|
||||
SQLite cannot drop unique constraints on individual columns. We must:
|
||||
1. Create a new table without the unique constraints.
|
||||
2. Copy data from the old table.
|
||||
3. Drop the old table.
|
||||
4. Rename the new table.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create new table definition (without unique constraints)
|
||||
op.create_table(
|
||||
f"{TABLE_NAME}_new",
|
||||
sa.Column("owner_id", sa.UUID()),
|
||||
sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False),
|
||||
sa.Column("vector_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("graph_database_name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("vector_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("graph_database_provider", sa.String(), nullable=False),
|
||||
sa.Column("vector_database_url", sa.String()),
|
||||
sa.Column("graph_database_url", sa.String()),
|
||||
sa.Column("vector_database_key", sa.String()),
|
||||
sa.Column("graph_database_key", sa.String()),
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime()),
|
||||
sa.Column("updated_at", sa.DateTime()),
|
||||
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Copy data into new table
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO {TABLE_NAME}_new
|
||||
SELECT
|
||||
owner_id,
|
||||
dataset_id,
|
||||
vector_database_name,
|
||||
graph_database_name,
|
||||
vector_database_provider,
|
||||
graph_database_provider,
|
||||
vector_database_url,
|
||||
graph_database_url,
|
||||
vector_database_key,
|
||||
graph_database_key,
|
||||
COALESCE(graph_database_connection_info, '{{}}'),
|
||||
COALESCE(vector_database_connection_info, '{{}}'),
|
||||
created_at,
|
||||
updated_at
|
||||
FROM {TABLE_NAME}
|
||||
""")
|
||||
)
|
||||
|
||||
# Drop old table
|
||||
op.drop_table(TABLE_NAME)
|
||||
|
||||
# Rename new table
|
||||
op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
unique_constraints = insp.get_unique_constraints(TABLE_NAME)
|
||||
|
||||
vector_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "vector_database_connection_info"
|
||||
)
|
||||
if not vector_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
graph_database_connection_info_column = _get_column(
|
||||
insp, "dataset_database", "graph_database_connection_info"
|
||||
)
|
||||
if not graph_database_connection_info_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_database_connection_info",
|
||||
sa.JSON(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'"),
|
||||
),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Drop the unique constraint to make unique=False
|
||||
graph_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == graph_constraint_name:
|
||||
graph_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
vector_constraint_to_drop = None
|
||||
for uc in unique_constraints:
|
||||
# Check if the constraint covers ONLY the target column
|
||||
if uc["name"] == vector_constraint_name:
|
||||
vector_constraint_to_drop = uc["name"]
|
||||
break
|
||||
|
||||
if (
|
||||
vector_constraint_to_drop
|
||||
and graph_constraint_to_drop
|
||||
and op.get_context().dialect.name == "postgresql"
|
||||
):
|
||||
# PostgreSQL
|
||||
batch_op.drop_constraint(graph_constraint_name, type_="unique")
|
||||
batch_op.drop_constraint(vector_constraint_name, type_="unique")
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
conn = op.get_bind()
|
||||
# Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly
|
||||
# So we need to check for them and drop them by recreating the table (altering column also won't work)
|
||||
result = conn.execute(sa.text("PRAGMA index_list('dataset_database')"))
|
||||
rows = result.fetchall()
|
||||
unique_auto_indexes = [row for row in rows if row[3] == "u"]
|
||||
for row in unique_auto_indexes:
|
||||
result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')"))
|
||||
index_info = result.fetchall()
|
||||
if index_info[0][2] == "vector_database_name":
|
||||
# In case a unique index exists on vector_database_name, drop it and the graph_database_name one
|
||||
_recreate_table_without_unique_constraint_sqlite(op, insp)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
_recreate_table_with_unique_constraint_sqlite(op, insp)
|
||||
elif op.get_context().dialect.name == "postgresql":
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"])
|
||||
|
||||
with op.batch_alter_table("dataset_database", schema=None) as batch_op:
|
||||
# Re-add the unique constraint to return to unique=True
|
||||
batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"])
|
||||
|
||||
op.drop_column("dataset_database", "vector_database_connection_info")
|
||||
op.drop_column("dataset_database", "graph_database_connection_info")
|
||||
|
|
@ -4,9 +4,10 @@ from typing import Union
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
|
|
@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
||||
graph_db_config = get_graph_config()
|
||||
vector_db_config = get_vectordb_config()
|
||||
|
||||
graph_handler = graph_db_config.graph_dataset_database_handler
|
||||
vector_handler = vector_db_config.vector_dataset_database_handler
|
||||
from cognee.infrastructure.databases.dataset_database_handler import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
if graph_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if vector_handler not in supported_dataset_database_handlers:
|
||||
raise EnvironmentError(
|
||||
"Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[graph_handler]["handler_provider"]
|
||||
!= graph_db_config.graph_database_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected graph database provider: {graph_db_config.graph_database_provider}\n"
|
||||
f"Selected graph dataset to database handler: {graph_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
if (
|
||||
supported_dataset_database_handlers[vector_handler]["handler_provider"]
|
||||
!= vector_db_config.vector_db_provider
|
||||
):
|
||||
raise EnvironmentError(
|
||||
"The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n"
|
||||
f"Selected vector database provider: {vector_db_config.vector_db_provider}\n"
|
||||
f"Selected vector dataset to database handler: {vector_handler}\n"
|
||||
f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
|
|
@ -41,12 +79,7 @@ def backend_access_control_enabled():
|
|||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return multi_user_support_possible()
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
# Ensure that all connection info is resolved properly
|
||||
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
|
||||
|
||||
base_config = get_base_config()
|
||||
data_root_directory = os.path.join(
|
||||
|
|
@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
)
|
||||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
# TODO: Add better handling of vector and graph config accross Cognee.
|
||||
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||
vector_config = {
|
||||
"vector_db_provider": dataset_database.vector_database_provider,
|
||||
"vector_db_url": dataset_database.vector_database_url,
|
||||
|
|
@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
"graph_file_path": os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
"graph_database_username": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
"graph_database_password": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
"graph_dataset_database_handler": "",
|
||||
"graph_database_port": "",
|
||||
}
|
||||
|
||||
storage_config = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
|
||||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
from .use_dataset_database_handler import use_dataset_database_handler
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.users.models.User import User
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
class DatasetDatabaseHandlerInterface(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
The returned dictionary is stored verbatim in the relational database and is later passed to
|
||||
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
|
||||
returning only references to secrets or role identifiers, not plaintext credentials.
|
||||
|
||||
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created graph or vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve runtime connection details for a dataset’s backing graph/vector database.
|
||||
Function is intended to be overwritten to implement custom logic for resolving connection info.
|
||||
|
||||
This method is invoked right before the application opens a connection for a given dataset.
|
||||
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
|
||||
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
|
||||
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
|
||||
|
||||
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
|
||||
connection info and only change parameters related to its appropriate database, the resolution function will then
|
||||
be called one after another with the updated DatasetDatabase value from the previous function as the input.
|
||||
|
||||
Typical behavior:
|
||||
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
|
||||
or api_url/api_key), return them as-is.
|
||||
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
|
||||
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
|
||||
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
|
||||
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
|
||||
to the caller.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row from the relational database
|
||||
Returns:
|
||||
DatasetDatabase: Updated instance with resolved connection info
|
||||
"""
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
|
||||
"""
|
||||
Delete the graph or vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import (
|
||||
Neo4jAuraDevDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
|
||||
LanceDBDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||
KuzuDatasetDatabaseHandler,
|
||||
)
|
||||
|
||||
supported_dataset_database_handlers = {
|
||||
"neo4j_aura_dev": {
|
||||
"handler_instance": Neo4jAuraDevDatasetDatabaseHandler,
|
||||
"handler_provider": "neo4j",
|
||||
},
|
||||
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
|
||||
|
||||
def use_dataset_database_handler(
|
||||
dataset_database_handler_name, dataset_database_handler, dataset_database_provider
|
||||
):
|
||||
supported_dataset_database_handlers[dataset_database_handler_name] = {
|
||||
"handler_instance": dataset_database_handler,
|
||||
"handler_provider": dataset_database_provider,
|
||||
}
|
||||
|
|
@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_filename: str = ""
|
||||
graph_model: object = KnowledgeGraph
|
||||
graph_topology: object = KnowledgeGraph
|
||||
graph_dataset_database_handler: str = "kuzu"
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||
|
||||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||
|
|
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
"model_config": self.model_config,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
def to_hashable_dict(self) -> dict:
|
||||
|
|
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def create_graph_engine(
|
|||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
graph_dataset_database_handler="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with Kuzu Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Kuzu instance
|
||||
|
||||
"""
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "kuzu":
|
||||
raise ValueError(
|
||||
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
graph_db_url = graph_config.graph_database_url
|
||||
graph_db_key = graph_config.graph_database_key
|
||||
graph_db_username = graph_config.graph_database_username
|
||||
graph_db_password = graph_config.graph_database_password
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": graph_config.graph_database_provider,
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(dataset_database.owner_id)
|
||||
)
|
||||
graph_file_path = os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
)
|
||||
graph_engine = create_graph_engine(
|
||||
graph_database_provider=dataset_database.graph_database_provider,
|
||||
graph_database_url=dataset_database.graph_database_url,
|
||||
graph_database_name=dataset_database.graph_database_name,
|
||||
graph_database_key=dataset_database.graph_database_key,
|
||||
graph_file_path=graph_file_path,
|
||||
graph_database_username=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
graph_database_password=dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
graph_dataset_database_handler="",
|
||||
graph_database_port="",
|
||||
)
|
||||
await graph_engine.delete_graph()
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
import base64
|
||||
import hashlib
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.modules.users.models import User, DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases.
|
||||
This handler creates a new Neo4j Aura instance for each Cognee dataset created.
|
||||
|
||||
Improvements needed to be production ready:
|
||||
- Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database,
|
||||
a secret manager or a similar system should be used instead.
|
||||
|
||||
Quality of life improvements:
|
||||
- Allow configuration of different Neo4j Aura plans and regions.
|
||||
- Requests should be made async, currently a blocking requests library is used.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Neo4j instance
|
||||
|
||||
"""
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "neo4j":
|
||||
raise ValueError(
|
||||
"Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
# Client credentials and encryption
|
||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
|
||||
if client_id is None or client_secret is None or tenant_id is None:
|
||||
raise ValueError(
|
||||
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||
)
|
||||
|
||||
# Make the request with HTTP Basic Auth
|
||||
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||
url = "https://api.neo4j.io/oauth/token"
|
||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||
|
||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||
resp.raise_for_status() # raises if the request failed
|
||||
return resp.json()
|
||||
|
||||
resp = get_aura_token(client_id, client_secret)
|
||||
|
||||
url = "https://api.neo4j.io/v1/instances"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Authorization": f"Bearer {resp['access_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||
# Too allow different configurations between datasets
|
||||
payload = {
|
||||
"version": "5",
|
||||
"region": "europe-west1",
|
||||
"memory": "1GB",
|
||||
"name": graph_db_name[
|
||||
0:29
|
||||
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||
"type": "professional-db",
|
||||
"tenant_id": tenant_id,
|
||||
"cloud_provider": "gcp",
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||
graph_db_url = response.json()["data"]["connection_url"]
|
||||
graph_db_key = resp["access_token"]
|
||||
graph_db_username = response.json()["data"]["username"]
|
||||
graph_db_password = response.json()["data"]["password"]
|
||||
|
||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||
# Poll until the instance is running
|
||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||
status = ""
|
||||
for attempt in range(30): # Try for up to ~5 minutes
|
||||
status_resp = requests.get(
|
||||
status_url, headers=headers
|
||||
) # TODO: Use async requests with httpx
|
||||
status = status_resp.json()["data"]["status"]
|
||||
if status.lower() == "running":
|
||||
return
|
||||
await asyncio.sleep(10)
|
||||
raise TimeoutError(
|
||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||
)
|
||||
|
||||
instance_id = response.json()["data"]["id"]
|
||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||
|
||||
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": encrypted_db_password_string,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||
In this case, decrypt the password stored in the database.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||
"""
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
graph_db_password = cipher.decrypt(
|
||||
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||
).decode()
|
||||
|
||||
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||
graph_db_password
|
||||
)
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
pass
|
||||
|
|
@ -1 +1,2 @@
|
|||
from .get_or_create_dataset_database import get_or_create_dataset_database
|
||||
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
|
|
@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase
|
|||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
graph_config = get_graph_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||
return await handler["handler_instance"].create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _existing_dataset_database(
|
||||
dataset_id: UUID,
|
||||
user: User,
|
||||
) -> Optional[DatasetDatabase]:
|
||||
"""
|
||||
Check if a DatasetDatabase row already exists for the given owner + dataset.
|
||||
Return None if it doesn't exist, return the row if it does.
|
||||
Args:
|
||||
dataset_id:
|
||||
user:
|
||||
|
||||
Returns:
|
||||
DatasetDatabase or None
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = select(DatasetDatabase).where(
|
||||
DatasetDatabase.owner_id == user.id,
|
||||
DatasetDatabase.dataset_id == dataset_id,
|
||||
)
|
||||
existing: DatasetDatabase = await session.scalar(stmt)
|
||||
return existing
|
||||
|
||||
|
||||
async def get_or_create_dataset_database(
|
||||
dataset: Union[str, UUID],
|
||||
user: User,
|
||||
|
|
@ -25,6 +70,8 @@ async def get_or_create_dataset_database(
|
|||
• If the row already exists, it is fetched and returned.
|
||||
• Otherwise a new one is created atomically and returned.
|
||||
|
||||
DatasetDatabase row contains connection and provider info for vector and graph databases.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
user : User
|
||||
|
|
@ -36,59 +83,26 @@ async def get_or_create_dataset_database(
|
|||
|
||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
graph_config = get_graph_config()
|
||||
# If dataset is given as name make sure the dataset is created first
|
||||
if isinstance(dataset, str):
|
||||
async with db_engine.get_async_session() as session:
|
||||
await create_dataset(dataset, user, session)
|
||||
|
||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||
if graph_config.graph_database_provider == "kuzu":
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
else:
|
||||
graph_db_name = f"{dataset_id}"
|
||||
# If dataset database already exists return it
|
||||
existing_dataset_database = await _existing_dataset_database(dataset_id, user)
|
||||
if existing_dataset_database:
|
||||
return existing_dataset_database
|
||||
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
else:
|
||||
vector_db_name = f"{dataset_id}"
|
||||
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
# Determine vector database URL
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
||||
else:
|
||||
vector_db_url = vector_config.vector_database_url
|
||||
|
||||
# Determine graph database URL
|
||||
graph_config_dict = await _get_graph_db_info(dataset_id, user)
|
||||
vector_config_dict = await _get_vector_db_info(dataset_id, user)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Create dataset if it doesn't exist
|
||||
if isinstance(dataset, str):
|
||||
dataset = await create_dataset(dataset, user, session)
|
||||
|
||||
# Try to fetch an existing row first
|
||||
stmt = select(DatasetDatabase).where(
|
||||
DatasetDatabase.owner_id == user.id,
|
||||
DatasetDatabase.dataset_id == dataset_id,
|
||||
)
|
||||
existing: DatasetDatabase = await session.scalar(stmt)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# If there are no existing rows build a new row
|
||||
record = DatasetDatabase(
|
||||
owner_id=user.id,
|
||||
dataset_id=dataset_id,
|
||||
vector_database_name=vector_db_name,
|
||||
graph_database_name=graph_db_name,
|
||||
vector_database_provider=vector_config.vector_db_provider,
|
||||
graph_database_provider=graph_config.graph_database_provider,
|
||||
vector_database_url=vector_db_url,
|
||||
graph_database_url=graph_config.graph_database_url,
|
||||
vector_database_key=vector_config.vector_db_key,
|
||||
graph_database_key=graph_config.graph_database_key,
|
||||
**graph_config_dict, # Unpack graph db config
|
||||
**vector_config_dict, # Unpack vector db config
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||
|
||||
|
||||
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||
graph_config = get_graph_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||
|
||||
|
||||
async def resolve_dataset_database_connection_info(
|
||||
dataset_database: DatasetDatabase,
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve the connection info for the given DatasetDatabase instance.
|
||||
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance
|
||||
Returns:
|
||||
DatasetDatabase instance with resolved connection info
|
||||
"""
|
||||
dataset_database = await _get_vector_db_connection_info(dataset_database)
|
||||
dataset_database = await _get_graph_db_connection_info(dataset_database)
|
||||
return dataset_database
|
||||
|
|
@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
|
|||
vector_db_name: str = ""
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
vector_dataset_database_handler: str = "lancedb"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
|
|||
"vector_db_name": self.vector_db_name,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ def create_vector_engine(
|
|||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
vector_dataset_database_handler: str = "",
|
||||
):
|
||||
"""
|
||||
Create a vector database engine based on the specified provider.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with LanceDB Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
base_config = get_base_config()
|
||||
|
||||
if vector_config.vector_db_provider != "lancedb":
|
||||
raise ValueError(
|
||||
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
|
||||
)
|
||||
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
|
||||
return {
|
||||
"vector_database_provider": vector_config.vector_db_provider,
|
||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||
"vector_database_key": vector_config.vector_db_key,
|
||||
"vector_database_name": vector_db_name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
vector_engine = create_vector_engine(
|
||||
vector_db_provider=dataset_database.vector_database_provider,
|
||||
vector_db_url=dataset_database.vector_database_url,
|
||||
vector_db_key=dataset_database.vector_database_key,
|
||||
vector_db_name=dataset_database.vector_database_name,
|
||||
)
|
||||
await vector_engine.prune()
|
||||
|
|
@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
|
|||
from abc import abstractmethod
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from .models.PayloadSchema import PayloadSchema
|
||||
from uuid import UUID
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
class VectorDBInterface(Protocol):
|
||||
|
|
@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
|
|||
- Any: The schema object suitable for this vector database
|
||||
"""
|
||||
return model_type
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with connection info for a vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
||||
"""
|
||||
Delete the vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
user: User object
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,17 +1,82 @@
|
|||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.shared.cache import delete_cache
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def prune_graph_databases():
|
||||
async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict:
|
||||
graph_config = get_graph_config()
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
try:
|
||||
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||
# Go through each dataset database and delete the graph database
|
||||
for data_item in data:
|
||||
await _prune_graph_db(data_item)
|
||||
except (OperationalError, EntityNotFoundError) as e:
|
||||
logger.debug(
|
||||
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def prune_vector_databases():
|
||||
async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||
return await handler["handler_instance"].delete_dataset(dataset_database)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
try:
|
||||
data = await db_engine.get_all_data_from_table("dataset_database")
|
||||
# Go through each dataset database and delete the vector database
|
||||
for data_item in data:
|
||||
await _prune_vector_db(data_item)
|
||||
except (OperationalError, EntityNotFoundError) as e:
|
||||
logger.debug(
|
||||
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||
if graph:
|
||||
# Note: prune system should not be available through the API, it has no permission checks and will
|
||||
# delete all graph and vector databases if called. It should only be used in development or testing environments.
|
||||
if graph and not backend_access_control_enabled():
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
elif graph and backend_access_control_enabled():
|
||||
await prune_graph_databases()
|
||||
|
||||
if vector:
|
||||
if vector and not backend_access_control_enabled():
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.prune()
|
||||
elif vector and backend_access_control_enabled():
|
||||
await prune_vector_databases()
|
||||
|
||||
if metadata:
|
||||
db_engine = get_relational_engine()
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
|
|||
|
||||
# Check environment variable to determine authentication requirement
|
||||
REQUIRE_AUTHENTICATION = (
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||
or backend_access_control_enabled()
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
|
||||
or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
|
||||
)
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
|
||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
|
|
@ -12,8 +12,8 @@ class DatasetDatabase(Base):
|
|||
UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True
|
||||
)
|
||||
|
||||
vector_database_name = Column(String, unique=True, nullable=False)
|
||||
graph_database_name = Column(String, unique=True, nullable=False)
|
||||
vector_database_name = Column(String, unique=False, nullable=False)
|
||||
graph_database_name = Column(String, unique=False, nullable=False)
|
||||
|
||||
vector_database_provider = Column(String, unique=False, nullable=False)
|
||||
graph_database_provider = Column(String, unique=False, nullable=False)
|
||||
|
|
@ -24,5 +24,14 @@ class DatasetDatabase(Base):
|
|||
vector_database_key = Column(String, unique=False, nullable=True)
|
||||
graph_database_key = Column(String, unique=False, nullable=True)
|
||||
|
||||
# configuration details for different database types. This would make it more flexible to add new database types
|
||||
# without changing the database schema.
|
||||
graph_database_connection_info = Column(
|
||||
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||
)
|
||||
vector_database_connection_info = Column(
|
||||
JSON, unique=False, nullable=False, server_default=text("'{}'")
|
||||
)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
|
|
|||
|
|
@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None):
|
|||
# Get a configured logger and log system information
|
||||
logger = structlog.get_logger(name if name else __name__)
|
||||
|
||||
logger.warning(
|
||||
"From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation."
|
||||
)
|
||||
|
||||
if logs_dir is not None:
|
||||
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
|
||||
|
||||
|
|
|
|||
135
cognee/tests/test_dataset_database_handler.py
Normal file
135
cognee/tests/test_dataset_database_handler.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
# Set custom dataset database handler environment variable
|
||||
os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler"
|
||||
os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler"
|
||||
|
||||
import cognee
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
|
||||
class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id, user):
|
||||
import pathlib
|
||||
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id))
|
||||
os.makedirs(databases_directory_path, exist_ok=True)
|
||||
|
||||
vector_db_name = "test.lance.db"
|
||||
|
||||
return {
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||
"vector_database_provider": "lancedb",
|
||||
}
|
||||
|
||||
|
||||
class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id, user):
|
||||
databases_directory_path = os.path.join("databases", str(user.id))
|
||||
os.makedirs(databases_directory_path, exist_ok=True)
|
||||
|
||||
graph_db_name = "test.kuzu"
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": os.path.join(databases_directory_path, graph_db_name),
|
||||
"graph_database_provider": "kuzu",
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
import pathlib
|
||||
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler"
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler"
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
# Add custom dataset database handler
|
||||
from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import (
|
||||
use_dataset_database_handler,
|
||||
)
|
||||
|
||||
use_dataset_database_handler(
|
||||
"custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb"
|
||||
)
|
||||
use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu")
|
||||
|
||||
# Create a clean slate for cognee -- reset data and system state
|
||||
print("Resetting cognee data...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
print("Data reset complete.\n")
|
||||
|
||||
# cognee knowledge graph will be created based on this text
|
||||
text = """
|
||||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
|
||||
print("Adding text to cognee:")
|
||||
print(text.strip())
|
||||
|
||||
# Add the text, and make it available for cognify
|
||||
await cognee.add(text)
|
||||
print("Text added successfully.\n")
|
||||
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
await cognee.cognify()
|
||||
print("Cognify process complete.\n")
|
||||
|
||||
query_text = "Tell me about NLP"
|
||||
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||
# Query cognee for insights on the added text
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
|
||||
)
|
||||
|
||||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
|
||||
default_user = await get_default_user()
|
||||
# Assert that the custom database files were created based on the custom dataset database handlers
|
||||
assert os.path.exists(
|
||||
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu")
|
||||
), "Graph database file not found."
|
||||
assert os.path.exists(
|
||||
os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db")
|
||||
), "Vector database file not found."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
Loading…
Add table
Reference in a new issue