From 8e8aecb76ff66a48e4b5fe18a7ffaac48434f89a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Mili=C4=87evi=C4=87?= <85933103+siillee@users.noreply.github.com> Date: Tue, 11 Nov 2025 17:03:48 +0100 Subject: [PATCH] feat: enable multi user for falkor (#1689) ## Description Added multi-user support for Falkor. Adding support for the rest of the graph dbs should be a bit easier after this first one, especially since Falkor is hybrid. There are a few things code quality wise that might need changing, I am open to suggestions. ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Andrej Milicevic Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Igor Ilic --- ..._expand_dataset_database_for_multi_user.py | 98 +++++++++++++++++++ cognee/context_global_variables.py | 25 ++--- .../infrastructure/databases/graph/config.py | 4 + .../databases/graph/get_graph_engine.py | 2 + .../utils/get_or_create_dataset_database.py | 40 +++++++- .../infrastructure/databases/vector/config.py | 3 + .../databases/vector/create_vector_engine.py | 4 + .../modules/users/models/DatasetDatabase.py | 9 ++ cognee/tests/test_parallel_databases.py | 2 + 9 files changed, 172 insertions(+), 15 deletions(-) create mode 100644 alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py diff --git a/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py new file mode 100644 index 000000000..7e13898ae --- /dev/null +++ b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py @@ -0,0 +1,98 @@ +"""Expand dataset database for multi user + +Revision ID: 76625596c5c3 +Revises: 211ab850ef3d +Create Date: 2025-10-30 12:55:20.239562 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "76625596c5c3" +down_revision: Union[str, None] = "c946955da633" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + +def upgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + vector_database_provider_column = _get_column( + insp, "dataset_database", "vector_database_provider" + ) + if not vector_database_provider_column: + op.add_column( + "dataset_database", + sa.Column( + "vector_database_provider", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + ) + + graph_database_provider_column = _get_column( + insp, "dataset_database", "graph_database_provider" + ) + if not graph_database_provider_column: + op.add_column( + "dataset_database", + sa.Column( + "graph_database_provider", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), + ) + + vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url") + if not vector_database_url_column: + op.add_column( + "dataset_database", + sa.Column("vector_database_url", sa.String(), unique=False, nullable=True), + ) + + graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url") + if not graph_database_url_column: + op.add_column( + "dataset_database", + sa.Column("graph_database_url", sa.String(), unique=False, nullable=True), + ) + + vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key") + if not vector_database_key_column: + op.add_column( + "dataset_database", + sa.Column("vector_database_key", sa.String(), unique=False, nullable=True), + ) + + graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key") + if not graph_database_key_column: + op.add_column( + "dataset_database", + sa.Column("graph_database_key", sa.String(), unique=False, nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("dataset_database", "vector_database_provider") + op.drop_column("dataset_database", "graph_database_provider") + op.drop_column("dataset_database", "vector_database_url") + op.drop_column("dataset_database", "graph_database_url") + op.drop_column("dataset_database", "vector_database_key") + op.drop_column("dataset_database", "graph_database_key") diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index f17c9187a..62e06fc64 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -16,8 +16,8 @@ vector_db_config = ContextVar("vector_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None) session_user = ContextVar("session_user", default=None) -vector_dbs_with_multi_user_support = ["lancedb"] -graph_dbs_with_multi_user_support = ["kuzu"] +VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] +GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] async def set_session_user_context_variable(user): @@ -28,8 +28,8 @@ def multi_user_support_possible(): graph_db_config = get_graph_context_config() vector_db_config = get_vectordb_context_config() return ( - graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support - and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support + graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT + and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT ) @@ -69,8 +69,6 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ """ - base_config = get_base_config() - if not backend_access_control_enabled(): return @@ -79,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # To ensure permissions are enforced properly all datasets will have their own databases dataset_database = await get_or_create_dataset_database(dataset, user) + base_config = get_base_config() data_root_directory = os.path.join( base_config.data_root_directory, str(user.tenant_id or user.id) ) @@ -88,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # Set vector and graph database configuration based on dataset database information vector_config = { - "vector_db_url": os.path.join( - databases_directory_path, dataset_database.vector_database_name - ), - "vector_db_key": "", - "vector_db_provider": "lancedb", + "vector_db_provider": dataset_database.vector_database_provider, + "vector_db_url": dataset_database.vector_database_url, + "vector_db_key": dataset_database.vector_database_key, + "vector_db_name": dataset_database.vector_database_name, } graph_config = { - "graph_database_provider": "kuzu", + "graph_database_provider": dataset_database.graph_database_provider, + "graph_database_url": dataset_database.graph_database_url, + "graph_database_name": dataset_database.graph_database_name, + "graph_database_key": dataset_database.graph_database_key, "graph_file_path": os.path.join( databases_directory_path, dataset_database.graph_database_name ), diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index b7907313c..23687b359 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -26,6 +26,7 @@ class GraphConfig(BaseSettings): - graph_database_username - graph_database_password - graph_database_port + - graph_database_key - graph_file_path - graph_model - graph_topology @@ -41,6 +42,7 @@ class GraphConfig(BaseSettings): graph_database_username: str = "" graph_database_password: str = "" graph_database_port: int = 123 + graph_database_key: str = "" graph_file_path: str = "" graph_filename: str = "" graph_model: object = KnowledgeGraph @@ -90,6 +92,7 @@ class GraphConfig(BaseSettings): "graph_database_username": self.graph_database_username, "graph_database_password": self.graph_database_password, "graph_database_port": self.graph_database_port, + "graph_database_key": self.graph_database_key, "graph_file_path": self.graph_file_path, "graph_model": self.graph_model, "graph_topology": self.graph_topology, @@ -116,6 +119,7 @@ class GraphConfig(BaseSettings): "graph_database_username": self.graph_database_username, "graph_database_password": self.graph_database_password, "graph_database_port": self.graph_database_port, + "graph_database_key": self.graph_database_key, "graph_file_path": self.graph_file_path, } diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 1ea61d29f..82e3cad6e 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -33,6 +33,7 @@ def create_graph_engine( graph_database_username="", graph_database_password="", graph_database_port="", + graph_database_key="", ): """ Create a graph engine based on the specified provider type. @@ -69,6 +70,7 @@ def create_graph_engine( graph_database_url=graph_database_url, graph_database_username=graph_database_username, graph_database_password=graph_database_password, + database_name=graph_database_name, ) if graph_database_provider == "neo4j": diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 29156025d..3684bb100 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -1,11 +1,15 @@ +import os from uuid import UUID from typing import Union from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from cognee.modules.data.methods import create_dataset +from cognee.base_config import get_base_config +from cognee.modules.data.methods import create_dataset from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import User @@ -32,8 +36,32 @@ async def get_or_create_dataset_database( dataset_id = await get_unique_dataset_id(dataset, user) - vector_db_name = f"{dataset_id}.lance.db" - graph_db_name = f"{dataset_id}.pkl" + vector_config = get_vectordb_config() + graph_config = get_graph_config() + + # Note: for hybrid databases both graph and vector DB name have to be the same + if graph_config.graph_database_provider == "kuzu": + graph_db_name = f"{dataset_id}.pkl" + else: + graph_db_name = f"{dataset_id}" + + if vector_config.vector_db_provider == "lancedb": + vector_db_name = f"{dataset_id}.lance.db" + else: + vector_db_name = f"{dataset_id}" + + base_config = get_base_config() + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + # Determine vector database URL + if vector_config.vector_db_provider == "lancedb": + vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) + else: + vector_db_url = vector_config.vector_database_url + + # Determine graph database URL async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist @@ -55,6 +83,12 @@ async def get_or_create_dataset_database( dataset_id=dataset_id, vector_database_name=vector_db_name, graph_database_name=graph_db_name, + vector_database_provider=vector_config.vector_db_provider, + graph_database_provider=graph_config.graph_database_provider, + vector_database_url=vector_db_url, + graph_database_url=graph_config.graph_database_url, + vector_database_key=vector_config.vector_db_key, + graph_database_key=graph_config.graph_database_key, ) try: diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index b6d3ae644..7d28f1668 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -18,12 +18,14 @@ class VectorConfig(BaseSettings): Instance variables: - vector_db_url: The URL of the vector database. - vector_db_port: The port for the vector database. + - vector_db_name: The name of the vector database. - vector_db_key: The key for accessing the vector database. - vector_db_provider: The provider for the vector database. """ vector_db_url: str = "" vector_db_port: int = 1234 + vector_db_name: str = "" vector_db_key: str = "" vector_db_provider: str = "lancedb" @@ -58,6 +60,7 @@ class VectorConfig(BaseSettings): return { "vector_db_url": self.vector_db_url, "vector_db_port": self.vector_db_port, + "vector_db_name": self.vector_db_name, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index c54d94f6c..b182f084b 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,5 +1,6 @@ from .supported_databases import supported_databases from .embeddings import get_embedding_engine +from cognee.infrastructure.databases.graph.config import get_graph_context_config from functools import lru_cache @@ -8,6 +9,7 @@ from functools import lru_cache def create_vector_engine( vector_db_provider: str, vector_db_url: str, + vector_db_name: str, vector_db_port: str = "", vector_db_key: str = "", ): @@ -27,6 +29,7 @@ def create_vector_engine( - vector_db_url (str): The URL for the vector database instance. - vector_db_port (str): The port for the vector database instance. Required for some providers. + - vector_db_name (str): The name of the vector database instance. - vector_db_key (str): The API key or access token for the vector database instance. - vector_db_provider (str): The name of the vector database provider to use (e.g., 'pgvector'). @@ -45,6 +48,7 @@ def create_vector_engine( url=vector_db_url, api_key=vector_db_key, embedding_engine=embedding_engine, + database_name=vector_db_name, ) if vector_db_provider.lower() == "pgvector": diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 0d71d8413..25d610ab9 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -15,5 +15,14 @@ class DatasetDatabase(Base): vector_database_name = Column(String, unique=True, nullable=False) graph_database_name = Column(String, unique=True, nullable=False) + vector_database_provider = Column(String, unique=False, nullable=False) + graph_database_provider = Column(String, unique=False, nullable=False) + + vector_database_url = Column(String, unique=False, nullable=True) + graph_database_url = Column(String, unique=False, nullable=True) + + vector_database_key = Column(String, unique=False, nullable=True) + graph_database_key = Column(String, unique=False, nullable=True) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) diff --git a/cognee/tests/test_parallel_databases.py b/cognee/tests/test_parallel_databases.py index 9a590921a..3164206ed 100755 --- a/cognee/tests/test_parallel_databases.py +++ b/cognee/tests/test_parallel_databases.py @@ -33,11 +33,13 @@ async def main(): "vector_db_url": "cognee1.test", "vector_db_key": "", "vector_db_provider": "lancedb", + "vector_db_name": "", } task_2_config = { "vector_db_url": "cognee2.test", "vector_db_key": "", "vector_db_provider": "lancedb", + "vector_db_name": "", } task_1_graph_config = {