Compare commits

...
Sign in to create a new pull request.

10 commits

Author SHA1 Message Date
Daulet Amirkhanov
284a6d1cd6 lint: ruff format 2025-08-19 13:07:10 +01:00
Daulet Amirkhanov
4b78866403 fix: test modifies global sys.modules, affecting other tests 2025-08-19 13:06:40 +01:00
Daulet Amirkhanov
e8b7fe7d0e lint: ruff format fix 2025-08-19 11:59:44 +01:00
Daulet Amirkhanov
9d41bc9573
Merge branch 'dev' into feature/kuzu-s3-migration 2025-08-19 11:53:49 +01:00
Daulet Amirkhanov
50a9fb91f2
Apply suggestion from @daukadolt 2025-08-19 11:47:49 +01:00
Daulet Amirkhanov
4cbaa6502d
Apply suggestion from @daukadolt 2025-08-19 11:47:40 +01:00
Daulet Amirkhanov
330c6fe2ed
Apply suggestion from @daukadolt 2025-08-19 11:47:30 +01:00
Daulet Amirkhanov
e23175507e
Apply suggestion from @daukadolt 2025-08-19 11:47:20 +01:00
Timmy
6170d8972a feat(kuzu): refactor database opening and migration logic
- Consolidate database opening and migration into a single method `_open_or_migrate`.
- Automatically handle version mismatches and perform migrations as needed.
- Simplify the logic for pushing migrated databases back to S3.
2025-08-11 18:54:26 +01:00
Timmy
805d147266 feat(kuzu): enable S3-aware Kuzu migration and auto-migrate in adapter\n\n- Add S3 staging, version read from S3 catalog.kz, and S3 rename (copy+rm)\n- Integrate auto-migration in KuzuAdapter for S3/local paths\n- Add unit tests for S3 migration helpers and adapter auto-migration\n\nI affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin
Signed-off-by: Timmy <timilehinoluwaseyi07@gmail.com>
2025-08-10 18:39:20 +01:00
4 changed files with 630 additions and 91 deletions

View file

@ -1,26 +1,27 @@
"""Adapter for Kuzu graph database."""
import os
import json
import asyncio
import json
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
from kuzu import Connection
from kuzu.database import Database
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.modules.storage.utils import JSONEncoder
from cognee.shared.logging_utils import get_logger
logger = get_logger()
@ -48,16 +49,16 @@ class KuzuAdapter(GraphDBInterface):
"""Initialize the Kuzu database connection and schema."""
try:
if "s3://" in self.db_path:
# Stage S3 DB to a local temp file
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
self.temp_graph_file = temp_file.name
run_sync(self.pull_from_s3())
self.db = Database(
self.temp_graph_file,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
# Open DB; on version mismatch auto-migrate and then push back to S3
self.db, migrated = self._open_or_migrate(self.temp_graph_file)
if migrated:
run_sync(self.push_to_s3())
else:
# Ensure the parent directory exists before creating the database
db_dir = os.path.dirname(self.db_path)
@ -73,36 +74,8 @@ class KuzuAdapter(GraphDBInterface):
run_sync(file_storage.ensure_directory_exists())
try:
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
except RuntimeError:
from .kuzu_migrate import read_kuzu_storage_version
import kuzu
kuzu_db_version = read_kuzu_storage_version(self.db_path)
if (
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
) and kuzu_db_version != kuzu.__version__:
# Try to migrate kuzu database to latest version
from .kuzu_migrate import kuzu_migration
kuzu_migration(
new_db=self.db_path + "_new",
old_db=self.db_path,
new_version=kuzu.__version__,
old_version=kuzu_db_version,
overwrite=True,
)
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
# Open DB; on version mismatch auto-migrate and then retry
self.db, _ = self._open_or_migrate(self.db_path)
self.db.init_database()
self.connection = Connection(self.db)
@ -132,6 +105,45 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
def _open_or_migrate(self, path: str) -> Tuple[Database, bool]:
"""
Try to open the Kuzu database at path. If it fails due to a version mismatch,
detect the on-disk version and migrate in-place to the current installed Kuzu
version. Returns the opened Database instance and a flag indicating whether a
migration was performed.
"""
did_migrate = False
try:
db = Database(
path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
return db, did_migrate
except RuntimeError:
import kuzu
from .kuzu_migrate import kuzu_migration, read_kuzu_storage_version
kuzu_db_version = read_kuzu_storage_version(path)
# Only migrate known legacy versions and when different from the installed one
if kuzu_db_version in ("0.9.0", "0.8.2") and kuzu_db_version != str(kuzu.__version__):
kuzu_migration(
new_db=path + "_new",
old_db=path,
new_version=str(kuzu.__version__),
old_version=kuzu_db_version,
overwrite=True,
)
did_migrate = True
# Retry opening after potential migration (or re-attempt if other transient issue)
db = Database(
path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
return db, did_migrate
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
@ -217,7 +229,8 @@ class KuzuAdapter(GraphDBInterface):
"""Convert a raw node result (with JSON properties) into a dictionary."""
if data.get("properties"):
try:
props = json.loads(data["properties"])
# Parse JSON properties into a dict
props: Dict[str, Any] = json.loads(data["properties"])
# Remove the JSON field and merge its contents
data.pop("properties")
data.update(props)

View file

@ -27,13 +27,28 @@ Notes:
- Can only be used to migrate to newer Kuzu versions, from 0.11.0 onwards
"""
import tempfile
import sys
import struct
import shutil
import subprocess
import argparse
import os
import shutil
import struct
import subprocess
import sys
import tempfile
from typing import Any, Optional
# Lazy-import s3fs via our storage adapter only when needed, so local-only runs
# don't require S3 credentials or dependencies at import time.
def _is_s3_path(path: str) -> bool:
return path.startswith("s3://")
def _get_s3_client() -> Any: # Returns configured s3fs client via project storage adapter
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
storage: Any = S3FileStorage("")
client: Any = storage.s3 # type: ignore[attr-defined]
return client
kuzu_version_mapping = {
@ -46,30 +61,48 @@ kuzu_version_mapping = {
}
def read_kuzu_storage_version(kuzu_db_path: str) -> int:
def read_kuzu_storage_version(kuzu_db_path: str) -> str:
"""
Reads the Kùzu storage version code from the first catalog.bin file bytes.
Read the Kuzu storage version from the first bytes of catalog.kz and map it
to a human-readable Kuzu semantic version string (e.g. "0.9.0").
:param kuzu_db_path: Path to the Kuzu database file/directory.
:return: Storage version code as an integer.
:param kuzu_db_path: Path/URI (local or s3://) to the Kuzu database file/directory.
:return: Semantic version string (e.g. "0.9.0").
"""
if os.path.isdir(kuzu_db_path):
kuzu_version_file_path = os.path.join(kuzu_db_path, "catalog.kz")
if not os.path.isfile(kuzu_version_file_path):
raise FileExistsError("Kuzu catalog.kz file does not exist")
if _is_s3_path(kuzu_db_path):
s3 = _get_s3_client()
# Determine whether the remote path is a directory or file
version_key = kuzu_db_path
try:
if s3.isdir(kuzu_db_path):
version_key = kuzu_db_path.rstrip("/") + "/catalog.kz"
# Open directly from S3 without downloading the entire DB
with s3.open(version_key, "rb") as f:
f.seek(4)
data = f.read(8)
except FileNotFoundError:
raise FileExistsError("Kuzu catalog.kz file does not exist on S3")
else:
kuzu_version_file_path = kuzu_db_path
if os.path.isdir(kuzu_db_path):
kuzu_version_file_path = os.path.join(kuzu_db_path, "catalog.kz")
if not os.path.isfile(kuzu_version_file_path):
raise FileExistsError("Kuzu catalog.kz file does not exist")
else:
kuzu_version_file_path = kuzu_db_path
with open(kuzu_version_file_path, "rb") as f:
# Skip the 3-byte magic "KUZ" and one byte of padding
f.seek(4)
# Read the next 8 bytes as a little-endian unsigned 64-bit integer
data = f.read(8)
if len(data) < 8:
raise ValueError(
f"File '{kuzu_version_file_path}' does not contain a storage version code."
)
version_code = struct.unpack("<Q", data)[0]
with open(kuzu_version_file_path, "rb") as f:
# Skip the 3-byte magic "KUZ" and one byte of padding
f.seek(4)
# Read the next 8 bytes as a little-endian unsigned 64-bit integer
data = f.read(8)
if len(data) < 8:
raise ValueError(
f"File '{kuzu_version_file_path}' does not contain a storage version code."
)
if len(data) < 8:
raise ValueError("catalog.kz does not contain a storage version code.")
version_code = struct.unpack("<Q", data)[0]
if kuzu_version_mapping.get(version_code):
return kuzu_version_mapping[version_code]
@ -77,7 +110,7 @@ def read_kuzu_storage_version(kuzu_db_path: str) -> int:
raise ValueError("Could not map version_code to proper Kuzu version.")
def ensure_env(version: str, export_dir) -> str:
def ensure_env(version: str, export_dir: str) -> str:
"""
Create (if needed) a venv at .kuzu_envs/{version} and install kuzu=={version}.
Returns the path to the venv's python executable.
@ -119,7 +152,14 @@ conn.execute(r\"\"\"{cypher}\"\"\")
sys.exit(proc.returncode)
def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None, delete_old=None):
def kuzu_migration(
new_db: str,
old_db: str,
new_version: str,
old_version: Optional[str] = None,
overwrite: Optional[bool] = None,
delete_old: Optional[bool] = None,
) -> None:
"""
Main migration function that handles the complete migration process.
"""
@ -131,23 +171,52 @@ def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None
if not old_version:
old_version = read_kuzu_storage_version(old_db)
# Check if old database exists
if not os.path.exists(old_db):
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
sys.exit(1)
# Check if old database exists (local or S3)
if _is_s3_path(old_db):
s3 = _get_s3_client()
if not (s3.exists(old_db) or s3.exists(old_db.rstrip("/") + "/")):
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
sys.exit(1)
else:
if not os.path.exists(old_db):
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
sys.exit(1)
# Prepare target - ensure parent directory exists but remove target if it exists
parent_dir = os.path.dirname(new_db)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
if os.path.exists(new_db):
raise FileExistsError(
"File already exists at new database location, remove file or change new database file path to continue"
)
if _is_s3_path(new_db):
# For S3 we don't create directories locally; just ensure the key doesn't already exist
s3 = _get_s3_client()
if s3.exists(new_db) or s3.exists(new_db.rstrip("/") + "/"):
raise FileExistsError(
"File already exists at new database location on S3; remove it or change new database path to continue"
)
else:
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
if os.path.exists(new_db):
raise FileExistsError(
"File already exists at new database location, remove file or change new database file path to continue"
)
# Use temp directory for all processing, it will be cleaned up after with statement
with tempfile.TemporaryDirectory() as export_dir:
is_old_s3 = _is_s3_path(old_db)
is_new_s3 = _is_s3_path(new_db)
# If old DB is on S3, download it locally first.
local_old_db = old_db
local_new_db = new_db
if is_old_s3:
s3 = _get_s3_client()
local_old_db = os.path.join(export_dir, "old_kuzu_db")
# Download either a file or a directory recursively
print(f"⬇️ Downloading old DB from S3 → {local_old_db}", file=sys.stderr)
s3.get(old_db, local_old_db, recursive=True)
if is_new_s3:
# Always stage new DB locally, then upload after migration
local_new_db = os.path.join(export_dir, "new_kuzu_db")
# Set up environments
print(f"Setting up Kuzu {old_version} environment...", file=sys.stderr)
old_py = ensure_env(old_version, export_dir)
@ -156,7 +225,7 @@ def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None
export_file = os.path.join(export_dir, "kuzu_export")
print(f"Exporting old DB → {export_dir}", file=sys.stderr)
run_migration_step(old_py, old_db, f"EXPORT DATABASE '{export_file}'")
run_migration_step(old_py, local_old_db, f"EXPORT DATABASE '{export_file}'")
print("Export complete.", file=sys.stderr)
# Check if export files were created and have content
@ -164,17 +233,36 @@ def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None
if not os.path.exists(schema_file) or os.path.getsize(schema_file) == 0:
raise ValueError(f"Schema file not found: {schema_file}")
print(f"Importing into new DB at {new_db}", file=sys.stderr)
run_migration_step(new_py, new_db, f"IMPORT DATABASE '{export_file}'")
print(f"Importing into new DB at {local_new_db}", file=sys.stderr)
run_migration_step(new_py, local_new_db, f"IMPORT DATABASE '{export_file}'")
print("Import complete.", file=sys.stderr)
# Rename new kuzu database to old kuzu database name if enabled
# If the target is S3, upload the migrated DB now
if is_new_s3:
# Remove kuzu lock from migrated DB before upload if present
lock_file = local_new_db + ".lock"
if os.path.exists(lock_file):
os.remove(lock_file)
print(f"⬆️ Uploading new DB to S3: {new_db}", file=sys.stderr)
s3 = _get_s3_client()
s3.put(local_new_db, new_db, recursive=True)
# Normalize flags
overwrite = bool(overwrite)
delete_old = bool(delete_old)
# Rename/move results into place if requested
if overwrite or delete_old:
# Remove kuzu lock from migrated DB
lock_file = new_db + ".lock"
if os.path.exists(lock_file):
os.remove(lock_file)
rename_databases(old_db, old_version, new_db, delete_old)
if _is_s3_path(new_db) or _is_s3_path(old_db):
# S3-aware rename
_s3_rename_databases(old_db, old_version, new_db, delete_old)
else:
# Remove kuzu lock from migrated DB
lock_file = new_db + ".lock"
if os.path.exists(lock_file):
os.remove(lock_file)
rename_databases(old_db, old_version, new_db, delete_old)
print("✅ Kuzu graph database migration finished successfully!")
@ -224,6 +312,57 @@ def rename_databases(old_db: str, old_version: str, new_db: str, delete_old: boo
print(f"Renamed '{src_new}' to '{dst_new}'", file=sys.stderr)
def _s3_rename_databases(old_db: str, old_version: str, new_db: str, delete_old: bool):
"""
Perform S3-equivalent of rename_databases: optionally back up the original old_db
to *_old_<version>, replace it with the new_db contents, and clean up.
This function handles both file-based and directory-based Kuzu databases by using
recursive copy and remove operations provided by s3fs.
"""
s3 = _get_s3_client()
# Normalize paths (keep s3:// URIs as they are; s3fs supports them)
def _isdir(p: str) -> bool:
try:
return s3.isdir(p)
except FileNotFoundError:
return False
def _isfile(p: str) -> bool:
try:
return s3.isfile(p)
except FileNotFoundError:
return False
base_dir = os.path.dirname(old_db.rstrip("/"))
name = os.path.basename(old_db.rstrip("/"))
backup_database_name = f"{name}_old_" + old_version.replace(".", "_")
backup_base = base_dir + "/" + backup_database_name
# Back up or delete the original old_db
if _isfile(old_db):
if not delete_old:
s3.copy(old_db, backup_base, recursive=True)
print(f"Copied '{old_db}' to '{backup_base}' on S3", file=sys.stderr)
s3.rm(old_db, recursive=True)
elif _isdir(old_db):
if not delete_old:
s3.copy(old_db, backup_base, recursive=True)
print(f"Copied directory '{old_db}' to '{backup_base}' on S3", file=sys.stderr)
s3.rm(old_db, recursive=True)
else:
print(f"Original database path '{old_db}' not found on S3 for renaming.", file=sys.stderr)
sys.exit(1)
# Move new into place under the old name
target_path = base_dir + "/" + name
s3.copy(new_db, target_path, recursive=True)
print(f"Copied '{new_db}' to '{target_path}' on S3", file=sys.stderr)
# Remove the staging 'new_db' key
s3.rm(new_db, recursive=True)
def main():
p = argparse.ArgumentParser(
description="Migrate Kùzu DB via PyPI versions",

View file

@ -0,0 +1,192 @@
import importlib.util
import os
import sys
import types
from types import ModuleType
import pytest
class _DBOpenError(RuntimeError):
pass
class _FakeDatabase:
"""Fake kuzu.Database that fails first, then succeeds."""
calls = 0
def __init__(self, path: str, **kwargs):
_FakeDatabase.calls += 1
if _FakeDatabase.calls == 1:
raise _DBOpenError("version mismatch")
def init_database(self):
pass
class _FakeConnection:
def __init__(self, db):
pass
def execute(self, query: str, params=None):
class _Res:
def has_next(self):
return False
def get_next(self):
return []
return _Res()
@pytest.fixture
def stub_import(monkeypatch):
def _install_stub(name: str, module: ModuleType | None = None) -> ModuleType:
mod = module or ModuleType(name)
# Mark as package so submodule imports succeed when needed
if not hasattr(mod, "__path__"):
mod.__path__ = [] # type: ignore[attr-defined]
monkeypatch.setitem(sys.modules, name, mod)
return mod
return _install_stub
def _find_repo_root(start_path: str) -> str:
"""Walk up directories until we find pyproject.toml (repo root)."""
cur = os.path.abspath(start_path)
while True:
if os.path.exists(os.path.join(cur, "pyproject.toml")):
return cur
parent = os.path.dirname(cur)
if parent == cur:
raise RuntimeError("Could not locate repository root from: " + start_path)
cur = parent
def _load_adapter_with_stubs(monkeypatch, stub_import):
# Provide fake 'kuzu' and submodules used by adapter imports
kuzu_mod = stub_import("kuzu")
kuzu_mod.__dict__["__version__"] = "0.11.0"
# Placeholders to satisfy adapter's "from kuzu import Connection" and "from kuzu.database import Database"
class _PlaceholderConn:
pass
kuzu_mod.Connection = _PlaceholderConn
kuzu_db_mod = stub_import("kuzu.database")
class _PlaceholderDB:
pass
kuzu_db_mod.Database = _PlaceholderDB
# Create minimal stub tree for required cognee imports to avoid executing package __init__
stub_import("cognee")
stub_import("cognee.infrastructure")
stub_import("cognee.infrastructure.databases")
stub_import("cognee.infrastructure.databases.graph")
stub_import("cognee.infrastructure.databases.graph.kuzu")
# graph_db_interface stub
gdi_mod = stub_import("cognee.infrastructure.databases.graph.graph_db_interface")
class _GraphDBInterface: # bare minimum
pass
def record_graph_changes(fn):
return fn
gdi_mod.GraphDBInterface = _GraphDBInterface
gdi_mod.record_graph_changes = record_graph_changes
# engine.DataPoint stub
engine_mod = stub_import("cognee.infrastructure.engine")
class _DataPoint:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
engine_mod.DataPoint = _DataPoint
# files.storage.get_file_storage stub
storage_pkg = stub_import("cognee.infrastructure.files.storage")
storage_pkg.get_file_storage = lambda path: types.SimpleNamespace(
ensure_directory_exists=lambda: None
)
# utils.run_sync stub
run_sync_mod = stub_import("cognee.infrastructure.utils.run_sync")
run_sync_mod.run_sync = lambda coro: None
# modules.storage.utils JSONEncoder stub
utils_mod2 = stub_import("cognee.modules.storage.utils")
utils_mod2.JSONEncoder = object
# shared.logging_utils.get_logger stub
logging_utils_mod = stub_import("cognee.shared.logging_utils")
class _Logger:
def debug(self, *a, **k):
pass
def error(self, *a, **k):
pass
logging_utils_mod.get_logger = lambda: _Logger()
# Now load adapter.py by path
repo_root = _find_repo_root(os.path.dirname(__file__))
adapter_path = os.path.join(
repo_root, "cognee", "infrastructure", "databases", "graph", "kuzu", "adapter.py"
)
spec = importlib.util.spec_from_file_location(
"cognee.infrastructure.databases.graph.kuzu.adapter", adapter_path
)
assert spec and spec.loader
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[attr-defined]
# Replace Database/Connection in the loaded module
monkeypatch.setattr(mod, "Database", _FakeDatabase, raising=True)
monkeypatch.setattr(mod, "Connection", _FakeConnection, raising=True)
# Patch migration helpers inside the kuzu_migrate module used by adapter
# Load kuzu_migrate similarly
km_path = os.path.join(
repo_root, "cognee", "infrastructure", "databases", "graph", "kuzu", "kuzu_migrate.py"
)
km_spec = importlib.util.spec_from_file_location("kuzu_migrate_under_test", km_path)
km_mod = importlib.util.module_from_spec(km_spec)
assert km_spec and km_spec.loader
km_spec.loader.exec_module(km_mod) # type: ignore[attr-defined]
calls = {"migrated": False}
def fake_read_version(_):
return "0.9.0"
def fake_migration(**kwargs):
calls["migrated"] = True
monkeypatch.setattr(km_mod, "read_kuzu_storage_version", fake_read_version)
monkeypatch.setattr(km_mod, "kuzu_migration", fake_migration)
# Ensure adapter refers to our loaded km_mod
monkeypatch.setitem(
sys.modules, "cognee.infrastructure.databases.graph.kuzu.kuzu_migrate", km_mod
)
return mod, calls
def test_adapter_s3_auto_migration(monkeypatch, stub_import):
mod, calls = _load_adapter_with_stubs(monkeypatch, stub_import)
# ensure pull/push do not touch real S3
monkeypatch.setattr(mod.KuzuAdapter, "pull_from_s3", lambda self: None)
monkeypatch.setattr(mod.KuzuAdapter, "push_to_s3", lambda self: None)
mod.KuzuAdapter("s3://bucket/db")
assert calls["migrated"] is True

View file

@ -0,0 +1,195 @@
import io
import os
import struct
import importlib.util
from types import ModuleType
from typing import Dict
import pytest
class _FakeS3:
"""
Minimal fake S3 client implementing the subset used by kuzu_migrate helpers.
Store layout is a dict mapping string keys to bytes (files). Directories are
implicit via key prefixes. Methods operate on s3:// style keys.
"""
def __init__(self, initial: Dict[str, bytes] | None = None):
self.store: Dict[str, bytes] = dict(initial or {})
# Helpers
def _norm(self, path: str) -> str:
return path.rstrip("/")
def _is_prefix(self, prefix: str, key: str) -> bool:
p = self._norm(prefix)
return key == p or key.startswith(p + "/")
# API used by kuzu_migrate
def exists(self, path: str) -> bool:
p = self._norm(path)
if p in self.store:
return True
# any key under this prefix implies existence as a directory
return any(self._is_prefix(p, k) for k in self.store)
def isdir(self, path: str) -> bool:
p = self._norm(path)
# A directory is assumed if there is any key with this prefix and that key isn't exactly the same
return any(self._is_prefix(p, k) and k != p for k in self.store)
def isfile(self, path: str) -> bool:
p = self._norm(path)
return p in self.store
def open(self, path: str, mode: str = "rb"):
p = self._norm(path)
if "r" in mode:
if p not in self.store:
raise FileNotFoundError(p)
return io.BytesIO(self.store[p])
elif "w" in mode:
buf = io.BytesIO()
def _close():
self.store[p] = buf.getvalue()
# monkeypatch close so that written data is persisted on close
orig_close = buf.close
def close_wrapper():
_close()
orig_close()
buf.close = close_wrapper # type: ignore[assignment]
return buf
else:
raise ValueError(f"Unsupported mode: {mode}")
def copy(self, src: str, dst: str, recursive: bool = True):
s = self._norm(src)
d = self._norm(dst)
if recursive:
# copy all keys under src prefix to dst prefix
to_copy = [k for k in self.store if self._is_prefix(s, k)]
for key in to_copy:
new_key = key.replace(s, d, 1)
self.store[new_key] = self.store[key]
else:
if s not in self.store:
raise FileNotFoundError(s)
self.store[d] = self.store[s]
def rm(self, path: str, recursive: bool = False):
p = self._norm(path)
if recursive:
for key in list(self.store.keys()):
if self._is_prefix(p, key):
del self.store[key]
else:
if p in self.store:
del self.store[p]
else:
raise FileNotFoundError(p)
def _load_module_by_path(path: str) -> ModuleType:
spec = importlib.util.spec_from_file_location("kuzu_migrate_under_test", path)
assert spec and spec.loader
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[attr-defined]
return mod
def _find_repo_root(start_path: str) -> str:
cur = os.path.abspath(start_path)
while True:
if os.path.exists(os.path.join(cur, "pyproject.toml")):
return cur
parent = os.path.dirname(cur)
if parent == cur:
raise RuntimeError("Could not locate repository root from: " + start_path)
cur = parent
@pytest.fixture
def km_module(monkeypatch):
# Load the kuzu_migrate module directly from file to avoid importing package __init__
repo_root = _find_repo_root(os.path.dirname(__file__))
target = os.path.join(
repo_root, "cognee", "infrastructure", "databases", "graph", "kuzu", "kuzu_migrate.py"
)
mod = _load_module_by_path(target)
return mod
@pytest.fixture
def patch_get_s3_client(monkeypatch, km_module):
# Provide each test with its own fake client instance
client = _FakeS3()
monkeypatch.setattr(km_module, "_get_s3_client", lambda: client)
return client
def _make_catalog_bytes(version_code: int) -> bytes:
# 4 bytes header skipped + 8 bytes little-endian version code
return b"KUZ\x00" + struct.pack("<Q", version_code) + b"padding"
def test_read_kuzu_storage_version_from_s3_directory(monkeypatch, patch_get_s3_client, km_module):
s3 = patch_get_s3_client
# Simulate a directory with catalog.kz
dir_key = "s3://bucket/db"
catalog_key = dir_key + "/catalog.kz"
s3.store[catalog_key] = _make_catalog_bytes(39) # maps to 0.11.0
assert km_module.read_kuzu_storage_version(dir_key) == "0.11.0"
def test_s3_rename_file_backup(monkeypatch, patch_get_s3_client, km_module):
s3 = patch_get_s3_client
old_db = "s3://bucket/graph.db"
new_db = "s3://bucket/graph_new.db"
# seed store
s3.store[old_db] = b"OLD"
s3.store[new_db] = b"NEW"
km_module._s3_rename_databases(old_db, "0.9.0", new_db, delete_old=False)
# old is replaced with new
assert s3.store.get(old_db) == b"NEW"
# backup exists with version suffix
backup = "s3://bucket/graph.db_old_0_9_0"
assert s3.store.get(backup) == b"OLD"
# staging removed
assert new_db not in s3.store
def test_s3_rename_directory_delete_old(monkeypatch, patch_get_s3_client, km_module):
s3 = patch_get_s3_client
old_dir = "s3://bucket/graph_dir"
new_dir = "s3://bucket/graph_dir_new"
# Represent a directory by multiple keys under the prefix
s3.store[old_dir + "/catalog.kz"] = b"OLD1"
s3.store[old_dir + "/data.bin"] = b"OLD2"
s3.store[new_dir + "/catalog.kz"] = b"NEW1"
s3.store[new_dir + "/data.bin"] = b"NEW2"
km_module._s3_rename_databases(old_dir, "0.9.0", new_dir, delete_old=True)
# old dir contents replaced by new
assert s3.store.get(old_dir + "/catalog.kz") == b"NEW1"
assert s3.store.get(old_dir + "/data.bin") == b"NEW2"
# no backup created when delete_old=True
backup_prefix = os.path.dirname(old_dir) + "/" + os.path.basename(old_dir) + "_old_0_9_0"
assert not any(k.startswith(backup_prefix) for k in s3.store)
# staging removed
assert not any(k.startswith(new_dir) for k in s3.store)