kuzu - improve type inference for connection

This commit is contained in:
Daulet Amirkhanov 2025-09-04 15:09:48 +01:00
parent b59087841b
commit 9a2bf0f137

View file

@ -145,6 +145,12 @@ class KuzuAdapter(GraphDBInterface):
except Exception as e:
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
def _get_connection(self) -> Connection:
"""Get the connection to the Kuzu database."""
if not self.connection:
raise RuntimeError("Kuzu database connection not initialized")
return self.connection
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
@ -152,9 +158,9 @@ class KuzuAdapter(GraphDBInterface):
s3_file_storage = S3FileStorage("")
if self.connection:
if self._get_connection():
async with self.KUZU_ASYNC_LOCK:
self.connection.execute("CHECKPOINT;")
self._get_connection().execute("CHECKPOINT;")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
@ -192,14 +198,14 @@ class KuzuAdapter(GraphDBInterface):
def blocking_query() -> List[Tuple[Any, ...]]:
try:
if not self.connection:
if not self._get_connection():
logger.debug("Reconnecting to Kuzu database...")
self._initialize_connection()
if not self.connection:
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
result = self.connection.execute(query, params)
result = self._get_connection().execute(query, params)
rows = []
if not isinstance(result, list):
@ -233,7 +239,7 @@ class KuzuAdapter(GraphDBInterface):
and on exit performs cleanup if necessary.
"""
try:
yield self.connection
yield self._get_connection()
finally:
pass
@ -1517,8 +1523,8 @@ class KuzuAdapter(GraphDBInterface):
It raises exceptions for failures occurring during deletion processes.
"""
try:
if self.connection:
self.connection.close()
if self._get_connection():
self._get_connection().close()
self.connection = None
if self.db:
self.db.close()
@ -1549,7 +1555,7 @@ class KuzuAdapter(GraphDBInterface):
occur during file deletions or initializations carefully.
"""
try:
if self.connection:
if self._get_connection():
self.connection = None
if self.db:
self.db.close()
@ -1566,20 +1572,23 @@ class KuzuAdapter(GraphDBInterface):
# Reinitialize the database
self._initialize_connection()
if not self.connection:
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
# Verify the database is empty
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
if not isinstance(result, list):
result = [result]
for single_result in result:
count = single_result.get_next()[0] if single_result.has_next() else 0 # type: ignore
_next = single_result.get_next()
if not isinstance(_next, list):
raise RuntimeError("Expected list of results")
count = _next[0] if _next else 0
if count > 0:
logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion"
)
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
self._get_connection().execute("MATCH (n:Node) DETACH DELETE n")
logger.info("Database cleared successfully")
except Exception as e:
logger.error(f"Error during database clearing: {e}")