diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 6a0f767ff..c2e9e82a9 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -69,8 +69,6 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ """ - base_config = get_base_config() - if not backend_access_control_enabled(): return @@ -79,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # To ensure permissions are enforced properly all datasets will have their own databases dataset_database = await get_or_create_dataset_database(dataset, user) + base_config = get_base_config() data_root_directory = os.path.join( base_config.data_root_directory, str(user.tenant_id or user.id) ) @@ -86,17 +85,10 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ base_config.system_root_directory, "databases", str(user.id) ) - if dataset_database.vector_database_provider == "lancedb": - vector_db_url = os.path.join( - databases_directory_path, dataset_database.vector_database_name - ) - else: - vector_db_url = dataset_database.vector_database_url - # Set vector and graph database configuration based on dataset database information vector_config = { "vector_db_provider": dataset_database.vector_database_provider, - "vector_db_url": vector_db_url, + "vector_db_url": dataset_database.vector_database_url, "vector_db_key": dataset_database.vector_database_key, "vector_db_name": dataset_database.vector_database_name, } diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 311f89ad7..0df3502ba 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -1,10 +1,12 @@ +import os from uuid import UUID from typing import Union from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from cognee.modules.data.methods import create_dataset +from cognee.base_config import get_base_config +from cognee.modules.data.methods import create_dataset from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.vector import get_vectordb_config from cognee.infrastructure.databases.graph.config import get_graph_config @@ -36,6 +38,7 @@ async def get_or_create_dataset_database( vector_config = get_vectordb_config() graph_config = get_graph_config() + # Note: for hybrid databases both graph and vector DB name have to be the same if graph_config.graph_database_provider == "kuzu": graph_db_name = f"{dataset_id}.pkl" @@ -47,6 +50,19 @@ async def get_or_create_dataset_database( else: vector_db_name = dataset_id + base_config = get_base_config() + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + # Determine vector database URL + if vector_config.vector_db_provider == "lancedb": + vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) + else: + vector_db_url = vector_config.vector_database_url + + # Determine graph database URL + async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist if isinstance(dataset, str): @@ -69,7 +85,7 @@ async def get_or_create_dataset_database( graph_database_name=graph_db_name, vector_database_provider=vector_config.vector_db_provider, graph_database_provider=graph_config.graph_database_provider, - vector_database_url=vector_config.vector_db_url, + vector_database_url=vector_db_url, graph_database_url=graph_config.graph_database_url, vector_database_key=vector_config.vector_db_key, graph_database_key=graph_config.graph_database_key,