From e3964cb2084d05da9e1ea8a5860c7ffb470f759e Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 7 Oct 2025 15:47:54 +0200 Subject: [PATCH] feat: adds connection execture wrapper to keep track open executions --- .../databases/cache/get_cache_engine.py | 4 +- .../databases/graph/kuzu/adapter.py | 56 ++++++++++++++----- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/cognee/infrastructure/databases/cache/get_cache_engine.py b/cognee/infrastructure/databases/cache/get_cache_engine.py index e22beb549..e14fdb31c 100644 --- a/cognee/infrastructure/databases/cache/get_cache_engine.py +++ b/cognee/infrastructure/databases/cache/get_cache_engine.py @@ -40,7 +40,7 @@ def create_cache_engine( ) -def get_cache_engine() -> CacheDBInterface: +def get_cache_engine(lock_key: str) -> CacheDBInterface: """ Returns a cache adapter instance using current context configuration. """ @@ -48,7 +48,7 @@ def get_cache_engine() -> CacheDBInterface: return create_cache_engine( cache_host=config.cache_host, cache_port=config.cache_port, - lock_key=config.lock_key, + lock_key=lock_key, agentic_lock_expire=config.agentic_lock_expire, agentic_lock_timeout=config.agentic_lock_timeout, ) diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 015dcaa78..177a66aab 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -4,7 +4,7 @@ import os import json import asyncio import tempfile -from uuid import UUID +from uuid import UUID, uuid5, NAMESPACE_OID from kuzu import Connection from kuzu.database import Database from datetime import datetime, timezone @@ -23,6 +23,12 @@ from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import JSONEncoder from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.tasks.temporal_graph.models import Timestamp +from cognee.infrastructure.databases.cache.config import CacheConfig + +cache_config = CacheConfig() + +if cache_config.caching: + from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine logger = get_logger() @@ -39,11 +45,17 @@ class KuzuAdapter(GraphDBInterface): def __init__(self, db_path: str): """Initialize Kuzu database connection and schema.""" + self.open_connections = 0 self.db_path = db_path # Path for the database directory self.db: Optional[Database] = None self.connection: Optional[Connection] = None self.executor = ThreadPoolExecutor() - self._initialize_connection() + if cache_config.caching: + self.redis_cache = get_cache_engine( + lock_key="kuzu-lock-" + (str)(uuid5(NAMESPACE_OID, self.db_path)) + ) + else: + self._initialize_connection() self.KUZU_ASYNC_LOCK = asyncio.Lock() def _initialize_connection(self) -> None: @@ -154,7 +166,7 @@ class KuzuAdapter(GraphDBInterface): if self.connection: async with self.KUZU_ASYNC_LOCK: - self.connection.execute("CHECKPOINT;") + self.connection_execute_wrapper("CHECKPOINT;") s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True) @@ -167,6 +179,30 @@ class KuzuAdapter(GraphDBInterface): except FileNotFoundError: logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") + def connection_execute_wrapper(self, query: str, parameters: dict[str, Any] | None = None): + if cache_config.caching: + self.open_connections += 1 + logger.info(f"Number of connections opened after opening: {self.open_connections}") + try: + if not self.connection: + logger.info("Reconnecting to Kuzu database...") + self._initialize_connection() + result = self.connection.execute(query, parameters=parameters) + except Exception as e: + logger.error(f"Failed to execute query: {e}") + self.open_connections -= 1 + logger.info(f"Number of connections after query failure: {self.open_connections}") + raise e + + self.open_connections -= 1 + logger.info(f"Number of connections closed after executing: {self.open_connections}") + return result + else: + if not self.connection: + logger.info("Reconnecting to Kuzu database...") + self._initialize_connection() + return self.connection.execute(query, parameters=parameters) + async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: """ Execute a Kuzu query asynchronously with automatic reconnection. @@ -192,11 +228,7 @@ class KuzuAdapter(GraphDBInterface): def blocking_query(): try: - if not self.connection: - logger.debug("Reconnecting to Kuzu database...") - self._initialize_connection() - - result = self.connection.execute(query, params) + result = self.connection_execute_wrapper(query, params) rows = [] while result.has_next(): @@ -1562,16 +1594,14 @@ class KuzuAdapter(GraphDBInterface): await file_storage.remove_all() logger.info(f"Deleted Kuzu database files at {self.db_path}") - # Reinitialize the database - self._initialize_connection() # Verify the database is empty - result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)") + result = self.connection_execute_wrapper("MATCH (n:Node) RETURN COUNT(n)") count = result.get_next()[0] if result.has_next() else 0 if count > 0: logger.warning( f"Database still contains {count} nodes after clearing, forcing deletion" ) - self.connection.execute("MATCH (n:Node) DETACH DELETE n") + self.connection_execute_wrapper("MATCH (n:Node) DETACH DELETE n") logger.info("Database cleared successfully") except Exception as e: logger.error(f"Error during database clearing: {e}") @@ -1860,4 +1890,4 @@ class KuzuAdapter(GraphDBInterface): time_nodes = await self.query(cypher) time_ids_list = [item[0] for item in time_nodes] - return ", ".join(f"'{uid}'" for uid in time_ids_list) \ No newline at end of file + return ", ".join(f"'{uid}'" for uid in time_ids_list)