feat: enable multi user for falkor (#1689)

<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
Added multi-user support for Falkor. Adding support for the rest of the
graph dbs should be a bit easier after this first one, especially since
Falkor is hybrid.
There are a few things code quality wise that might need changing, I am
open to suggestions.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Andrej Milicevic <milicevi@Andrejs-MacBook-Pro.local>
Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
Co-authored-by: Igor Ilic <igorilic03@gmail.com>
This commit is contained in:
Andrej Milićević 2025-11-11 17:03:48 +01:00 committed by GitHub
parent 78b825f338
commit 8e8aecb76f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 172 additions and 15 deletions

View file

@ -0,0 +1,98 @@
"""Expand dataset database for multi user
Revision ID: 76625596c5c3
Revises: 211ab850ef3d
Create Date: 2025-10-30 12:55:20.239562
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "76625596c5c3"
down_revision: Union[str, None] = "c946955da633"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
vector_database_provider_column = _get_column(
insp, "dataset_database", "vector_database_provider"
)
if not vector_database_provider_column:
op.add_column(
"dataset_database",
sa.Column(
"vector_database_provider",
sa.String(),
unique=False,
nullable=False,
server_default="lancedb",
),
)
graph_database_provider_column = _get_column(
insp, "dataset_database", "graph_database_provider"
)
if not graph_database_provider_column:
op.add_column(
"dataset_database",
sa.Column(
"graph_database_provider",
sa.String(),
unique=False,
nullable=False,
server_default="kuzu",
),
)
vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url")
if not vector_database_url_column:
op.add_column(
"dataset_database",
sa.Column("vector_database_url", sa.String(), unique=False, nullable=True),
)
graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url")
if not graph_database_url_column:
op.add_column(
"dataset_database",
sa.Column("graph_database_url", sa.String(), unique=False, nullable=True),
)
vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key")
if not vector_database_key_column:
op.add_column(
"dataset_database",
sa.Column("vector_database_key", sa.String(), unique=False, nullable=True),
)
graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key")
if not graph_database_key_column:
op.add_column(
"dataset_database",
sa.Column("graph_database_key", sa.String(), unique=False, nullable=True),
)
def downgrade() -> None:
op.drop_column("dataset_database", "vector_database_provider")
op.drop_column("dataset_database", "graph_database_provider")
op.drop_column("dataset_database", "vector_database_url")
op.drop_column("dataset_database", "graph_database_url")
op.drop_column("dataset_database", "vector_database_key")
op.drop_column("dataset_database", "graph_database_key")

View file

@ -16,8 +16,8 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None) session_user = ContextVar("session_user", default=None)
vector_dbs_with_multi_user_support = ["lancedb"] VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
graph_dbs_with_multi_user_support = ["kuzu"] GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
async def set_session_user_context_variable(user): async def set_session_user_context_variable(user):
@ -28,8 +28,8 @@ def multi_user_support_possible():
graph_db_config = get_graph_context_config() graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config() vector_db_config = get_vectordb_context_config()
return ( return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
) )
@ -69,8 +69,6 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
""" """
base_config = get_base_config()
if not backend_access_control_enabled(): if not backend_access_control_enabled():
return return
@ -79,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# To ensure permissions are enforced properly all datasets will have their own databases # To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user) dataset_database = await get_or_create_dataset_database(dataset, user)
base_config = get_base_config()
data_root_directory = os.path.join( data_root_directory = os.path.join(
base_config.data_root_directory, str(user.tenant_id or user.id) base_config.data_root_directory, str(user.tenant_id or user.id)
) )
@ -88,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# Set vector and graph database configuration based on dataset database information # Set vector and graph database configuration based on dataset database information
vector_config = { vector_config = {
"vector_db_url": os.path.join( "vector_db_provider": dataset_database.vector_database_provider,
databases_directory_path, dataset_database.vector_database_name "vector_db_url": dataset_database.vector_database_url,
), "vector_db_key": dataset_database.vector_database_key,
"vector_db_key": "", "vector_db_name": dataset_database.vector_database_name,
"vector_db_provider": "lancedb",
} }
graph_config = { graph_config = {
"graph_database_provider": "kuzu", "graph_database_provider": dataset_database.graph_database_provider,
"graph_database_url": dataset_database.graph_database_url,
"graph_database_name": dataset_database.graph_database_name,
"graph_database_key": dataset_database.graph_database_key,
"graph_file_path": os.path.join( "graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name databases_directory_path, dataset_database.graph_database_name
), ),

View file

@ -26,6 +26,7 @@ class GraphConfig(BaseSettings):
- graph_database_username - graph_database_username
- graph_database_password - graph_database_password
- graph_database_port - graph_database_port
- graph_database_key
- graph_file_path - graph_file_path
- graph_model - graph_model
- graph_topology - graph_topology
@ -41,6 +42,7 @@ class GraphConfig(BaseSettings):
graph_database_username: str = "" graph_database_username: str = ""
graph_database_password: str = "" graph_database_password: str = ""
graph_database_port: int = 123 graph_database_port: int = 123
graph_database_key: str = ""
graph_file_path: str = "" graph_file_path: str = ""
graph_filename: str = "" graph_filename: str = ""
graph_model: object = KnowledgeGraph graph_model: object = KnowledgeGraph
@ -90,6 +92,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username, "graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password, "graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port, "graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path, "graph_file_path": self.graph_file_path,
"graph_model": self.graph_model, "graph_model": self.graph_model,
"graph_topology": self.graph_topology, "graph_topology": self.graph_topology,
@ -116,6 +119,7 @@ class GraphConfig(BaseSettings):
"graph_database_username": self.graph_database_username, "graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password, "graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port, "graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path, "graph_file_path": self.graph_file_path,
} }

View file

@ -33,6 +33,7 @@ def create_graph_engine(
graph_database_username="", graph_database_username="",
graph_database_password="", graph_database_password="",
graph_database_port="", graph_database_port="",
graph_database_key="",
): ):
""" """
Create a graph engine based on the specified provider type. Create a graph engine based on the specified provider type.
@ -69,6 +70,7 @@ def create_graph_engine(
graph_database_url=graph_database_url, graph_database_url=graph_database_url,
graph_database_username=graph_database_username, graph_database_username=graph_database_username,
graph_database_password=graph_database_password, graph_database_password=graph_database_password,
database_name=graph_database_name,
) )
if graph_database_provider == "neo4j": if graph_database_provider == "neo4j":

View file

@ -1,11 +1,15 @@
import os
from uuid import UUID from uuid import UUID
from typing import Union from typing import Union
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from cognee.modules.data.methods import create_dataset
from cognee.base_config import get_base_config
from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User from cognee.modules.users.models import User
@ -32,8 +36,32 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user) dataset_id = await get_unique_dataset_id(dataset, user)
vector_db_name = f"{dataset_id}.lance.db" vector_config = get_vectordb_config()
graph_db_name = f"{dataset_id}.pkl" graph_config = get_graph_config()
# Note: for hybrid databases both graph and vector DB name have to be the same
if graph_config.graph_database_provider == "kuzu":
graph_db_name = f"{dataset_id}.pkl"
else:
graph_db_name = f"{dataset_id}"
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}"
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector database URL
if vector_config.vector_db_provider == "lancedb":
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
else:
vector_db_url = vector_config.vector_database_url
# Determine graph database URL
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist # Create dataset if it doesn't exist
@ -55,6 +83,12 @@ async def get_or_create_dataset_database(
dataset_id=dataset_id, dataset_id=dataset_id,
vector_database_name=vector_db_name, vector_database_name=vector_db_name,
graph_database_name=graph_db_name, graph_database_name=graph_db_name,
vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_db_url,
graph_database_url=graph_config.graph_database_url,
vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_config.graph_database_key,
) )
try: try:

View file

@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
Instance variables: Instance variables:
- vector_db_url: The URL of the vector database. - vector_db_url: The URL of the vector database.
- vector_db_port: The port for the vector database. - vector_db_port: The port for the vector database.
- vector_db_name: The name of the vector database.
- vector_db_key: The key for accessing the vector database. - vector_db_key: The key for accessing the vector database.
- vector_db_provider: The provider for the vector database. - vector_db_provider: The provider for the vector database.
""" """
vector_db_url: str = "" vector_db_url: str = ""
vector_db_port: int = 1234 vector_db_port: int = 1234
vector_db_name: str = ""
vector_db_key: str = "" vector_db_key: str = ""
vector_db_provider: str = "lancedb" vector_db_provider: str = "lancedb"
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
return { return {
"vector_db_url": self.vector_db_url, "vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port, "vector_db_port": self.vector_db_port,
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key, "vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider, "vector_db_provider": self.vector_db_provider,
} }

View file

@ -1,5 +1,6 @@
from .supported_databases import supported_databases from .supported_databases import supported_databases
from .embeddings import get_embedding_engine from .embeddings import get_embedding_engine
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from functools import lru_cache from functools import lru_cache
@ -8,6 +9,7 @@ from functools import lru_cache
def create_vector_engine( def create_vector_engine(
vector_db_provider: str, vector_db_provider: str,
vector_db_url: str, vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "", vector_db_port: str = "",
vector_db_key: str = "", vector_db_key: str = "",
): ):
@ -27,6 +29,7 @@ def create_vector_engine(
- vector_db_url (str): The URL for the vector database instance. - vector_db_url (str): The URL for the vector database instance.
- vector_db_port (str): The port for the vector database instance. Required for some - vector_db_port (str): The port for the vector database instance. Required for some
providers. providers.
- vector_db_name (str): The name of the vector database instance.
- vector_db_key (str): The API key or access token for the vector database instance. - vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g., - vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector'). 'pgvector').
@ -45,6 +48,7 @@ def create_vector_engine(
url=vector_db_url, url=vector_db_url,
api_key=vector_db_key, api_key=vector_db_key,
embedding_engine=embedding_engine, embedding_engine=embedding_engine,
database_name=vector_db_name,
) )
if vector_db_provider.lower() == "pgvector": if vector_db_provider.lower() == "pgvector":

View file

@ -15,5 +15,14 @@ class DatasetDatabase(Base):
vector_database_name = Column(String, unique=True, nullable=False) vector_database_name = Column(String, unique=True, nullable=False)
graph_database_name = Column(String, unique=True, nullable=False) graph_database_name = Column(String, unique=True, nullable=False)
vector_database_provider = Column(String, unique=False, nullable=False)
graph_database_provider = Column(String, unique=False, nullable=False)
vector_database_url = Column(String, unique=False, nullable=True)
graph_database_url = Column(String, unique=False, nullable=True)
vector_database_key = Column(String, unique=False, nullable=True)
graph_database_key = Column(String, unique=False, nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))

View file

@ -33,11 +33,13 @@ async def main():
"vector_db_url": "cognee1.test", "vector_db_url": "cognee1.test",
"vector_db_key": "", "vector_db_key": "",
"vector_db_provider": "lancedb", "vector_db_provider": "lancedb",
"vector_db_name": "",
} }
task_2_config = { task_2_config = {
"vector_db_url": "cognee2.test", "vector_db_url": "cognee2.test",
"vector_db_key": "", "vector_db_key": "",
"vector_db_provider": "lancedb", "vector_db_provider": "lancedb",
"vector_db_name": "",
} }
task_1_graph_config = { task_1_graph_config = {