feat: enable multi user for falkor (#1689)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> Added multi-user support for Falkor. Adding support for the rest of the graph dbs should be a bit easier after this first one, especially since Falkor is hybrid. There are a few things code quality wise that might need changing, I am open to suggestions. ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Andrej Milicevic <milicevi@Andrejs-MacBook-Pro.local> Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Igor Ilic <igorilic03@gmail.com>
This commit is contained in:
parent
78b825f338
commit
8e8aecb76f
9 changed files with 172 additions and 15 deletions
|
|
@ -0,0 +1,98 @@
|
|||
"""Expand dataset database for multi user
|
||||
|
||||
Revision ID: 76625596c5c3
|
||||
Revises: 211ab850ef3d
|
||||
Create Date: 2025-10-30 12:55:20.239562
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "76625596c5c3"
|
||||
down_revision: Union[str, None] = "c946955da633"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
vector_database_provider_column = _get_column(
|
||||
insp, "dataset_database", "vector_database_provider"
|
||||
)
|
||||
if not vector_database_provider_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"vector_database_provider",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="lancedb",
|
||||
),
|
||||
)
|
||||
|
||||
graph_database_provider_column = _get_column(
|
||||
insp, "dataset_database", "graph_database_provider"
|
||||
)
|
||||
if not graph_database_provider_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column(
|
||||
"graph_database_provider",
|
||||
sa.String(),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
server_default="kuzu",
|
||||
),
|
||||
)
|
||||
|
||||
vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url")
|
||||
if not vector_database_url_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("vector_database_url", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
graph_database_url_column = _get_column(insp, "dataset_database", "graph_database_url")
|
||||
if not graph_database_url_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("graph_database_url", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
vector_database_key_column = _get_column(insp, "dataset_database", "vector_database_key")
|
||||
if not vector_database_key_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("vector_database_key", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
graph_database_key_column = _get_column(insp, "dataset_database", "graph_database_key")
|
||||
if not graph_database_key_column:
|
||||
op.add_column(
|
||||
"dataset_database",
|
||||
sa.Column("graph_database_key", sa.String(), unique=False, nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("dataset_database", "vector_database_provider")
|
||||
op.drop_column("dataset_database", "graph_database_provider")
|
||||
op.drop_column("dataset_database", "vector_database_url")
|
||||
op.drop_column("dataset_database", "graph_database_url")
|
||||
op.drop_column("dataset_database", "vector_database_key")
|
||||
op.drop_column("dataset_database", "graph_database_key")
|
||||
|
|
@ -16,8 +16,8 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
vector_dbs_with_multi_user_support = ["lancedb"]
|
||||
graph_dbs_with_multi_user_support = ["kuzu"]
|
||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
|
|
@ -28,8 +28,8 @@ def multi_user_support_possible():
|
|||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
|
||||
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
|
||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -69,8 +69,6 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
"""
|
||||
|
||||
base_config = get_base_config()
|
||||
|
||||
if not backend_access_control_enabled():
|
||||
return
|
||||
|
||||
|
|
@ -79,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
|
||||
base_config = get_base_config()
|
||||
data_root_directory = os.path.join(
|
||||
base_config.data_root_directory, str(user.tenant_id or user.id)
|
||||
)
|
||||
|
|
@ -88,15 +87,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
vector_config = {
|
||||
"vector_db_url": os.path.join(
|
||||
databases_directory_path, dataset_database.vector_database_name
|
||||
),
|
||||
"vector_db_key": "",
|
||||
"vector_db_provider": "lancedb",
|
||||
"vector_db_provider": dataset_database.vector_database_provider,
|
||||
"vector_db_url": dataset_database.vector_database_url,
|
||||
"vector_db_key": dataset_database.vector_database_key,
|
||||
"vector_db_name": dataset_database.vector_database_name,
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"graph_database_provider": "kuzu",
|
||||
"graph_database_provider": dataset_database.graph_database_provider,
|
||||
"graph_database_url": dataset_database.graph_database_url,
|
||||
"graph_database_name": dataset_database.graph_database_name,
|
||||
"graph_database_key": dataset_database.graph_database_key,
|
||||
"graph_file_path": os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class GraphConfig(BaseSettings):
|
|||
- graph_database_username
|
||||
- graph_database_password
|
||||
- graph_database_port
|
||||
- graph_database_key
|
||||
- graph_file_path
|
||||
- graph_model
|
||||
- graph_topology
|
||||
|
|
@ -41,6 +42,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_database_username: str = ""
|
||||
graph_database_password: str = ""
|
||||
graph_database_port: int = 123
|
||||
graph_database_key: str = ""
|
||||
graph_file_path: str = ""
|
||||
graph_filename: str = ""
|
||||
graph_model: object = KnowledgeGraph
|
||||
|
|
@ -90,6 +92,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
|
|
@ -116,6 +119,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ def create_graph_engine(
|
|||
graph_database_username="",
|
||||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
@ -69,6 +70,7 @@ def create_graph_engine(
|
|||
graph_database_url=graph_database_url,
|
||||
graph_database_username=graph_database_username,
|
||||
graph_database_password=graph_database_password,
|
||||
database_name=graph_database_name,
|
||||
)
|
||||
|
||||
if graph_database_provider == "neo4j":
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.modules.data.methods import get_unique_dataset_id
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -32,8 +36,32 @@ async def get_or_create_dataset_database(
|
|||
|
||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
vector_config = get_vectordb_config()
|
||||
graph_config = get_graph_config()
|
||||
|
||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||
if graph_config.graph_database_provider == "kuzu":
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
else:
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
else:
|
||||
vector_db_name = f"{dataset_id}"
|
||||
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
# Determine vector database URL
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
||||
else:
|
||||
vector_db_url = vector_config.vector_database_url
|
||||
|
||||
# Determine graph database URL
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Create dataset if it doesn't exist
|
||||
|
|
@ -55,6 +83,12 @@ async def get_or_create_dataset_database(
|
|||
dataset_id=dataset_id,
|
||||
vector_database_name=vector_db_name,
|
||||
graph_database_name=graph_db_name,
|
||||
vector_database_provider=vector_config.vector_db_provider,
|
||||
graph_database_provider=graph_config.graph_database_provider,
|
||||
vector_database_url=vector_db_url,
|
||||
graph_database_url=graph_config.graph_database_url,
|
||||
vector_database_key=vector_config.vector_db_key,
|
||||
graph_database_key=graph_config.graph_database_key,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
|
|||
Instance variables:
|
||||
- vector_db_url: The URL of the vector database.
|
||||
- vector_db_port: The port for the vector database.
|
||||
- vector_db_name: The name of the vector database.
|
||||
- vector_db_key: The key for accessing the vector database.
|
||||
- vector_db_provider: The provider for the vector database.
|
||||
"""
|
||||
|
||||
vector_db_url: str = ""
|
||||
vector_db_port: int = 1234
|
||||
vector_db_name: str = ""
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
|
||||
|
|
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
|
|||
return {
|
||||
"vector_db_url": self.vector_db_url,
|
||||
"vector_db_port": self.vector_db_port,
|
||||
"vector_db_name": self.vector_db_name,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .supported_databases import supported_databases
|
||||
from .embeddings import get_embedding_engine
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
|
@ -8,6 +9,7 @@ from functools import lru_cache
|
|||
def create_vector_engine(
|
||||
vector_db_provider: str,
|
||||
vector_db_url: str,
|
||||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
):
|
||||
|
|
@ -27,6 +29,7 @@ def create_vector_engine(
|
|||
- vector_db_url (str): The URL for the vector database instance.
|
||||
- vector_db_port (str): The port for the vector database instance. Required for some
|
||||
providers.
|
||||
- vector_db_name (str): The name of the vector database instance.
|
||||
- vector_db_key (str): The API key or access token for the vector database instance.
|
||||
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
|
||||
'pgvector').
|
||||
|
|
@ -45,6 +48,7 @@ def create_vector_engine(
|
|||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
database_name=vector_db_name,
|
||||
)
|
||||
|
||||
if vector_db_provider.lower() == "pgvector":
|
||||
|
|
|
|||
|
|
@ -15,5 +15,14 @@ class DatasetDatabase(Base):
|
|||
vector_database_name = Column(String, unique=True, nullable=False)
|
||||
graph_database_name = Column(String, unique=True, nullable=False)
|
||||
|
||||
vector_database_provider = Column(String, unique=False, nullable=False)
|
||||
graph_database_provider = Column(String, unique=False, nullable=False)
|
||||
|
||||
vector_database_url = Column(String, unique=False, nullable=True)
|
||||
graph_database_url = Column(String, unique=False, nullable=True)
|
||||
|
||||
vector_database_key = Column(String, unique=False, nullable=True)
|
||||
graph_database_key = Column(String, unique=False, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
|
|
|||
|
|
@ -33,11 +33,13 @@ async def main():
|
|||
"vector_db_url": "cognee1.test",
|
||||
"vector_db_key": "",
|
||||
"vector_db_provider": "lancedb",
|
||||
"vector_db_name": "",
|
||||
}
|
||||
task_2_config = {
|
||||
"vector_db_url": "cognee2.test",
|
||||
"vector_db_key": "",
|
||||
"vector_db_provider": "lancedb",
|
||||
"vector_db_name": "",
|
||||
}
|
||||
|
||||
task_1_graph_config = {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue