diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 70c27aab3..82e3cad6e 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -70,7 +70,7 @@ def create_graph_engine( graph_database_url=graph_database_url, graph_database_username=graph_database_username, graph_database_password=graph_database_password, - graph_name=graph_database_name, + database_name=graph_database_name, ) if graph_database_provider == "neo4j": diff --git a/cognee/infrastructure/databases/utils/constants.py b/cognee/infrastructure/databases/utils/constants.py new file mode 100644 index 000000000..fe6390a07 --- /dev/null +++ b/cognee/infrastructure/databases/utils/constants.py @@ -0,0 +1,4 @@ +VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] +GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] + +HYBRID_DBS = ["falkor"] diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 1552a7bbc..deea46541 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -11,17 +11,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import User - - -# TODO: Find a better place to define these -default_vector_db_provider = "lancedb" -default_graph_db_provider = "kuzu" -default_vector_db_url = None -default_graph_db_url = None -default_vector_db_key = None -default_graph_db_key = None -vector_dbs_with_multi_user_support = ["lancedb", "falkor"] -graph_dbs_with_multi_user_support = ["kuzu", "falkor"] +from .constants import ( + GRAPH_DBS_WITH_MULTI_USER_SUPPORT, + VECTOR_DBS_WITH_MULTI_USER_SUPPORT, + HYBRID_DBS, +) async def get_or_create_dataset_database( @@ -45,12 +39,19 @@ async def get_or_create_dataset_database( dataset_id = await get_unique_dataset_id(dataset, user) - vector_db_name = f"{dataset_id}.db" - graph_db_name = f"{dataset_id}.pkl" - vector_config = get_vectordb_config() graph_config = get_graph_config() + graph_db_name = f"{dataset_id}.pkl" + + if graph_config.graph_database_provider in HYBRID_DBS: + vector_db_name = graph_db_name + else: + if vector_config.vector_database_provider == "lancedb": + vector_db_name = f"{dataset_id}.lance.db" + else: + vector_db_name = f"{dataset_id}.db" + async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist if isinstance(dataset, str): @@ -66,23 +67,18 @@ async def get_or_create_dataset_database( return existing # Check if we support multi-user for this provider. If not, use default - if graph_config.graph_database_provider in graph_dbs_with_multi_user_support: - graph_provider = graph_config.graph_database_provider - graph_url = graph_config.graph_database_url - graph_key = graph_config.graph_database_key - else: - graph_provider = default_graph_db_provider - graph_url = default_graph_db_url - graph_key = default_graph_db_key - - if vector_config.vector_db_provider in vector_dbs_with_multi_user_support: - vector_provider = vector_config.vector_db_provider - vector_url = vector_config.vector_db_url - vector_key = vector_config.vector_db_key - else: - vector_provider = default_vector_db_provider - vector_url = default_vector_db_url - vector_key = default_vector_db_key + if graph_config.graph_database_provider not in GRAPH_DBS_WITH_MULTI_USER_SUPPORT: + raise EnvironmentError( + f"Multi-user is currently not supported for the graph database provider: {graph_config.graph_database_provider}. " + f"Supported providers are: {', '.join(GRAPH_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these" + f"providers, or disable BACKEND_ACCESS_CONTROL" + ) + if vector_config.vector_db_provider not in VECTOR_DBS_WITH_MULTI_USER_SUPPORT: + raise EnvironmentError( + f"Multi-user is currently not supported for the vector database provider: {vector_config.vector_db_provider}. " + f"Supported providers are: {', '.join(VECTOR_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these" + f"providers, or disable BACKEND_ACCESS_CONTROL" + ) # If there are no existing rows build a new row record = DatasetDatabase( @@ -90,12 +86,12 @@ async def get_or_create_dataset_database( dataset_id=dataset_id, vector_database_name=vector_db_name, graph_database_name=graph_db_name, - vector_database_provider=vector_provider, - graph_database_provider=graph_provider, - vector_database_url=vector_url, - graph_database_url=graph_url, - vector_database_key=vector_key, - graph_database_key=graph_key, + vector_database_provider=vector_config.vector_db_provider, + graph_database_provider=graph_config.graph_database_provider, + vector_database_url=vector_config.vector_db_url, + graph_database_url=graph_config.graph_database_url, + vector_database_key=vector_config.vector_db_key, + graph_database_key=graph_config.graph_database_key, ) try: diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index b6d3ae644..7d28f1668 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -18,12 +18,14 @@ class VectorConfig(BaseSettings): Instance variables: - vector_db_url: The URL of the vector database. - vector_db_port: The port for the vector database. + - vector_db_name: The name of the vector database. - vector_db_key: The key for accessing the vector database. - vector_db_provider: The provider for the vector database. """ vector_db_url: str = "" vector_db_port: int = 1234 + vector_db_name: str = "" vector_db_key: str = "" vector_db_provider: str = "lancedb" @@ -58,6 +60,7 @@ class VectorConfig(BaseSettings): return { "vector_db_url": self.vector_db_url, "vector_db_port": self.vector_db_port, + "vector_db_name": self.vector_db_name, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 35bbc110a..3fe926978 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -9,6 +9,7 @@ from functools import lru_cache def create_vector_engine( vector_db_provider: str, vector_db_url: str, + vector_db_name: str, vector_db_port: str = "", vector_db_key: str = "", ): @@ -28,6 +29,7 @@ def create_vector_engine( - vector_db_url (str): The URL for the vector database instance. - vector_db_port (str): The port for the vector database instance. Required for some providers. + - vector_db_name (str): The name of the vector database instance. - vector_db_key (str): The API key or access token for the vector database instance. - vector_db_provider (str): The name of the vector database provider to use (e.g., 'pgvector'). @@ -46,7 +48,7 @@ def create_vector_engine( url=vector_db_url, api_key=vector_db_key, embedding_engine=embedding_engine, - graph_name=get_graph_context_config()["graph_database_name"], + database_name=vector_db_name, ) if vector_db_provider == "pgvector":