fix: Update vector db url properly
This commit is contained in:
parent
cfc131307f
commit
6a64023876
2 changed files with 20 additions and 12 deletions
|
|
@ -69,8 +69,6 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_config = get_base_config()
|
|
||||||
|
|
||||||
if not backend_access_control_enabled():
|
if not backend_access_control_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -79,6 +77,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||||
|
|
||||||
|
base_config = get_base_config()
|
||||||
data_root_directory = os.path.join(
|
data_root_directory = os.path.join(
|
||||||
base_config.data_root_directory, str(user.tenant_id or user.id)
|
base_config.data_root_directory, str(user.tenant_id or user.id)
|
||||||
)
|
)
|
||||||
|
|
@ -86,17 +85,10 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
base_config.system_root_directory, "databases", str(user.id)
|
base_config.system_root_directory, "databases", str(user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_database.vector_database_provider == "lancedb":
|
|
||||||
vector_db_url = os.path.join(
|
|
||||||
databases_directory_path, dataset_database.vector_database_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vector_db_url = dataset_database.vector_database_url
|
|
||||||
|
|
||||||
# Set vector and graph database configuration based on dataset database information
|
# Set vector and graph database configuration based on dataset database information
|
||||||
vector_config = {
|
vector_config = {
|
||||||
"vector_db_provider": dataset_database.vector_database_provider,
|
"vector_db_provider": dataset_database.vector_database_provider,
|
||||||
"vector_db_url": vector_db_url,
|
"vector_db_url": dataset_database.vector_database_url,
|
||||||
"vector_db_key": dataset_database.vector_database_key,
|
"vector_db_key": dataset_database.vector_database_key,
|
||||||
"vector_db_name": dataset_database.vector_database_name,
|
"vector_db_name": dataset_database.vector_database_name,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
|
import os
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from cognee.modules.data.methods import create_dataset
|
|
||||||
|
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.modules.data.methods import create_dataset
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
|
|
@ -36,6 +38,7 @@ async def get_or_create_dataset_database(
|
||||||
|
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||||
if graph_config.graph_database_provider == "kuzu":
|
if graph_config.graph_database_provider == "kuzu":
|
||||||
graph_db_name = f"{dataset_id}.pkl"
|
graph_db_name = f"{dataset_id}.pkl"
|
||||||
|
|
@ -47,6 +50,19 @@ async def get_or_create_dataset_database(
|
||||||
else:
|
else:
|
||||||
vector_db_name = dataset_id
|
vector_db_name = dataset_id
|
||||||
|
|
||||||
|
base_config = get_base_config()
|
||||||
|
databases_directory_path = os.path.join(
|
||||||
|
base_config.system_root_directory, "databases", str(user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine vector database URL
|
||||||
|
if vector_config.vector_db_provider == "lancedb":
|
||||||
|
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
|
||||||
|
else:
|
||||||
|
vector_db_url = vector_config.vector_database_url
|
||||||
|
|
||||||
|
# Determine graph database URL
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Create dataset if it doesn't exist
|
# Create dataset if it doesn't exist
|
||||||
if isinstance(dataset, str):
|
if isinstance(dataset, str):
|
||||||
|
|
@ -69,7 +85,7 @@ async def get_or_create_dataset_database(
|
||||||
graph_database_name=graph_db_name,
|
graph_database_name=graph_db_name,
|
||||||
vector_database_provider=vector_config.vector_db_provider,
|
vector_database_provider=vector_config.vector_db_provider,
|
||||||
graph_database_provider=graph_config.graph_database_provider,
|
graph_database_provider=graph_config.graph_database_provider,
|
||||||
vector_database_url=vector_config.vector_db_url,
|
vector_database_url=vector_db_url,
|
||||||
graph_database_url=graph_config.graph_database_url,
|
graph_database_url=graph_config.graph_database_url,
|
||||||
vector_database_key=vector_config.vector_db_key,
|
vector_database_key=vector_config.vector_db_key,
|
||||||
graph_database_key=graph_config.graph_database_key,
|
graph_database_key=graph_config.graph_database_key,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue