feat: adds connection execture wrapper to keep track open executions

This commit is contained in:
hajdul88 2025-10-07 15:47:54 +02:00
parent 76c4a4bd4c
commit e3964cb208
2 changed files with 45 additions and 15 deletions

View file

@ -40,7 +40,7 @@ def create_cache_engine(
) )
def get_cache_engine() -> CacheDBInterface: def get_cache_engine(lock_key: str) -> CacheDBInterface:
""" """
Returns a cache adapter instance using current context configuration. Returns a cache adapter instance using current context configuration.
""" """
@ -48,7 +48,7 @@ def get_cache_engine() -> CacheDBInterface:
return create_cache_engine( return create_cache_engine(
cache_host=config.cache_host, cache_host=config.cache_host,
cache_port=config.cache_port, cache_port=config.cache_port,
lock_key=config.lock_key, lock_key=lock_key,
agentic_lock_expire=config.agentic_lock_expire, agentic_lock_expire=config.agentic_lock_expire,
agentic_lock_timeout=config.agentic_lock_timeout, agentic_lock_timeout=config.agentic_lock_timeout,
) )

View file

@ -4,7 +4,7 @@ import os
import json import json
import asyncio import asyncio
import tempfile import tempfile
from uuid import UUID from uuid import UUID, uuid5, NAMESPACE_OID
from kuzu import Connection from kuzu import Connection
from kuzu.database import Database from kuzu.database import Database
from datetime import datetime, timezone from datetime import datetime, timezone
@ -23,6 +23,12 @@ from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp from cognee.tasks.temporal_graph.models import Timestamp
from cognee.infrastructure.databases.cache.config import CacheConfig
cache_config = CacheConfig()
if cache_config.caching:
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
logger = get_logger() logger = get_logger()
@ -39,11 +45,17 @@ class KuzuAdapter(GraphDBInterface):
def __init__(self, db_path: str): def __init__(self, db_path: str):
"""Initialize Kuzu database connection and schema.""" """Initialize Kuzu database connection and schema."""
self.open_connections = 0
self.db_path = db_path # Path for the database directory self.db_path = db_path # Path for the database directory
self.db: Optional[Database] = None self.db: Optional[Database] = None
self.connection: Optional[Connection] = None self.connection: Optional[Connection] = None
self.executor = ThreadPoolExecutor() self.executor = ThreadPoolExecutor()
self._initialize_connection() if cache_config.caching:
self.redis_cache = get_cache_engine(
lock_key="kuzu-lock-" + (str)(uuid5(NAMESPACE_OID, self.db_path))
)
else:
self._initialize_connection()
self.KUZU_ASYNC_LOCK = asyncio.Lock() self.KUZU_ASYNC_LOCK = asyncio.Lock()
def _initialize_connection(self) -> None: def _initialize_connection(self) -> None:
@ -154,7 +166,7 @@ class KuzuAdapter(GraphDBInterface):
if self.connection: if self.connection:
async with self.KUZU_ASYNC_LOCK: async with self.KUZU_ASYNC_LOCK:
self.connection.execute("CHECKPOINT;") self.connection_execute_wrapper("CHECKPOINT;")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True) s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
@ -167,6 +179,30 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
def connection_execute_wrapper(self, query: str, parameters: dict[str, Any] | None = None):
if cache_config.caching:
self.open_connections += 1
logger.info(f"Number of connections opened after opening: {self.open_connections}")
try:
if not self.connection:
logger.info("Reconnecting to Kuzu database...")
self._initialize_connection()
result = self.connection.execute(query, parameters=parameters)
except Exception as e:
logger.error(f"Failed to execute query: {e}")
self.open_connections -= 1
logger.info(f"Number of connections after query failure: {self.open_connections}")
raise e
self.open_connections -= 1
logger.info(f"Number of connections closed after executing: {self.open_connections}")
return result
else:
if not self.connection:
logger.info("Reconnecting to Kuzu database...")
self._initialize_connection()
return self.connection.execute(query, parameters=parameters)
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
""" """
Execute a Kuzu query asynchronously with automatic reconnection. Execute a Kuzu query asynchronously with automatic reconnection.
@ -192,11 +228,7 @@ class KuzuAdapter(GraphDBInterface):
def blocking_query(): def blocking_query():
try: try:
if not self.connection: result = self.connection_execute_wrapper(query, params)
logger.debug("Reconnecting to Kuzu database...")
self._initialize_connection()
result = self.connection.execute(query, params)
rows = [] rows = []
while result.has_next(): while result.has_next():
@ -1562,16 +1594,14 @@ class KuzuAdapter(GraphDBInterface):
await file_storage.remove_all() await file_storage.remove_all()
logger.info(f"Deleted Kuzu database files at {self.db_path}") logger.info(f"Deleted Kuzu database files at {self.db_path}")
# Reinitialize the database
self._initialize_connection()
# Verify the database is empty # Verify the database is empty
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)") result = self.connection_execute_wrapper("MATCH (n:Node) RETURN COUNT(n)")
count = result.get_next()[0] if result.has_next() else 0 count = result.get_next()[0] if result.has_next() else 0
if count > 0: if count > 0:
logger.warning( logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion" f"Database still contains {count} nodes after clearing, forcing deletion"
) )
self.connection.execute("MATCH (n:Node) DETACH DELETE n") self.connection_execute_wrapper("MATCH (n:Node) DETACH DELETE n")
logger.info("Database cleared successfully") logger.info("Database cleared successfully")
except Exception as e: except Exception as e:
logger.error(f"Error during database clearing: {e}") logger.error(f"Error during database clearing: {e}")