refactor: Start adding multi-user functions to db interfaces
This commit is contained in:
parent
0176cd5a68
commit
6bb642d6b8
3 changed files with 61 additions and 7 deletions
|
|
@ -20,15 +20,13 @@ from cognee.modules.users.models import User
|
||||||
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
|
|
||||||
base_config = get_base_config()
|
|
||||||
databases_directory_path = os.path.join(
|
|
||||||
base_config.system_root_directory, "databases", str(user.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine vector configuration
|
# Determine vector configuration
|
||||||
if vector_config.vector_db_provider == "lancedb":
|
if vector_config.vector_db_provider == "lancedb":
|
||||||
vector_db_name = f"{dataset_id}.lance.db"
|
# TODO: Have the create_database method be called from interface adapter automatically for all providers instead of specifically here
|
||||||
vector_db_url = os.path.join(databases_directory_path, vector_db_name)
|
from cognee.infrastructure.databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
|
||||||
|
|
||||||
|
return await LanceDBAdapter.create_database(dataset_id, user)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||||
vector_db_name = vector_config.vector_db_name
|
vector_db_name = vector_config.vector_db_name
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,15 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from os import path
|
from os import path
|
||||||
|
import os
|
||||||
|
from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
from cognee.modules.users.models import User
|
||||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
|
|
@ -357,3 +362,20 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
},
|
},
|
||||||
exclude_fields=["metadata"] + related_models_fields,
|
exclude_fields=["metadata"] + related_models_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_database(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
base_config = get_base_config()
|
||||||
|
databases_directory_path = os.path.join(
|
||||||
|
base_config.system_root_directory, "databases", str(user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db_name = f"{dataset_id}.lance.db"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vector_database_name": vector_db_name,
|
||||||
|
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||||
|
"vector_database_provider": vector_config.vector_db_provider,
|
||||||
|
"vector_database_key": vector_config.vector_db_key,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from .models.PayloadSchema import PayloadSchema
|
from .models.PayloadSchema import PayloadSchema
|
||||||
|
from uuid import UUID
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
|
||||||
class VectorDBInterface(Protocol):
|
class VectorDBInterface(Protocol):
|
||||||
|
|
@ -217,3 +219,35 @@ class VectorDBInterface(Protocol):
|
||||||
- Any: The schema object suitable for this vector database
|
- Any: The schema object suitable for this vector database
|
||||||
"""
|
"""
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_database(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dictionary with connection info for a vector database for the given dataset and user.
|
||||||
|
Function should auto handle deploying of the actual database if needed.
|
||||||
|
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||||
|
|
||||||
|
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||||
|
From which internal mapping of dataset -> database connection info will be done.
|
||||||
|
|
||||||
|
Each dataset needs to map to a unique vector database instance when backend access control is enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||||
|
user: User object if needed by the database creation logic
|
||||||
|
Returns:
|
||||||
|
dict: Connection info for the created vector database instance.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_database(self, dataset_id: UUID, user: User) -> None:
|
||||||
|
"""
|
||||||
|
Delete the vector database instance for the given dataset and user.
|
||||||
|
Function should auto handle deleting of the actual database.
|
||||||
|
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: UUID of the dataset
|
||||||
|
user: User object
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue