From 9a2bf0f137dbd94e8b3cde396db309c978943ff5 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Thu, 4 Sep 2025 15:09:48 +0100 Subject: [PATCH] kuzu - improve type inference for connection --- .../databases/graph/kuzu/adapter.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 538ca4fe0..e550889de 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -145,6 +145,12 @@ class KuzuAdapter(GraphDBInterface): except Exception as e: logger.error(f"Failed to initialize Kuzu database: {e}") raise e + + def _get_connection(self) -> Connection: + """Get the connection to the Kuzu database.""" + if not self.connection: + raise RuntimeError("Kuzu database connection not initialized") + return self.connection async def push_to_s3(self) -> None: if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"): @@ -152,9 +158,9 @@ class KuzuAdapter(GraphDBInterface): s3_file_storage = S3FileStorage("") - if self.connection: + if self._get_connection(): async with self.KUZU_ASYNC_LOCK: - self.connection.execute("CHECKPOINT;") + self._get_connection().execute("CHECKPOINT;") s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True) @@ -192,14 +198,14 @@ class KuzuAdapter(GraphDBInterface): def blocking_query() -> List[Tuple[Any, ...]]: try: - if not self.connection: + if not self._get_connection(): logger.debug("Reconnecting to Kuzu database...") self._initialize_connection() - if not self.connection: + if not self._get_connection(): raise RuntimeError("Failed to establish database connection") - result = self.connection.execute(query, params) + result = self._get_connection().execute(query, params) rows = [] if not isinstance(result, list): @@ -233,7 +239,7 @@ class KuzuAdapter(GraphDBInterface): and on exit performs cleanup if necessary. """ try: - yield self.connection + yield self._get_connection() finally: pass @@ -1517,8 +1523,8 @@ class KuzuAdapter(GraphDBInterface): It raises exceptions for failures occurring during deletion processes. """ try: - if self.connection: - self.connection.close() + if self._get_connection(): + self._get_connection().close() self.connection = None if self.db: self.db.close() @@ -1549,7 +1555,7 @@ class KuzuAdapter(GraphDBInterface): occur during file deletions or initializations carefully. """ try: - if self.connection: + if self._get_connection(): self.connection = None if self.db: self.db.close() @@ -1566,20 +1572,23 @@ class KuzuAdapter(GraphDBInterface): # Reinitialize the database self._initialize_connection() - if not self.connection: + if not self._get_connection(): raise RuntimeError("Failed to establish database connection") # Verify the database is empty - result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)") + result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)") if not isinstance(result, list): result = [result] for single_result in result: - count = single_result.get_next()[0] if single_result.has_next() else 0 # type: ignore + _next = single_result.get_next() + if not isinstance(_next, list): + raise RuntimeError("Expected list of results") + count = _next[0] if _next else 0 if count > 0: logger.warning( f"Database still contains {count} nodes after clearing, forcing deletion" ) - self.connection.execute("MATCH (n:Node) DETACH DELETE n") + self._get_connection().execute("MATCH (n:Node) DETACH DELETE n") logger.info("Database cleared successfully") except Exception as e: logger.error(f"Error during database clearing: {e}")