diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 015dcaa78..2cfe2adca 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -4,7 +4,7 @@ import os import json import asyncio import tempfile -from uuid import UUID +from uuid import UUID, uuid5, NAMESPACE_OID from kuzu import Connection from kuzu.database import Database from datetime import datetime, timezone @@ -23,9 +23,14 @@ from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import JSONEncoder from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.tasks.temporal_graph.models import Timestamp +from cognee.infrastructure.databases.cache.config import get_cache_config logger = get_logger() +cache_config = get_cache_config() +if cache_config.caching: + from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine + class KuzuAdapter(GraphDBInterface): """ @@ -39,12 +44,18 @@ class KuzuAdapter(GraphDBInterface): def __init__(self, db_path: str): """Initialize Kuzu database connection and schema.""" + self.open_connections = 0 + self._is_closed = False self.db_path = db_path # Path for the database directory self.db: Optional[Database] = None self.connection: Optional[Connection] = None self.executor = ThreadPoolExecutor() - self._initialize_connection() + if cache_config.caching: + self.redis_lock = get_cache_engine(lock_key="kuzu-lock-" + str(uuid5(NAMESPACE_OID, db_path))) + else: + self._initialize_connection() self.KUZU_ASYNC_LOCK = asyncio.Lock() + self._counter_lock = asyncio.Lock() def _initialize_connection(self) -> None: """Initialize the Kuzu database connection and schema.""" @@ -212,7 +223,41 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Query execution failed: {str(e)}") raise - return await loop.run_in_executor(self.executor, blocking_query) + if cache_config.caching: + if self._is_closed: + self.reopen() + async with self._counter_lock: + self.open_connections += 1 + logger.info(f"Open connections after open: {self.open_connections}") + + + result = await loop.run_in_executor(self.executor, blocking_query) + + if cache_config.caching: + async with self._counter_lock: + self.open_connections -= 1 + logger.info(f"Opened connections after closing {self.open_connections}") + if self.open_connections == 0: + self.connection.execute('CHECKPOINT;') + self.close() + + return result + + def close(self): + if self.connection: + del self.connection + self.connection = None + if self.db: + del self.db + self.db = None + self._is_closed = True + logger.info(f"Kuzu database closed successfully") + + def reopen(self): + if self._is_closed: + self._is_closed = False + self._initialize_connection() + logger.info(f"Kuzu database re-opened successfully") @asynccontextmanager async def get_session(self):