feat: adds connection execture wrapper to keep track open executions
This commit is contained in:
parent
76c4a4bd4c
commit
e3964cb208
2 changed files with 45 additions and 15 deletions
|
|
@ -40,7 +40,7 @@ def create_cache_engine(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_engine() -> CacheDBInterface:
|
def get_cache_engine(lock_key: str) -> CacheDBInterface:
|
||||||
"""
|
"""
|
||||||
Returns a cache adapter instance using current context configuration.
|
Returns a cache adapter instance using current context configuration.
|
||||||
"""
|
"""
|
||||||
|
|
@ -48,7 +48,7 @@ def get_cache_engine() -> CacheDBInterface:
|
||||||
return create_cache_engine(
|
return create_cache_engine(
|
||||||
cache_host=config.cache_host,
|
cache_host=config.cache_host,
|
||||||
cache_port=config.cache_port,
|
cache_port=config.cache_port,
|
||||||
lock_key=config.lock_key,
|
lock_key=lock_key,
|
||||||
agentic_lock_expire=config.agentic_lock_expire,
|
agentic_lock_expire=config.agentic_lock_expire,
|
||||||
agentic_lock_timeout=config.agentic_lock_timeout,
|
agentic_lock_timeout=config.agentic_lock_timeout,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
from uuid import UUID
|
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||||
from kuzu import Connection
|
from kuzu import Connection
|
||||||
from kuzu.database import Database
|
from kuzu.database import Database
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
@ -23,6 +23,12 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||||
from cognee.tasks.temporal_graph.models import Timestamp
|
from cognee.tasks.temporal_graph.models import Timestamp
|
||||||
|
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
|
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
|
||||||
|
if cache_config.caching:
|
||||||
|
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -39,11 +45,17 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
"""Initialize Kuzu database connection and schema."""
|
"""Initialize Kuzu database connection and schema."""
|
||||||
|
self.open_connections = 0
|
||||||
self.db_path = db_path # Path for the database directory
|
self.db_path = db_path # Path for the database directory
|
||||||
self.db: Optional[Database] = None
|
self.db: Optional[Database] = None
|
||||||
self.connection: Optional[Connection] = None
|
self.connection: Optional[Connection] = None
|
||||||
self.executor = ThreadPoolExecutor()
|
self.executor = ThreadPoolExecutor()
|
||||||
self._initialize_connection()
|
if cache_config.caching:
|
||||||
|
self.redis_cache = get_cache_engine(
|
||||||
|
lock_key="kuzu-lock-" + (str)(uuid5(NAMESPACE_OID, self.db_path))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._initialize_connection()
|
||||||
self.KUZU_ASYNC_LOCK = asyncio.Lock()
|
self.KUZU_ASYNC_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
def _initialize_connection(self) -> None:
|
def _initialize_connection(self) -> None:
|
||||||
|
|
@ -154,7 +166,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
if self.connection:
|
if self.connection:
|
||||||
async with self.KUZU_ASYNC_LOCK:
|
async with self.KUZU_ASYNC_LOCK:
|
||||||
self.connection.execute("CHECKPOINT;")
|
self.connection_execute_wrapper("CHECKPOINT;")
|
||||||
|
|
||||||
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
||||||
|
|
||||||
|
|
@ -167,6 +179,30 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||||
|
|
||||||
|
def connection_execute_wrapper(self, query: str, parameters: dict[str, Any] | None = None):
|
||||||
|
if cache_config.caching:
|
||||||
|
self.open_connections += 1
|
||||||
|
logger.info(f"Number of connections opened after opening: {self.open_connections}")
|
||||||
|
try:
|
||||||
|
if not self.connection:
|
||||||
|
logger.info("Reconnecting to Kuzu database...")
|
||||||
|
self._initialize_connection()
|
||||||
|
result = self.connection.execute(query, parameters=parameters)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to execute query: {e}")
|
||||||
|
self.open_connections -= 1
|
||||||
|
logger.info(f"Number of connections after query failure: {self.open_connections}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self.open_connections -= 1
|
||||||
|
logger.info(f"Number of connections closed after executing: {self.open_connections}")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
if not self.connection:
|
||||||
|
logger.info("Reconnecting to Kuzu database...")
|
||||||
|
self._initialize_connection()
|
||||||
|
return self.connection.execute(query, parameters=parameters)
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
||||||
"""
|
"""
|
||||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||||
|
|
@ -192,11 +228,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
def blocking_query():
|
def blocking_query():
|
||||||
try:
|
try:
|
||||||
if not self.connection:
|
result = self.connection_execute_wrapper(query, params)
|
||||||
logger.debug("Reconnecting to Kuzu database...")
|
|
||||||
self._initialize_connection()
|
|
||||||
|
|
||||||
result = self.connection.execute(query, params)
|
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
while result.has_next():
|
while result.has_next():
|
||||||
|
|
@ -1562,16 +1594,14 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
await file_storage.remove_all()
|
await file_storage.remove_all()
|
||||||
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
||||||
|
|
||||||
# Reinitialize the database
|
|
||||||
self._initialize_connection()
|
|
||||||
# Verify the database is empty
|
# Verify the database is empty
|
||||||
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
|
result = self.connection_execute_wrapper("MATCH (n:Node) RETURN COUNT(n)")
|
||||||
count = result.get_next()[0] if result.has_next() else 0
|
count = result.get_next()[0] if result.has_next() else 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Database still contains {count} nodes after clearing, forcing deletion"
|
f"Database still contains {count} nodes after clearing, forcing deletion"
|
||||||
)
|
)
|
||||||
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
|
self.connection_execute_wrapper("MATCH (n:Node) DETACH DELETE n")
|
||||||
logger.info("Database cleared successfully")
|
logger.info("Database cleared successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during database clearing: {e}")
|
logger.error(f"Error during database clearing: {e}")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue