feat(kuzu): refactor database opening and migration logic

- Consolidate database opening and migration into a single method `_open_or_migrate`.
- Automatically handle version mismatches and perform migrations as needed.
- Simplify the logic for pushing migrated databases back to S3.
This commit is contained in:
Timmy 2025-08-11 18:54:26 +01:00
parent 805d147266
commit 6170d8972a

View file

@ -54,38 +54,10 @@ class KuzuAdapter(GraphDBInterface):
run_sync(self.pull_from_s3())
# Try to open; if it fails due to version mismatch, migrate the temp copy and push back
try:
self.db = Database(
self.temp_graph_file,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
except RuntimeError:
import kuzu
from .kuzu_migrate import kuzu_migration, read_kuzu_storage_version
kuzu_db_version = read_kuzu_storage_version(self.temp_graph_file)
if (
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
) and kuzu_db_version != str(kuzu.__version__): # ensure string comparison
kuzu_migration(
new_db=self.temp_graph_file + "_new",
old_db=self.temp_graph_file,
new_version=str(kuzu.__version__), # pass str to satisfy types
old_version=kuzu_db_version,
overwrite=True,
)
# Push migrated DB back to S3
run_sync(self.push_to_s3())
# Retry opening after potential migration
self.db = Database(
self.temp_graph_file,
buffer_pool_size=2048 * 1024 * 1024,
max_db_size=4096 * 1024 * 1024,
)
# Open DB; on version mismatch auto-migrate and then push back to S3
self.db, migrated = self._open_or_migrate(self.temp_graph_file)
if migrated:
run_sync(self.push_to_s3())
else:
# Ensure the parent directory exists before creating the database
db_dir = os.path.dirname(self.db_path)
@ -101,37 +73,8 @@ class KuzuAdapter(GraphDBInterface):
run_sync(file_storage.ensure_directory_exists())
try:
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
except RuntimeError:
import kuzu
from .kuzu_migrate import read_kuzu_storage_version
kuzu_db_version = read_kuzu_storage_version(self.db_path)
if (
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
) and kuzu_db_version != str(kuzu.__version__):
# Try to migrate kuzu database to latest version
from .kuzu_migrate import kuzu_migration
kuzu_migration(
new_db=self.db_path + "_new",
old_db=self.db_path,
new_version=str(kuzu.__version__),
old_version=kuzu_db_version,
overwrite=True,
)
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
# Open DB; on version mismatch auto-migrate and then retry
self.db, _ = self._open_or_migrate(self.db_path)
self.db.init_database()
self.connection = Connection(self.db)
@ -161,6 +104,45 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
def _open_or_migrate(self, path: str) -> Tuple[Database, bool]:
"""
Try to open the Kuzu database at path. If it fails due to a version mismatch,
detect the on-disk version and migrate in-place to the current installed Kuzu
version. Returns the opened Database instance and a flag indicating whether a
migration was performed.
"""
did_migrate = False
try:
db = Database(
path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
return db, did_migrate
except RuntimeError:
import kuzu
from .kuzu_migrate import kuzu_migration, read_kuzu_storage_version
kuzu_db_version = read_kuzu_storage_version(path)
# Only migrate known legacy versions and when different from the installed one
if kuzu_db_version in ("0.9.0", "0.8.2") and kuzu_db_version != str(kuzu.__version__):
kuzu_migration(
new_db=path + "_new",
old_db=path,
new_version=str(kuzu.__version__),
old_version=kuzu_db_version,
overwrite=True,
)
did_migrate = True
# Retry opening after potential migration (or re-attempt if other transient issue)
db = Database(
path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
return db, did_migrate
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage