refactor: Start adding multi-user functions to db interfaces

This commit is contained in:
Igor Ilic 2025-11-12 21:24:40 +01:00
parent 0176cd5a68
commit 6bb642d6b8
3 changed files with 61 additions and 7 deletions

View file

@ -20,15 +20,13 @@ from cognee.modules.users.models import User
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
vector_config = get_vectordb_config()
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector configuration
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
vector_db_url = os.path.join(databases_directory_path, vector_db_name)
# TODO: Have the create_database method be called from interface adapter automatically for all providers instead of specifically here
from cognee.infrastructure.databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
return await LanceDBAdapter.create_database(dataset_id, user)
else:
# Note: for hybrid databases both graph and vector DB name have to be the same
vector_db_name = vector_config.vector_db_name

View file

@ -1,10 +1,15 @@
import asyncio
from os import path
import os
from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.modules.users.models import User
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
@ -357,3 +362,20 @@ class LanceDBAdapter(VectorDBInterface):
},
exclude_fields=["metadata"] + related_models_fields,
)
@classmethod
async def create_database(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
vector_config = get_vectordb_config()
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
vector_db_name = f"{dataset_id}.lance.db"
return {
"vector_database_name": vector_db_name,
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
"vector_database_provider": vector_config.vector_db_provider,
"vector_database_key": vector_config.vector_db_key,
}

View file

@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
from abc import abstractmethod
from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema
from uuid import UUID
from cognee.modules.users.models import User
class VectorDBInterface(Protocol):
@ -217,3 +219,35 @@ class VectorDBInterface(Protocol):
- Any: The schema object suitable for this vector database
"""
return model_type
@classmethod
async def create_database(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Return a dictionary with connection info for a vector database for the given dataset and user.
Function should auto handle deploying of the actual database if needed.
Needed for Cognee multi-tenant/multi-user and backend access control support.
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
From which internal mapping of dataset -> database connection info will be done.
Each dataset needs to map to a unique vector database instance when backend access control is enabled.
Args:
dataset_id: UUID of the dataset if needed by the database creation logic
user: User object if needed by the database creation logic
Returns:
dict: Connection info for the created vector database instance.
"""
pass
async def delete_database(self, dataset_id: UUID, user: User) -> None:
"""
Delete the vector database instance for the given dataset and user.
Function should auto handle deleting of the actual database.
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
Args:
dataset_id: UUID of the dataset
user: User object
"""
pass