diff --git a/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py index e15a98b7c..25b94a724 100644 --- a/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +++ b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py @@ -49,6 +49,20 @@ def _recreate_table_without_unique_constraint_sqlite(op, insp): sa.Column("graph_database_name", sa.String(), nullable=False), sa.Column("vector_database_provider", sa.String(), nullable=False), sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), sa.Column("vector_database_url", sa.String()), sa.Column("graph_database_url", sa.String()), sa.Column("vector_database_key", sa.String()), @@ -82,6 +96,8 @@ def _recreate_table_without_unique_constraint_sqlite(op, insp): graph_database_name, vector_database_provider, graph_database_provider, + vector_dataset_database_handler, + graph_dataset_database_handler, vector_database_url, graph_database_url, vector_database_key, @@ -120,6 +136,20 @@ def _recreate_table_with_unique_constraint_sqlite(op, insp): sa.Column("graph_database_name", sa.String(), nullable=False, unique=True), sa.Column("vector_database_provider", sa.String(), nullable=False), sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), sa.Column("vector_database_url", sa.String()), sa.Column("graph_database_url", sa.String()), sa.Column("vector_database_key", sa.String()), @@ -153,6 +183,8 @@ def _recreate_table_with_unique_constraint_sqlite(op, insp): graph_database_name, vector_database_provider, graph_database_provider, + vector_dataset_database_handler, + graph_dataset_database_handler, vector_database_url, graph_database_url, vector_database_key, @@ -193,6 +225,22 @@ def upgrade() -> None: ), ) + vector_dataset_database_handler = _get_column( + insp, "dataset_database", "vector_dataset_database_handler" + ) + if not vector_dataset_database_handler: + # Add LanceDB as the default graph dataset database handler + op.add_column( + "dataset_database", + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + ) + graph_database_connection_info_column = _get_column( insp, "dataset_database", "graph_database_connection_info" ) @@ -208,6 +256,22 @@ def upgrade() -> None: ), ) + graph_dataset_database_handler = _get_column( + insp, "dataset_database", "graph_dataset_database_handler" + ) + if not graph_dataset_database_handler: + # Add Kuzu as the default graph dataset database handler + op.add_column( + "dataset_database", + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), + ) + with op.batch_alter_table("dataset_database", schema=None) as batch_op: # Drop the unique constraint to make unique=False graph_constraint_to_drop = None @@ -265,3 +329,5 @@ def downgrade() -> None: op.drop_column("dataset_database", "vector_database_connection_info") op.drop_column("dataset_database", "graph_database_connection_info") + op.drop_column("dataset_database", "vector_dataset_database_handler") + op.drop_column("dataset_database", "graph_dataset_database_handler") diff --git a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py index edc6d5c39..61ff84870 100644 --- a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py @@ -47,6 +47,7 @@ class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "graph_database_url": graph_db_url, "graph_database_provider": graph_config.graph_database_provider, "graph_database_key": graph_db_key, + "graph_dataset_database_handler": "kuzu", "graph_database_connection_info": { "graph_database_username": graph_db_username, "graph_database_password": graph_db_password, diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py index 73f057fa8..eb6cbc55a 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py @@ -131,6 +131,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "graph_database_url": graph_db_url, "graph_database_provider": "neo4j", "graph_database_key": graph_db_key, + "graph_dataset_database_handler": "neo4j_aura_dev", "graph_database_connection_info": { "graph_database_username": graph_db_username, "graph_database_password": encrypted_db_password_string, diff --git a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py index 4d8c19403..d33169642 100644 --- a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +++ b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py @@ -1,27 +1,21 @@ -from cognee.infrastructure.databases.vector import get_vectordb_config -from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.modules.users.models.DatasetDatabase import DatasetDatabase async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: - vector_config = get_vectordb_config() - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler] return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: - graph_config = get_graph_config() - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler] return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py index f165a7ea4..e392b7eb8 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py @@ -36,6 +36,7 @@ class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "vector_database_url": os.path.join(databases_directory_path, vector_db_name), "vector_database_key": vector_config.vector_db_key, "vector_database_name": vector_db_name, + "vector_dataset_database_handler": "lancedb", } @classmethod diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py index b43cab1f7..645e1a223 100644 --- a/cognee/modules/data/deletion/prune_system.py +++ b/cognee/modules/data/deletion/prune_system.py @@ -5,8 +5,6 @@ from cognee.context_global_variables import backend_access_control_enabled from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.infrastructure.databases.vector.config import get_vectordb_config -from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.shared.cache import delete_cache from cognee.modules.users.models import DatasetDatabase from cognee.shared.logging_utils import get_logger @@ -16,12 +14,13 @@ logger = get_logger() async def prune_graph_databases(): async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict: - graph_config = get_graph_config() from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + handler = supported_dataset_database_handlers[ + dataset_database.graph_dataset_database_handler + ] return await handler["handler_instance"].delete_dataset(dataset_database) db_engine = get_relational_engine() @@ -40,13 +39,13 @@ async def prune_graph_databases(): async def prune_vector_databases(): async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict: - vector_config = get_vectordb_config() - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + handler = supported_dataset_database_handlers[ + dataset_database.vector_dataset_database_handler + ] return await handler["handler_instance"].delete_dataset(dataset_database) db_engine = get_relational_engine() diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 15964f032..08c4b5311 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -18,6 +18,9 @@ class DatasetDatabase(Base): vector_database_provider = Column(String, unique=False, nullable=False) graph_database_provider = Column(String, unique=False, nullable=False) + graph_dataset_database_handler = Column(String, unique=False, nullable=False) + vector_dataset_database_handler = Column(String, unique=False, nullable=False) + vector_database_url = Column(String, unique=False, nullable=True) graph_database_url = Column(String, unique=False, nullable=True) diff --git a/cognee/tests/test_dataset_database_handler.py b/cognee/tests/test_dataset_database_handler.py index be1b249d2..e4c9b0177 100644 --- a/cognee/tests/test_dataset_database_handler.py +++ b/cognee/tests/test_dataset_database_handler.py @@ -30,6 +30,7 @@ class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): vector_db_name = "test.lance.db" return { + "vector_dataset_database_handler": "custom_lancedb_handler", "vector_database_name": vector_db_name, "vector_database_url": os.path.join(databases_directory_path, vector_db_name), "vector_database_provider": "lancedb", @@ -44,6 +45,7 @@ class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): graph_db_name = "test.kuzu" return { + "graph_dataset_database_handler": "custom_kuzu_handler", "graph_database_name": graph_db_name, "graph_database_url": os.path.join(databases_directory_path, graph_db_name), "graph_database_provider": "kuzu",