diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 085d7cd00..538ca4fe0 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -10,7 +10,7 @@ from kuzu.database import Database from datetime import datetime, timezone from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Any, List, Union, Optional, Tuple, Type +from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator from cognee.shared.logging_utils import get_logger from cognee.infrastructure.utils.run_sync import run_sync @@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import ( from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import JSONEncoder from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int -from cognee.tasks.temporal_graph.models import Timestamp +from cognee.modules.engine.models.Timestamp import Timestamp logger = get_logger() @@ -167,7 +167,7 @@ class KuzuAdapter(GraphDBInterface): except FileNotFoundError: logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") - async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: + async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]: """ Execute a Kuzu query asynchronously with automatic reconnection. @@ -190,23 +190,32 @@ class KuzuAdapter(GraphDBInterface): loop = asyncio.get_running_loop() params = params or {} - def blocking_query(): + def blocking_query() -> List[Tuple[Any, ...]]: try: if not self.connection: logger.debug("Reconnecting to Kuzu database...") self._initialize_connection() + if not self.connection: + raise RuntimeError("Failed to establish database connection") + result = self.connection.execute(query, params) rows = [] - while result.has_next(): - row = result.get_next() - processed_rows = [] - for val in row: - if hasattr(val, "as_py"): - val = val.as_py() - processed_rows.append(val) - rows.append(tuple(processed_rows)) + if not isinstance(result, list): + result = [result] + + # Handle QueryResult vs List[QueryResult] union type + for single_result in result: + while single_result.has_next(): + row = single_result.get_next() + processed_rows = [] + for val in row: + if hasattr(val, "as_py"): + val = val.as_py() + processed_rows.append(val) + rows.append(tuple(processed_rows)) + return rows except Exception as e: logger.error(f"Query execution failed: {str(e)}") @@ -215,7 +224,7 @@ class KuzuAdapter(GraphDBInterface): return await loop.run_in_executor(self.executor, blocking_query) @asynccontextmanager - async def get_session(self): + async def get_session(self) -> AsyncGenerator[Optional[Connection], None]: """ Get a database session. @@ -255,7 +264,7 @@ class KuzuAdapter(GraphDBInterface): def _edge_query_and_params( self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any] - ) -> Tuple[str, dict]: + ) -> Tuple[str, Dict[str, Any]]: """Build the edge creation query and parameters.""" now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") query = """ @@ -305,7 +314,7 @@ class KuzuAdapter(GraphDBInterface): result = await self.query(query_str, {"id": node_id}) return result[0][0] if result else False - async def add_node(self, node: DataPoint) -> None: + async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None: """ Add a single node to the graph if it doesn't exist. @@ -319,20 +328,30 @@ class KuzuAdapter(GraphDBInterface): - node (DataPoint): The node to be added, represented as a DataPoint. """ try: - properties = node.model_dump() if hasattr(node, "model_dump") else vars(node) + if isinstance(node, str): + # Handle string node ID with properties parameter + node_properties = properties or {} + core_properties = { + "id": node, + "name": str(node_properties.get("name", "")), + "type": str(node_properties.get("type", "")), + } + # Use the passed properties, excluding core fields + other_properties = {k: v for k, v in node_properties.items() + if k not in ["id", "name", "type"]} + else: + # Handle DataPoint object + node_properties = node.model_dump() + core_properties = { + "id": str(node_properties.get("id", "")), + "name": str(node_properties.get("name", "")), + "type": str(node_properties.get("type", "")), + } + # Remove core fields from other properties + other_properties = {k: v for k, v in node_properties.items() + if k not in ["id", "name", "type"]} - # Extract core fields with defaults if not present - core_properties = { - "id": str(properties.get("id", "")), - "name": str(properties.get("name", "")), - "type": str(properties.get("type", "")), - } - - # Remove core fields from other properties - for key in core_properties: - properties.pop(key, None) - - core_properties["properties"] = json.dumps(properties, cls=JSONEncoder) + core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder) # Add timestamps for new node now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") @@ -360,7 +379,7 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to add node: {e}") raise - @record_graph_changes + @record_graph_changes # type: ignore async def add_nodes(self, nodes: List[DataPoint]) -> None: """ Add multiple nodes to the graph in a batch operation. @@ -568,7 +587,7 @@ class KuzuAdapter(GraphDBInterface): ) return result[0][0] if result else False - async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]: + async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]: """ Check if multiple edges exist in a batch operation. @@ -599,7 +618,7 @@ class KuzuAdapter(GraphDBInterface): "to_id": str(to_node), # Ensure string type "relationship_name": str(edge_label), # Ensure string type } - for from_node, to_node, edge_label in edges + for from_node, to_node, edge_label, _ in edges ] # Batch check query with direct string comparison @@ -615,9 +634,21 @@ class KuzuAdapter(GraphDBInterface): results = await self.query(query, {"edges": edge_params}) # Convert results back to tuples and ensure string types - existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results] + # Find the original edge properties for each existing edge + # TODO: get review on this + existing_edges = [] + for row in results: + from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2]) + # Find the original properties from the input edges + original_props = {} + for orig_from, orig_to, orig_rel, orig_props in edges: + if orig_from == from_id and orig_to == to_id and orig_rel == rel_name: + original_props = orig_props + break + existing_edges.append((from_id, to_id, rel_name, original_props)) logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked") + # TODO: otherwise, we can just return dummy properties since they are not used apparently return existing_edges except Exception as e: @@ -626,10 +657,10 @@ class KuzuAdapter(GraphDBInterface): async def add_edge( self, - from_node: str, - to_node: str, + source_id: str, + target_id: str, relationship_name: str, - edge_properties: Dict[str, Any] = {}, + properties: Optional[Dict[str, Any]] = None, ) -> None: """ Add an edge between two nodes. @@ -641,23 +672,23 @@ class KuzuAdapter(GraphDBInterface): Parameters: ----------- - - from_node (str): The identifier of the source node from which the edge originates. - - to_node (str): The identifier of the target node to which the edge points. + - source_id (str): The identifier of the source node from which the edge originates. + - target_id (str): The identifier of the target node to which the edge points. - relationship_name (str): The label of the edge to be created, representing the relationship name. - - edge_properties (Dict[str, Any]): A dictionary containing properties for the edge. - (default {}) + - properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge. + (default None) """ try: query, params = self._edge_query_and_params( - from_node, to_node, relationship_name, edge_properties + source_id, target_id, relationship_name, properties or {} ) await self.query(query, params) except Exception as e: logger.error(f"Failed to add edge: {e}") raise - @record_graph_changes + @record_graph_changes # type: ignore async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: """ Add multiple edges in a batch operation. @@ -712,7 +743,7 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to add edges in batch: {e}") raise - async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: + async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]: """ Get all edges connected to a node. @@ -727,9 +758,8 @@ class KuzuAdapter(GraphDBInterface): Returns: -------- - - List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each - tuple contains (source_node, relationship_name, target_node), with source_node and - target_node as dictionaries of node properties. + - List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each + tuple contains (source_id, relationship_name, target_id, edge_properties). """ query_str = """ MATCH (n:Node)-[r]-(m:Node) @@ -750,12 +780,14 @@ class KuzuAdapter(GraphDBInterface): """ try: results = await self.query(query_str, {"node_id": node_id}) - edges = [] + edges: List[Tuple[str, str, str, Dict[str, Any]]] = [] for row in results: if row and len(row) == 3: source_node = self._parse_node_properties(row[0]) + relationship_name = row[1] target_node = self._parse_node_properties(row[2]) - edges.append((source_node, row[1], target_node)) + # TODO: any edge properties we can add? Adding empty to avoid modifying query without reason + edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings return edges except Exception as e: logger.error(f"Failed to get edges for node {node_id}: {e}") @@ -977,7 +1009,7 @@ class KuzuAdapter(GraphDBInterface): return [] async def get_connections( - self, node_id: str + self, node_id: Union[str, UUID] ) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]: """ Get all nodes connected to a given node. @@ -1019,7 +1051,9 @@ class KuzuAdapter(GraphDBInterface): } """ try: - results = await self.query(query_str, {"node_id": node_id}) + # Convert UUID to string if needed + node_id_str = str(node_id) + results = await self.query(query_str, {"node_id": node_id_str}) edges = [] for row in results: if row and len(row) == 3: @@ -1177,7 +1211,7 @@ class KuzuAdapter(GraphDBInterface): async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] - ) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]: + ) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]: """ Get subgraph for a set of nodes based on type and names. @@ -1225,9 +1259,9 @@ class KuzuAdapter(GraphDBInterface): RETURN n.id, n.name, n.type, n.properties """ node_rows = await self.query(nodes_query, {"ids": all_ids}) - nodes: List[Tuple[str, dict]] = [] + nodes: List[Tuple[str, Dict[str, Any]]] = [] for node_id, name, typ, props in node_rows: - data = {"id": node_id, "name": name, "type": typ} + data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ} if props: try: data.update(json.loads(props)) @@ -1241,22 +1275,22 @@ class KuzuAdapter(GraphDBInterface): RETURN a.id, b.id, r.relationship_name, r.properties """ edge_rows = await self.query(edges_query, {"ids": all_ids}) - edges: List[Tuple[str, str, str, dict]] = [] + edges: List[Tuple[str, str, str, Dict[str, Any]]] = [] for from_id, to_id, rel_type, props in edge_rows: - data = {} + edge_data: Dict[str, Any] = {} if props: try: - data = json.loads(props) + edge_data = json.loads(props) except json.JSONDecodeError: logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}") - edges.append((from_id, to_id, rel_type, data)) + edges.append((from_id, to_id, rel_type, edge_data)) - return nodes, edges + return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs async def get_filtered_graph_data( self, attribute_filters: List[Dict[str, List[Union[str, int]]]] - ): + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """ Get filtered nodes and relationships based on attributes. @@ -1299,7 +1333,7 @@ class KuzuAdapter(GraphDBInterface): ) return ([n[0] for n in nodes], [e[0] for e in edges]) - async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]: + async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]: """ Get metrics on graph structure and connectivity. @@ -1322,8 +1356,8 @@ class KuzuAdapter(GraphDBInterface): try: # Get basic graph data nodes, edges = await self.get_model_independent_graph_data() - num_nodes = len(nodes[0]["nodes"]) if nodes else 0 - num_edges = len(edges[0]["elements"]) if edges else 0 + num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string? + num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string? # Calculate mandatory metrics mandatory_metrics = { @@ -1531,9 +1565,16 @@ class KuzuAdapter(GraphDBInterface): # Reinitialize the database self._initialize_connection() + + if not self.connection: + raise RuntimeError("Failed to establish database connection") + # Verify the database is empty result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)") - count = result.get_next()[0] if result.has_next() else 0 + if not isinstance(result, list): + result = [result] + for single_result in result: + count = single_result.get_next()[0] if single_result.has_next() else 0 # type: ignore if count > 0: logger.warning( f"Database still contains {count} nodes after clearing, forcing deletion" @@ -1544,7 +1585,7 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Error during database clearing: {e}") raise - async def get_document_subgraph(self, data_id: str): + async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]: """ Get all nodes that should be deleted when removing a document. @@ -1616,7 +1657,7 @@ class KuzuAdapter(GraphDBInterface): "orphan_types": result[0][4], } - async def get_degree_one_nodes(self, node_type: str): + async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]: """ Get all nodes that have only one connection. @@ -1769,8 +1810,8 @@ class KuzuAdapter(GraphDBInterface): ids: List[str] = [] if time_from and time_to: - time_from = date_to_int(time_from) - time_to = date_to_int(time_to) + time_from_int = date_to_int(time_from) + time_to_int = date_to_int(time_to) cypher = f""" MATCH (n:Node) @@ -1782,13 +1823,13 @@ class KuzuAdapter(GraphDBInterface): WHEN t_str IS NULL OR t_str = '' THEN NULL ELSE CAST(t_str AS INT64) END AS t - WHERE t >= {time_from} - AND t <= {time_to} + WHERE t >= {time_from_int} + AND t <= {time_to_int} RETURN n.id as id """ elif time_from: - time_from = date_to_int(time_from) + time_from_int = date_to_int(time_from) cypher = f""" MATCH (n:Node) @@ -1800,12 +1841,12 @@ class KuzuAdapter(GraphDBInterface): WHEN t_str IS NULL OR t_str = '' THEN NULL ELSE CAST(t_str AS INT64) END AS t - WHERE t >= {time_from} + WHERE t >= {time_from_int} RETURN n.id as id """ elif time_to: - time_to = date_to_int(time_to) + time_to_int = date_to_int(time_to) cypher = f""" MATCH (n:Node) @@ -1817,12 +1858,12 @@ class KuzuAdapter(GraphDBInterface): WHEN t_str IS NULL OR t_str = '' THEN NULL ELSE CAST(t_str AS INT64) END AS t - WHERE t <= {time_to} + WHERE t <= {time_to_int} RETURN n.id as id """ else: - return ids + return ", ".join(f"'{uid}'" for uid in ids) time_nodes = await self.query(cypher) time_ids_list = [item[0] for item in time_nodes]