Update adapter.py
This commit is contained in:
parent
03eedddf29
commit
7ec1c75bee
1 changed files with 48 additions and 3 deletions
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
from uuid import UUID
|
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||||
from kuzu import Connection
|
from kuzu import Connection
|
||||||
from kuzu.database import Database
|
from kuzu.database import Database
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
@ -23,9 +23,14 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||||
from cognee.tasks.temporal_graph.models import Timestamp
|
from cognee.tasks.temporal_graph.models import Timestamp
|
||||||
|
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
cache_config = get_cache_config()
|
||||||
|
if cache_config.caching:
|
||||||
|
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||||
|
|
||||||
|
|
||||||
class KuzuAdapter(GraphDBInterface):
|
class KuzuAdapter(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
@ -39,12 +44,18 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
"""Initialize Kuzu database connection and schema."""
|
"""Initialize Kuzu database connection and schema."""
|
||||||
|
self.open_connections = 0
|
||||||
|
self._is_closed = False
|
||||||
self.db_path = db_path # Path for the database directory
|
self.db_path = db_path # Path for the database directory
|
||||||
self.db: Optional[Database] = None
|
self.db: Optional[Database] = None
|
||||||
self.connection: Optional[Connection] = None
|
self.connection: Optional[Connection] = None
|
||||||
self.executor = ThreadPoolExecutor()
|
self.executor = ThreadPoolExecutor()
|
||||||
self._initialize_connection()
|
if cache_config.caching:
|
||||||
|
self.redis_lock = get_cache_engine(lock_key="kuzu-lock-" + str(uuid5(NAMESPACE_OID, db_path)))
|
||||||
|
else:
|
||||||
|
self._initialize_connection()
|
||||||
self.KUZU_ASYNC_LOCK = asyncio.Lock()
|
self.KUZU_ASYNC_LOCK = asyncio.Lock()
|
||||||
|
self._counter_lock = asyncio.Lock()
|
||||||
|
|
||||||
def _initialize_connection(self) -> None:
|
def _initialize_connection(self) -> None:
|
||||||
"""Initialize the Kuzu database connection and schema."""
|
"""Initialize the Kuzu database connection and schema."""
|
||||||
|
|
@ -212,7 +223,41 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Query execution failed: {str(e)}")
|
logger.error(f"Query execution failed: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return await loop.run_in_executor(self.executor, blocking_query)
|
if cache_config.caching:
|
||||||
|
if self._is_closed:
|
||||||
|
self.reopen()
|
||||||
|
async with self._counter_lock:
|
||||||
|
self.open_connections += 1
|
||||||
|
logger.info(f"Open connections after open: {self.open_connections}")
|
||||||
|
|
||||||
|
|
||||||
|
result = await loop.run_in_executor(self.executor, blocking_query)
|
||||||
|
|
||||||
|
if cache_config.caching:
|
||||||
|
async with self._counter_lock:
|
||||||
|
self.open_connections -= 1
|
||||||
|
logger.info(f"Opened connections after closing {self.open_connections}")
|
||||||
|
if self.open_connections == 0:
|
||||||
|
self.connection.execute('CHECKPOINT;')
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.connection:
|
||||||
|
del self.connection
|
||||||
|
self.connection = None
|
||||||
|
if self.db:
|
||||||
|
del self.db
|
||||||
|
self.db = None
|
||||||
|
self._is_closed = True
|
||||||
|
logger.info(f"Kuzu database closed successfully")
|
||||||
|
|
||||||
|
def reopen(self):
|
||||||
|
if self._is_closed:
|
||||||
|
self._is_closed = False
|
||||||
|
self._initialize_connection()
|
||||||
|
logger.info(f"Kuzu database re-opened successfully")
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_session(self):
|
async def get_session(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue