diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 0a2638dc5..27c0d62a3 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -39,6 +39,30 @@ async def get_or_create_dataset_database( vector_config = get_vectordb_config() graph_config = get_graph_config() + # Note: for hybrid databases both graph and vector DB name have to be the same + if graph_config.graph_database_provider == "kuzu": + graph_db_name = f"{dataset_id}.pkl" + else: + graph_db_name = f"{dataset_id}" + + if vector_config.vector_db_provider == "lancedb": + vector_db_name = f"{dataset_id}.lance.db" + else: + vector_db_name = f"{dataset_id}" + + base_config = get_base_config() + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + # Determine vector database URL + if vector_config.vector_db_provider == "lancedb": + vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) + else: + vector_db_url = vector_config.vector_database_url + + # Determine graph database URL + async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist if isinstance(dataset, str): diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 5d2e4fcd5..f4b7c2aed 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -27,5 +27,14 @@ class DatasetDatabase(Base): graph_database_username = Column(String, unique=False, nullable=True) graph_database_password = Column(String, unique=False, nullable=True) + vector_database_provider = Column(String, unique=False, nullable=False) + graph_database_provider = Column(String, unique=False, nullable=False) + + vector_database_url = Column(String, unique=False, nullable=True) + graph_database_url = Column(String, unique=False, nullable=True) + + vector_database_key = Column(String, unique=False, nullable=True) + graph_database_key = Column(String, unique=False, nullable=True) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))