feat: enable multi user for falkor (#1689)

<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
Added multi-user support for Falkor. Adding support for the rest of the
graph dbs should be a bit easier after this first one, especially since
Falkor is hybrid.
There are a few things code quality wise that might need changing, I am
open to suggestions.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Andrej Milicevic <milicevi@Andrejs-MacBook-Pro.local>
Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
Co-authored-by: Igor Ilic <igorilic03@gmail.com>
This commit is contained in:
Andrej Milićević 2025-11-11 17:03:48 +01:00 committed by GitHub
parent 78b825f338
commit 8e8aecb76f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 172 additions and 15 deletions

View file

@ -0,0 +1,98 @@
"""Expand dataset database for multi user
Revision ID: 76625596c5c3
Revises: 211ab850ef3d
Create Date: 2025-10-30 12:55:20.239562
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "76625596c5c3"
down_revision: Union[str, None] = "c946955da633"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
vector_database_provider_column = _get_column(
insp, "dataset_database", "vector_database_provider"
)
if not vector_database_provider_column:
op.add_column(
"dataset_database",
sa.Column(
"vector_database_provider",
sa.String(),
unique=False,
nullable=False,
server_default="lancedb",
),
)
graph_database_provider_column = _get_column(
insp, "dataset_database", "graph_database_provider"
)
if not graph_database_provider_column:
op.add_column(
"dataset_database",
sa.Column(
"graph_database_provider",
sa.String(),
unique=False,
nullable=False,
server_default="kuzu",
),
)
vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url")
if not vector_database_url_column:
op.add_column(
"dataset_database",
sa.Column("vector_database_url", sa.String(), unique=False, nullable=True),
)
graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url")
if not graph_database_url_column:
op.add_column(
"dataset_database",
sa.Column("graph_database_url", sa.String(), unique=False, nullable=True),
)
vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key")
if not vector_database_key_column:
op.add_column(
"dataset_database",
sa.Column("vector_database_key", sa.String(), unique=False, nullable=True),
)
graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key")
if not graph_database_key_column:
op.add_column(
"dataset_database",
sa.Column("graph_database_key", sa.String(), unique=False, nullable=True),
)
def downgrade() -> None:
op.drop_column("dataset_database", "vector_database_provider")
op.drop_column("dataset_database", "graph_database_provider")
op.drop_column("dataset_database", "vector_database_url")
op.drop_column("dataset_database", "graph_database_url")
op.drop_column("dataset_database", "vector_database_key")
op.drop_column("dataset_database", "graph_database_key")

View file

@ -16,8 +16,8 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
async def set_session_user_context_variable(user):
@ -28,8 +28,8 @@ def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
)
@ -69,8 +69,6 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
"""
base_config = get_base_config()
if not backend_access_control_enabled():
return
@ -79,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user)
base_config = get_base_config()
data_root_directory = os.path.join(
base_config.data_root_directory, str(user.tenant_id or user.id)
)
@ -88,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# Set vector and graph database configuration based on dataset database information
vector_config = {
"vector_db_url": os.path.join(
databases_directory_path, dataset_database.vector_database_name
),
"vector_db_key": "",
"vector_db_provider": "lancedb",
"vector_db_provider": dataset_database.vector_database_provider,
"vector_db_url": dataset_database.vector_database_url,
"vector_db_key": dataset_database.vector_database_key,
"vector_db_name": dataset_database.vector_database_name,
}
graph_config = {
"graph_database_provider": "kuzu",
"graph_database_provider": dataset_database.graph_database_provider,
"graph_database_url": dataset_database.graph_database_url,
"graph_database_name": dataset_database.graph_database_name,
"graph_database_key": dataset_database.graph_database_key,
"graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name
),

View file

@ -26,6 +26,7 @@ class GraphConfig(BaseSettings):
- graph_database_username
- graph_database_password
- graph_database_port
- graph_database_key
- graph_file_path
- graph_model
- graph_topology
@ -41,6 +42,7 @@ class GraphConfig(BaseSettings):
graph_database_username: str = ""
graph_database_password: str = ""
graph_database_port: int = 123
graph_database_key: str = ""
graph_file_path: str = ""
graph_filename: str = ""
graph_model: object = KnowledgeGraph
@ -90,6 +92,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
@ -116,6 +119,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
}

View file

@ -33,6 +33,7 @@ def create_graph_engine(
graph_database_username="",
graph_database_password="",
graph_database_port="",
graph_database_key="",
):
"""
Create a graph engine based on the specified provider type.
@ -69,6 +70,7 @@ def create_graph_engine(
graph_database_url=graph_database_url,
graph_database_username=graph_database_username,
graph_database_password=graph_database_password,
database_name=graph_database_name,
)
if graph_database_provider == "neo4j":

View file

@ -1,11 +1,15 @@
import os
from uuid import UUID
from typing import Union
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from cognee.modules.data.methods import create_dataset
from cognee.base_config import get_base_config
from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User
@ -32,8 +36,32 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user)
vector_db_name = f"{dataset_id}.lance.db"
graph_db_name = f"{dataset_id}.pkl"
vector_config = get_vectordb_config()
graph_config = get_graph_config()
# Note: for hybrid databases both graph and vector DB name have to be the same
if graph_config.graph_database_provider == "kuzu":
graph_db_name = f"{dataset_id}.pkl"
else:
graph_db_name = f"{dataset_id}"
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}"
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector database URL
if vector_config.vector_db_provider == "lancedb":
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
else:
vector_db_url = vector_config.vector_database_url
# Determine graph database URL
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
@ -55,6 +83,12 @@ async def get_or_create_dataset_database(
dataset_id=dataset_id,
vector_database_name=vector_db_name,
graph_database_name=graph_db_name,
vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_db_url,
graph_database_url=graph_config.graph_database_url,
vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_config.graph_database_key,
)
try:

View file

@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
Instance variables:
- vector_db_url: The URL of the vector database.
- vector_db_port: The port for the vector database.
- vector_db_name: The name of the vector database.
- vector_db_key: The key for accessing the vector database.
- vector_db_provider: The provider for the vector database.
"""
vector_db_url: str = ""
vector_db_port: int = 1234
vector_db_name: str = ""
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
return {
"vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port,
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
}

View file

@ -1,5 +1,6 @@
from .supported_databases import supported_databases
from .embeddings import get_embedding_engine
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from functools import lru_cache
@ -8,6 +9,7 @@ from functools import lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
):
@ -27,6 +29,7 @@ def create_vector_engine(
- vector_db_url (str): The URL for the vector database instance.
- vector_db_port (str): The port for the vector database instance. Required for some
providers.
- vector_db_name (str): The name of the vector database instance.
- vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector').
@ -45,6 +48,7 @@ def create_vector_engine(
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
database_name=vector_db_name,
)
if vector_db_provider.lower() == "pgvector":

View file

@ -15,5 +15,14 @@ class DatasetDatabase(Base):
vector_database_name = Column(String, unique=True, nullable=False)
graph_database_name = Column(String, unique=True, nullable=False)
vector_database_provider = Column(String, unique=False, nullable=False)
graph_database_provider = Column(String, unique=False, nullable=False)
vector_database_url = Column(String, unique=False, nullable=True)
graph_database_url = Column(String, unique=False, nullable=True)
vector_database_key = Column(String, unique=False, nullable=True)
graph_database_key = Column(String, unique=False, nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))

View file

@ -33,11 +33,13 @@ async def main():
"vector_db_url": "cognee1.test",
"vector_db_key": "",
"vector_db_provider": "lancedb",
"vector_db_name": "",
}
task_2_config = {
"vector_db_url": "cognee2.test",
"vector_db_key": "",
"vector_db_provider": "lancedb",
"vector_db_name": "",
}
task_1_graph_config = {