Compare commits
10 commits
main
...
pr-1230-en
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
284a6d1cd6 | ||
|
|
4b78866403 | ||
|
|
e8b7fe7d0e | ||
|
|
9d41bc9573 | ||
|
|
50a9fb91f2 | ||
|
|
4cbaa6502d | ||
|
|
330c6fe2ed | ||
|
|
e23175507e | ||
|
|
6170d8972a | ||
|
|
805d147266 |
4 changed files with 630 additions and 91 deletions
|
|
@ -1,26 +1,27 @@
|
|||
"""Adapter for Kuzu graph database."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from kuzu import Connection
|
||||
from kuzu.database import Database
|
||||
from datetime import datetime, timezone
|
||||
from contextlib import asynccontextmanager
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
)
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -48,16 +49,16 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"""Initialize the Kuzu database connection and schema."""
|
||||
try:
|
||||
if "s3://" in self.db_path:
|
||||
# Stage S3 DB to a local temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
|
||||
self.temp_graph_file = temp_file.name
|
||||
|
||||
run_sync(self.pull_from_s3())
|
||||
|
||||
self.db = Database(
|
||||
self.temp_graph_file,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
# Open DB; on version mismatch auto-migrate and then push back to S3
|
||||
self.db, migrated = self._open_or_migrate(self.temp_graph_file)
|
||||
if migrated:
|
||||
run_sync(self.push_to_s3())
|
||||
else:
|
||||
# Ensure the parent directory exists before creating the database
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
|
|
@ -73,36 +74,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
run_sync(file_storage.ensure_directory_exists())
|
||||
|
||||
try:
|
||||
self.db = Database(
|
||||
self.db_path,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
except RuntimeError:
|
||||
from .kuzu_migrate import read_kuzu_storage_version
|
||||
import kuzu
|
||||
|
||||
kuzu_db_version = read_kuzu_storage_version(self.db_path)
|
||||
if (
|
||||
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
|
||||
) and kuzu_db_version != kuzu.__version__:
|
||||
# Try to migrate kuzu database to latest version
|
||||
from .kuzu_migrate import kuzu_migration
|
||||
|
||||
kuzu_migration(
|
||||
new_db=self.db_path + "_new",
|
||||
old_db=self.db_path,
|
||||
new_version=kuzu.__version__,
|
||||
old_version=kuzu_db_version,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
self.db = Database(
|
||||
self.db_path,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
# Open DB; on version mismatch auto-migrate and then retry
|
||||
self.db, _ = self._open_or_migrate(self.db_path)
|
||||
|
||||
self.db.init_database()
|
||||
self.connection = Connection(self.db)
|
||||
|
|
@ -132,6 +105,45 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to initialize Kuzu database: {e}")
|
||||
raise e
|
||||
|
||||
def _open_or_migrate(self, path: str) -> Tuple[Database, bool]:
|
||||
"""
|
||||
Try to open the Kuzu database at path. If it fails due to a version mismatch,
|
||||
detect the on-disk version and migrate in-place to the current installed Kuzu
|
||||
version. Returns the opened Database instance and a flag indicating whether a
|
||||
migration was performed.
|
||||
"""
|
||||
did_migrate = False
|
||||
try:
|
||||
db = Database(
|
||||
path,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
return db, did_migrate
|
||||
except RuntimeError:
|
||||
import kuzu
|
||||
from .kuzu_migrate import kuzu_migration, read_kuzu_storage_version
|
||||
|
||||
kuzu_db_version = read_kuzu_storage_version(path)
|
||||
# Only migrate known legacy versions and when different from the installed one
|
||||
if kuzu_db_version in ("0.9.0", "0.8.2") and kuzu_db_version != str(kuzu.__version__):
|
||||
kuzu_migration(
|
||||
new_db=path + "_new",
|
||||
old_db=path,
|
||||
new_version=str(kuzu.__version__),
|
||||
old_version=kuzu_db_version,
|
||||
overwrite=True,
|
||||
)
|
||||
did_migrate = True
|
||||
|
||||
# Retry opening after potential migration (or re-attempt if other transient issue)
|
||||
db = Database(
|
||||
path,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
return db, did_migrate
|
||||
|
||||
async def push_to_s3(self) -> None:
|
||||
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
|
@ -217,7 +229,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"""Convert a raw node result (with JSON properties) into a dictionary."""
|
||||
if data.get("properties"):
|
||||
try:
|
||||
props = json.loads(data["properties"])
|
||||
# Parse JSON properties into a dict
|
||||
props: Dict[str, Any] = json.loads(data["properties"])
|
||||
# Remove the JSON field and merge its contents
|
||||
data.pop("properties")
|
||||
data.update(props)
|
||||
|
|
|
|||
|
|
@ -27,13 +27,28 @@ Notes:
|
|||
- Can only be used to migrate to newer Kuzu versions, from 0.11.0 onwards
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import sys
|
||||
import struct
|
||||
import shutil
|
||||
import subprocess
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
# Lazy-import s3fs via our storage adapter only when needed, so local-only runs
|
||||
# don't require S3 credentials or dependencies at import time.
|
||||
def _is_s3_path(path: str) -> bool:
|
||||
return path.startswith("s3://")
|
||||
|
||||
|
||||
def _get_s3_client() -> Any: # Returns configured s3fs client via project storage adapter
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
storage: Any = S3FileStorage("")
|
||||
client: Any = storage.s3 # type: ignore[attr-defined]
|
||||
return client
|
||||
|
||||
|
||||
kuzu_version_mapping = {
|
||||
|
|
@ -46,30 +61,48 @@ kuzu_version_mapping = {
|
|||
}
|
||||
|
||||
|
||||
def read_kuzu_storage_version(kuzu_db_path: str) -> int:
|
||||
def read_kuzu_storage_version(kuzu_db_path: str) -> str:
|
||||
"""
|
||||
Reads the Kùzu storage version code from the first catalog.bin file bytes.
|
||||
Read the Kuzu storage version from the first bytes of catalog.kz and map it
|
||||
to a human-readable Kuzu semantic version string (e.g. "0.9.0").
|
||||
|
||||
:param kuzu_db_path: Path to the Kuzu database file/directory.
|
||||
:return: Storage version code as an integer.
|
||||
:param kuzu_db_path: Path/URI (local or s3://) to the Kuzu database file/directory.
|
||||
:return: Semantic version string (e.g. "0.9.0").
|
||||
"""
|
||||
if os.path.isdir(kuzu_db_path):
|
||||
kuzu_version_file_path = os.path.join(kuzu_db_path, "catalog.kz")
|
||||
if not os.path.isfile(kuzu_version_file_path):
|
||||
raise FileExistsError("Kuzu catalog.kz file does not exist")
|
||||
if _is_s3_path(kuzu_db_path):
|
||||
s3 = _get_s3_client()
|
||||
# Determine whether the remote path is a directory or file
|
||||
version_key = kuzu_db_path
|
||||
try:
|
||||
if s3.isdir(kuzu_db_path):
|
||||
version_key = kuzu_db_path.rstrip("/") + "/catalog.kz"
|
||||
# Open directly from S3 without downloading the entire DB
|
||||
with s3.open(version_key, "rb") as f:
|
||||
f.seek(4)
|
||||
data = f.read(8)
|
||||
except FileNotFoundError:
|
||||
raise FileExistsError("Kuzu catalog.kz file does not exist on S3")
|
||||
else:
|
||||
kuzu_version_file_path = kuzu_db_path
|
||||
if os.path.isdir(kuzu_db_path):
|
||||
kuzu_version_file_path = os.path.join(kuzu_db_path, "catalog.kz")
|
||||
if not os.path.isfile(kuzu_version_file_path):
|
||||
raise FileExistsError("Kuzu catalog.kz file does not exist")
|
||||
else:
|
||||
kuzu_version_file_path = kuzu_db_path
|
||||
|
||||
with open(kuzu_version_file_path, "rb") as f:
|
||||
# Skip the 3-byte magic "KUZ" and one byte of padding
|
||||
f.seek(4)
|
||||
# Read the next 8 bytes as a little-endian unsigned 64-bit integer
|
||||
data = f.read(8)
|
||||
if len(data) < 8:
|
||||
raise ValueError(
|
||||
f"File '{kuzu_version_file_path}' does not contain a storage version code."
|
||||
)
|
||||
version_code = struct.unpack("<Q", data)[0]
|
||||
with open(kuzu_version_file_path, "rb") as f:
|
||||
# Skip the 3-byte magic "KUZ" and one byte of padding
|
||||
f.seek(4)
|
||||
# Read the next 8 bytes as a little-endian unsigned 64-bit integer
|
||||
data = f.read(8)
|
||||
if len(data) < 8:
|
||||
raise ValueError(
|
||||
f"File '{kuzu_version_file_path}' does not contain a storage version code."
|
||||
)
|
||||
|
||||
if len(data) < 8:
|
||||
raise ValueError("catalog.kz does not contain a storage version code.")
|
||||
version_code = struct.unpack("<Q", data)[0]
|
||||
|
||||
if kuzu_version_mapping.get(version_code):
|
||||
return kuzu_version_mapping[version_code]
|
||||
|
|
@ -77,7 +110,7 @@ def read_kuzu_storage_version(kuzu_db_path: str) -> int:
|
|||
raise ValueError("Could not map version_code to proper Kuzu version.")
|
||||
|
||||
|
||||
def ensure_env(version: str, export_dir) -> str:
|
||||
def ensure_env(version: str, export_dir: str) -> str:
|
||||
"""
|
||||
Create (if needed) a venv at .kuzu_envs/{version} and install kuzu=={version}.
|
||||
Returns the path to the venv's python executable.
|
||||
|
|
@ -119,7 +152,14 @@ conn.execute(r\"\"\"{cypher}\"\"\")
|
|||
sys.exit(proc.returncode)
|
||||
|
||||
|
||||
def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None, delete_old=None):
|
||||
def kuzu_migration(
|
||||
new_db: str,
|
||||
old_db: str,
|
||||
new_version: str,
|
||||
old_version: Optional[str] = None,
|
||||
overwrite: Optional[bool] = None,
|
||||
delete_old: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Main migration function that handles the complete migration process.
|
||||
"""
|
||||
|
|
@ -131,23 +171,52 @@ def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None
|
|||
if not old_version:
|
||||
old_version = read_kuzu_storage_version(old_db)
|
||||
|
||||
# Check if old database exists
|
||||
if not os.path.exists(old_db):
|
||||
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
# Check if old database exists (local or S3)
|
||||
if _is_s3_path(old_db):
|
||||
s3 = _get_s3_client()
|
||||
if not (s3.exists(old_db) or s3.exists(old_db.rstrip("/") + "/")):
|
||||
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
if not os.path.exists(old_db):
|
||||
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Prepare target - ensure parent directory exists but remove target if it exists
|
||||
parent_dir = os.path.dirname(new_db)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(new_db):
|
||||
raise FileExistsError(
|
||||
"File already exists at new database location, remove file or change new database file path to continue"
|
||||
)
|
||||
if _is_s3_path(new_db):
|
||||
# For S3 we don't create directories locally; just ensure the key doesn't already exist
|
||||
s3 = _get_s3_client()
|
||||
if s3.exists(new_db) or s3.exists(new_db.rstrip("/") + "/"):
|
||||
raise FileExistsError(
|
||||
"File already exists at new database location on S3; remove it or change new database path to continue"
|
||||
)
|
||||
else:
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
if os.path.exists(new_db):
|
||||
raise FileExistsError(
|
||||
"File already exists at new database location, remove file or change new database file path to continue"
|
||||
)
|
||||
|
||||
# Use temp directory for all processing, it will be cleaned up after with statement
|
||||
with tempfile.TemporaryDirectory() as export_dir:
|
||||
is_old_s3 = _is_s3_path(old_db)
|
||||
is_new_s3 = _is_s3_path(new_db)
|
||||
|
||||
# If old DB is on S3, download it locally first.
|
||||
local_old_db = old_db
|
||||
local_new_db = new_db
|
||||
if is_old_s3:
|
||||
s3 = _get_s3_client()
|
||||
local_old_db = os.path.join(export_dir, "old_kuzu_db")
|
||||
# Download either a file or a directory recursively
|
||||
print(f"⬇️ Downloading old DB from S3 → {local_old_db}", file=sys.stderr)
|
||||
s3.get(old_db, local_old_db, recursive=True)
|
||||
|
||||
if is_new_s3:
|
||||
# Always stage new DB locally, then upload after migration
|
||||
local_new_db = os.path.join(export_dir, "new_kuzu_db")
|
||||
# Set up environments
|
||||
print(f"Setting up Kuzu {old_version} environment...", file=sys.stderr)
|
||||
old_py = ensure_env(old_version, export_dir)
|
||||
|
|
@ -156,7 +225,7 @@ def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None
|
|||
|
||||
export_file = os.path.join(export_dir, "kuzu_export")
|
||||
print(f"Exporting old DB → {export_dir}", file=sys.stderr)
|
||||
run_migration_step(old_py, old_db, f"EXPORT DATABASE '{export_file}'")
|
||||
run_migration_step(old_py, local_old_db, f"EXPORT DATABASE '{export_file}'")
|
||||
print("Export complete.", file=sys.stderr)
|
||||
|
||||
# Check if export files were created and have content
|
||||
|
|
@ -164,17 +233,36 @@ def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None
|
|||
if not os.path.exists(schema_file) or os.path.getsize(schema_file) == 0:
|
||||
raise ValueError(f"Schema file not found: {schema_file}")
|
||||
|
||||
print(f"Importing into new DB at {new_db}", file=sys.stderr)
|
||||
run_migration_step(new_py, new_db, f"IMPORT DATABASE '{export_file}'")
|
||||
print(f"Importing into new DB at {local_new_db}", file=sys.stderr)
|
||||
run_migration_step(new_py, local_new_db, f"IMPORT DATABASE '{export_file}'")
|
||||
print("Import complete.", file=sys.stderr)
|
||||
|
||||
# Rename new kuzu database to old kuzu database name if enabled
|
||||
# If the target is S3, upload the migrated DB now
|
||||
if is_new_s3:
|
||||
# Remove kuzu lock from migrated DB before upload if present
|
||||
lock_file = local_new_db + ".lock"
|
||||
if os.path.exists(lock_file):
|
||||
os.remove(lock_file)
|
||||
|
||||
print(f"⬆️ Uploading new DB to S3: {new_db}", file=sys.stderr)
|
||||
s3 = _get_s3_client()
|
||||
s3.put(local_new_db, new_db, recursive=True)
|
||||
|
||||
# Normalize flags
|
||||
overwrite = bool(overwrite)
|
||||
delete_old = bool(delete_old)
|
||||
|
||||
# Rename/move results into place if requested
|
||||
if overwrite or delete_old:
|
||||
# Remove kuzu lock from migrated DB
|
||||
lock_file = new_db + ".lock"
|
||||
if os.path.exists(lock_file):
|
||||
os.remove(lock_file)
|
||||
rename_databases(old_db, old_version, new_db, delete_old)
|
||||
if _is_s3_path(new_db) or _is_s3_path(old_db):
|
||||
# S3-aware rename
|
||||
_s3_rename_databases(old_db, old_version, new_db, delete_old)
|
||||
else:
|
||||
# Remove kuzu lock from migrated DB
|
||||
lock_file = new_db + ".lock"
|
||||
if os.path.exists(lock_file):
|
||||
os.remove(lock_file)
|
||||
rename_databases(old_db, old_version, new_db, delete_old)
|
||||
|
||||
print("✅ Kuzu graph database migration finished successfully!")
|
||||
|
||||
|
|
@ -224,6 +312,57 @@ def rename_databases(old_db: str, old_version: str, new_db: str, delete_old: boo
|
|||
print(f"Renamed '{src_new}' to '{dst_new}'", file=sys.stderr)
|
||||
|
||||
|
||||
def _s3_rename_databases(old_db: str, old_version: str, new_db: str, delete_old: bool):
|
||||
"""
|
||||
Perform S3-equivalent of rename_databases: optionally back up the original old_db
|
||||
to *_old_<version>, replace it with the new_db contents, and clean up.
|
||||
|
||||
This function handles both file-based and directory-based Kuzu databases by using
|
||||
recursive copy and remove operations provided by s3fs.
|
||||
"""
|
||||
s3 = _get_s3_client()
|
||||
|
||||
# Normalize paths (keep s3:// URIs as they are; s3fs supports them)
|
||||
def _isdir(p: str) -> bool:
|
||||
try:
|
||||
return s3.isdir(p)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
def _isfile(p: str) -> bool:
|
||||
try:
|
||||
return s3.isfile(p)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
base_dir = os.path.dirname(old_db.rstrip("/"))
|
||||
name = os.path.basename(old_db.rstrip("/"))
|
||||
backup_database_name = f"{name}_old_" + old_version.replace(".", "_")
|
||||
backup_base = base_dir + "/" + backup_database_name
|
||||
|
||||
# Back up or delete the original old_db
|
||||
if _isfile(old_db):
|
||||
if not delete_old:
|
||||
s3.copy(old_db, backup_base, recursive=True)
|
||||
print(f"Copied '{old_db}' to '{backup_base}' on S3", file=sys.stderr)
|
||||
s3.rm(old_db, recursive=True)
|
||||
elif _isdir(old_db):
|
||||
if not delete_old:
|
||||
s3.copy(old_db, backup_base, recursive=True)
|
||||
print(f"Copied directory '{old_db}' to '{backup_base}' on S3", file=sys.stderr)
|
||||
s3.rm(old_db, recursive=True)
|
||||
else:
|
||||
print(f"Original database path '{old_db}' not found on S3 for renaming.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Move new into place under the old name
|
||||
target_path = base_dir + "/" + name
|
||||
s3.copy(new_db, target_path, recursive=True)
|
||||
print(f"Copied '{new_db}' to '{target_path}' on S3", file=sys.stderr)
|
||||
# Remove the staging 'new_db' key
|
||||
s3.rm(new_db, recursive=True)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description="Migrate Kùzu DB via PyPI versions",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,192 @@
|
|||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DBOpenError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeDatabase:
|
||||
"""Fake kuzu.Database that fails first, then succeeds."""
|
||||
|
||||
calls = 0
|
||||
|
||||
def __init__(self, path: str, **kwargs):
|
||||
_FakeDatabase.calls += 1
|
||||
if _FakeDatabase.calls == 1:
|
||||
raise _DBOpenError("version mismatch")
|
||||
|
||||
def init_database(self):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeConnection:
|
||||
def __init__(self, db):
|
||||
pass
|
||||
|
||||
def execute(self, query: str, params=None):
|
||||
class _Res:
|
||||
def has_next(self):
|
||||
return False
|
||||
|
||||
def get_next(self):
|
||||
return []
|
||||
|
||||
return _Res()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stub_import(monkeypatch):
|
||||
def _install_stub(name: str, module: ModuleType | None = None) -> ModuleType:
|
||||
mod = module or ModuleType(name)
|
||||
# Mark as package so submodule imports succeed when needed
|
||||
if not hasattr(mod, "__path__"):
|
||||
mod.__path__ = [] # type: ignore[attr-defined]
|
||||
monkeypatch.setitem(sys.modules, name, mod)
|
||||
return mod
|
||||
|
||||
return _install_stub
|
||||
|
||||
|
||||
def _find_repo_root(start_path: str) -> str:
|
||||
"""Walk up directories until we find pyproject.toml (repo root)."""
|
||||
cur = os.path.abspath(start_path)
|
||||
while True:
|
||||
if os.path.exists(os.path.join(cur, "pyproject.toml")):
|
||||
return cur
|
||||
parent = os.path.dirname(cur)
|
||||
if parent == cur:
|
||||
raise RuntimeError("Could not locate repository root from: " + start_path)
|
||||
cur = parent
|
||||
|
||||
|
||||
def _load_adapter_with_stubs(monkeypatch, stub_import):
|
||||
# Provide fake 'kuzu' and submodules used by adapter imports
|
||||
kuzu_mod = stub_import("kuzu")
|
||||
kuzu_mod.__dict__["__version__"] = "0.11.0"
|
||||
|
||||
# Placeholders to satisfy adapter's "from kuzu import Connection" and "from kuzu.database import Database"
|
||||
class _PlaceholderConn:
|
||||
pass
|
||||
|
||||
kuzu_mod.Connection = _PlaceholderConn
|
||||
kuzu_db_mod = stub_import("kuzu.database")
|
||||
|
||||
class _PlaceholderDB:
|
||||
pass
|
||||
|
||||
kuzu_db_mod.Database = _PlaceholderDB
|
||||
|
||||
# Create minimal stub tree for required cognee imports to avoid executing package __init__
|
||||
stub_import("cognee")
|
||||
stub_import("cognee.infrastructure")
|
||||
stub_import("cognee.infrastructure.databases")
|
||||
stub_import("cognee.infrastructure.databases.graph")
|
||||
stub_import("cognee.infrastructure.databases.graph.kuzu")
|
||||
|
||||
# graph_db_interface stub
|
||||
gdi_mod = stub_import("cognee.infrastructure.databases.graph.graph_db_interface")
|
||||
|
||||
class _GraphDBInterface: # bare minimum
|
||||
pass
|
||||
|
||||
def record_graph_changes(fn):
|
||||
return fn
|
||||
|
||||
gdi_mod.GraphDBInterface = _GraphDBInterface
|
||||
gdi_mod.record_graph_changes = record_graph_changes
|
||||
|
||||
# engine.DataPoint stub
|
||||
engine_mod = stub_import("cognee.infrastructure.engine")
|
||||
|
||||
class _DataPoint:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
engine_mod.DataPoint = _DataPoint
|
||||
|
||||
# files.storage.get_file_storage stub
|
||||
storage_pkg = stub_import("cognee.infrastructure.files.storage")
|
||||
storage_pkg.get_file_storage = lambda path: types.SimpleNamespace(
|
||||
ensure_directory_exists=lambda: None
|
||||
)
|
||||
|
||||
# utils.run_sync stub
|
||||
run_sync_mod = stub_import("cognee.infrastructure.utils.run_sync")
|
||||
run_sync_mod.run_sync = lambda coro: None
|
||||
|
||||
# modules.storage.utils JSONEncoder stub
|
||||
utils_mod2 = stub_import("cognee.modules.storage.utils")
|
||||
utils_mod2.JSONEncoder = object
|
||||
|
||||
# shared.logging_utils.get_logger stub
|
||||
logging_utils_mod = stub_import("cognee.shared.logging_utils")
|
||||
|
||||
class _Logger:
|
||||
def debug(self, *a, **k):
|
||||
pass
|
||||
|
||||
def error(self, *a, **k):
|
||||
pass
|
||||
|
||||
logging_utils_mod.get_logger = lambda: _Logger()
|
||||
|
||||
# Now load adapter.py by path
|
||||
repo_root = _find_repo_root(os.path.dirname(__file__))
|
||||
adapter_path = os.path.join(
|
||||
repo_root, "cognee", "infrastructure", "databases", "graph", "kuzu", "adapter.py"
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"cognee.infrastructure.databases.graph.kuzu.adapter", adapter_path
|
||||
)
|
||||
assert spec and spec.loader
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
||||
|
||||
# Replace Database/Connection in the loaded module
|
||||
monkeypatch.setattr(mod, "Database", _FakeDatabase, raising=True)
|
||||
monkeypatch.setattr(mod, "Connection", _FakeConnection, raising=True)
|
||||
|
||||
# Patch migration helpers inside the kuzu_migrate module used by adapter
|
||||
# Load kuzu_migrate similarly
|
||||
km_path = os.path.join(
|
||||
repo_root, "cognee", "infrastructure", "databases", "graph", "kuzu", "kuzu_migrate.py"
|
||||
)
|
||||
km_spec = importlib.util.spec_from_file_location("kuzu_migrate_under_test", km_path)
|
||||
km_mod = importlib.util.module_from_spec(km_spec)
|
||||
assert km_spec and km_spec.loader
|
||||
km_spec.loader.exec_module(km_mod) # type: ignore[attr-defined]
|
||||
|
||||
calls = {"migrated": False}
|
||||
|
||||
def fake_read_version(_):
|
||||
return "0.9.0"
|
||||
|
||||
def fake_migration(**kwargs):
|
||||
calls["migrated"] = True
|
||||
|
||||
monkeypatch.setattr(km_mod, "read_kuzu_storage_version", fake_read_version)
|
||||
monkeypatch.setattr(km_mod, "kuzu_migration", fake_migration)
|
||||
|
||||
# Ensure adapter refers to our loaded km_mod
|
||||
monkeypatch.setitem(
|
||||
sys.modules, "cognee.infrastructure.databases.graph.kuzu.kuzu_migrate", km_mod
|
||||
)
|
||||
|
||||
return mod, calls
|
||||
|
||||
|
||||
def test_adapter_s3_auto_migration(monkeypatch, stub_import):
|
||||
mod, calls = _load_adapter_with_stubs(monkeypatch, stub_import)
|
||||
|
||||
# ensure pull/push do not touch real S3
|
||||
monkeypatch.setattr(mod.KuzuAdapter, "pull_from_s3", lambda self: None)
|
||||
monkeypatch.setattr(mod.KuzuAdapter, "push_to_s3", lambda self: None)
|
||||
|
||||
mod.KuzuAdapter("s3://bucket/db")
|
||||
assert calls["migrated"] is True
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
import io
|
||||
import os
|
||||
import struct
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeS3:
|
||||
"""
|
||||
Minimal fake S3 client implementing the subset used by kuzu_migrate helpers.
|
||||
|
||||
Store layout is a dict mapping string keys to bytes (files). Directories are
|
||||
implicit via key prefixes. Methods operate on s3:// style keys.
|
||||
"""
|
||||
|
||||
def __init__(self, initial: Dict[str, bytes] | None = None):
|
||||
self.store: Dict[str, bytes] = dict(initial or {})
|
||||
|
||||
# Helpers
|
||||
def _norm(self, path: str) -> str:
|
||||
return path.rstrip("/")
|
||||
|
||||
def _is_prefix(self, prefix: str, key: str) -> bool:
|
||||
p = self._norm(prefix)
|
||||
return key == p or key.startswith(p + "/")
|
||||
|
||||
# API used by kuzu_migrate
|
||||
def exists(self, path: str) -> bool:
|
||||
p = self._norm(path)
|
||||
if p in self.store:
|
||||
return True
|
||||
# any key under this prefix implies existence as a directory
|
||||
return any(self._is_prefix(p, k) for k in self.store)
|
||||
|
||||
def isdir(self, path: str) -> bool:
|
||||
p = self._norm(path)
|
||||
# A directory is assumed if there is any key with this prefix and that key isn't exactly the same
|
||||
return any(self._is_prefix(p, k) and k != p for k in self.store)
|
||||
|
||||
def isfile(self, path: str) -> bool:
|
||||
p = self._norm(path)
|
||||
return p in self.store
|
||||
|
||||
def open(self, path: str, mode: str = "rb"):
|
||||
p = self._norm(path)
|
||||
if "r" in mode:
|
||||
if p not in self.store:
|
||||
raise FileNotFoundError(p)
|
||||
return io.BytesIO(self.store[p])
|
||||
elif "w" in mode:
|
||||
buf = io.BytesIO()
|
||||
|
||||
def _close():
|
||||
self.store[p] = buf.getvalue()
|
||||
|
||||
# monkeypatch close so that written data is persisted on close
|
||||
orig_close = buf.close
|
||||
|
||||
def close_wrapper():
|
||||
_close()
|
||||
orig_close()
|
||||
|
||||
buf.close = close_wrapper # type: ignore[assignment]
|
||||
return buf
|
||||
else:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
def copy(self, src: str, dst: str, recursive: bool = True):
|
||||
s = self._norm(src)
|
||||
d = self._norm(dst)
|
||||
if recursive:
|
||||
# copy all keys under src prefix to dst prefix
|
||||
to_copy = [k for k in self.store if self._is_prefix(s, k)]
|
||||
for key in to_copy:
|
||||
new_key = key.replace(s, d, 1)
|
||||
self.store[new_key] = self.store[key]
|
||||
else:
|
||||
if s not in self.store:
|
||||
raise FileNotFoundError(s)
|
||||
self.store[d] = self.store[s]
|
||||
|
||||
def rm(self, path: str, recursive: bool = False):
|
||||
p = self._norm(path)
|
||||
if recursive:
|
||||
for key in list(self.store.keys()):
|
||||
if self._is_prefix(p, key):
|
||||
del self.store[key]
|
||||
else:
|
||||
if p in self.store:
|
||||
del self.store[p]
|
||||
else:
|
||||
raise FileNotFoundError(p)
|
||||
|
||||
|
||||
def _load_module_by_path(path: str) -> ModuleType:
|
||||
spec = importlib.util.spec_from_file_location("kuzu_migrate_under_test", path)
|
||||
assert spec and spec.loader
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
||||
return mod
|
||||
|
||||
|
||||
def _find_repo_root(start_path: str) -> str:
|
||||
cur = os.path.abspath(start_path)
|
||||
while True:
|
||||
if os.path.exists(os.path.join(cur, "pyproject.toml")):
|
||||
return cur
|
||||
parent = os.path.dirname(cur)
|
||||
if parent == cur:
|
||||
raise RuntimeError("Could not locate repository root from: " + start_path)
|
||||
cur = parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def km_module(monkeypatch):
|
||||
# Load the kuzu_migrate module directly from file to avoid importing package __init__
|
||||
repo_root = _find_repo_root(os.path.dirname(__file__))
|
||||
target = os.path.join(
|
||||
repo_root, "cognee", "infrastructure", "databases", "graph", "kuzu", "kuzu_migrate.py"
|
||||
)
|
||||
mod = _load_module_by_path(target)
|
||||
return mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_get_s3_client(monkeypatch, km_module):
|
||||
# Provide each test with its own fake client instance
|
||||
client = _FakeS3()
|
||||
monkeypatch.setattr(km_module, "_get_s3_client", lambda: client)
|
||||
return client
|
||||
|
||||
|
||||
def _make_catalog_bytes(version_code: int) -> bytes:
|
||||
# 4 bytes header skipped + 8 bytes little-endian version code
|
||||
return b"KUZ\x00" + struct.pack("<Q", version_code) + b"padding"
|
||||
|
||||
|
||||
def test_read_kuzu_storage_version_from_s3_directory(monkeypatch, patch_get_s3_client, km_module):
|
||||
s3 = patch_get_s3_client
|
||||
# Simulate a directory with catalog.kz
|
||||
dir_key = "s3://bucket/db"
|
||||
catalog_key = dir_key + "/catalog.kz"
|
||||
s3.store[catalog_key] = _make_catalog_bytes(39) # maps to 0.11.0
|
||||
|
||||
assert km_module.read_kuzu_storage_version(dir_key) == "0.11.0"
|
||||
|
||||
|
||||
def test_s3_rename_file_backup(monkeypatch, patch_get_s3_client, km_module):
|
||||
s3 = patch_get_s3_client
|
||||
|
||||
old_db = "s3://bucket/graph.db"
|
||||
new_db = "s3://bucket/graph_new.db"
|
||||
# seed store
|
||||
s3.store[old_db] = b"OLD"
|
||||
s3.store[new_db] = b"NEW"
|
||||
|
||||
km_module._s3_rename_databases(old_db, "0.9.0", new_db, delete_old=False)
|
||||
|
||||
# old is replaced with new
|
||||
assert s3.store.get(old_db) == b"NEW"
|
||||
# backup exists with version suffix
|
||||
backup = "s3://bucket/graph.db_old_0_9_0"
|
||||
assert s3.store.get(backup) == b"OLD"
|
||||
# staging removed
|
||||
assert new_db not in s3.store
|
||||
|
||||
|
||||
def test_s3_rename_directory_delete_old(monkeypatch, patch_get_s3_client, km_module):
|
||||
s3 = patch_get_s3_client
|
||||
|
||||
old_dir = "s3://bucket/graph_dir"
|
||||
new_dir = "s3://bucket/graph_dir_new"
|
||||
|
||||
# Represent a directory by multiple keys under the prefix
|
||||
s3.store[old_dir + "/catalog.kz"] = b"OLD1"
|
||||
s3.store[old_dir + "/data.bin"] = b"OLD2"
|
||||
|
||||
s3.store[new_dir + "/catalog.kz"] = b"NEW1"
|
||||
s3.store[new_dir + "/data.bin"] = b"NEW2"
|
||||
|
||||
km_module._s3_rename_databases(old_dir, "0.9.0", new_dir, delete_old=True)
|
||||
|
||||
# old dir contents replaced by new
|
||||
assert s3.store.get(old_dir + "/catalog.kz") == b"NEW1"
|
||||
assert s3.store.get(old_dir + "/data.bin") == b"NEW2"
|
||||
|
||||
# no backup created when delete_old=True
|
||||
backup_prefix = os.path.dirname(old_dir) + "/" + os.path.basename(old_dir) + "_old_0_9_0"
|
||||
assert not any(k.startswith(backup_prefix) for k in s3.store)
|
||||
|
||||
# staging removed
|
||||
assert not any(k.startswith(new_dir) for k in s3.store)
|
||||
Loading…
Add table
Reference in a new issue