mypy: first fix KuzuAdapter mypy errors

This commit is contained in:
Daulet Amirkhanov 2025-09-04 15:03:36 +01:00
parent 86f3d46bf5
commit 6b2301ff28

View file

@ -10,7 +10,7 @@ from kuzu.database import Database
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.modules.engine.models.Timestamp import Timestamp
logger = get_logger()
@ -167,7 +167,7 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
"""
Execute a Kuzu query asynchronously with automatic reconnection.
@ -190,23 +190,32 @@ class KuzuAdapter(GraphDBInterface):
loop = asyncio.get_running_loop()
params = params or {}
def blocking_query():
def blocking_query() -> List[Tuple[Any, ...]]:
try:
if not self.connection:
logger.debug("Reconnecting to Kuzu database...")
self._initialize_connection()
if not self.connection:
raise RuntimeError("Failed to establish database connection")
result = self.connection.execute(query, params)
rows = []
while result.has_next():
row = result.get_next()
processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
if not isinstance(result, list):
result = [result]
# Handle QueryResult vs List[QueryResult] union type
for single_result in result:
while single_result.has_next():
row = single_result.get_next()
processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
return rows
except Exception as e:
logger.error(f"Query execution failed: {str(e)}")
@ -215,7 +224,7 @@ class KuzuAdapter(GraphDBInterface):
return await loop.run_in_executor(self.executor, blocking_query)
@asynccontextmanager
async def get_session(self):
async def get_session(self) -> AsyncGenerator[Optional[Connection], None]:
"""
Get a database session.
@ -255,7 +264,7 @@ class KuzuAdapter(GraphDBInterface):
def _edge_query_and_params(
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
) -> Tuple[str, dict]:
) -> Tuple[str, Dict[str, Any]]:
"""Build the edge creation query and parameters."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
query = """
@ -305,7 +314,7 @@ class KuzuAdapter(GraphDBInterface):
result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False
async def add_node(self, node: DataPoint) -> None:
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
"""
Add a single node to the graph if it doesn't exist.
@ -319,20 +328,30 @@ class KuzuAdapter(GraphDBInterface):
- node (DataPoint): The node to be added, represented as a DataPoint.
"""
try:
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
if isinstance(node, str):
# Handle string node ID with properties parameter
node_properties = properties or {}
core_properties = {
"id": node,
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Use the passed properties, excluding core fields
other_properties = {k: v for k, v in node_properties.items()
if k not in ["id", "name", "type"]}
else:
# Handle DataPoint object
node_properties = node.model_dump()
core_properties = {
"id": str(node_properties.get("id", "")),
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Remove core fields from other properties
other_properties = {k: v for k, v in node_properties.items()
if k not in ["id", "name", "type"]}
# Extract core fields with defaults if not present
core_properties = {
"id": str(properties.get("id", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
# Remove core fields from other properties
for key in core_properties:
properties.pop(key, None)
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
# Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
@ -360,7 +379,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add node: {e}")
raise
@record_graph_changes
@record_graph_changes # type: ignore
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the graph in a batch operation.
@ -568,7 +587,7 @@ class KuzuAdapter(GraphDBInterface):
)
return result[0][0] if result else False
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]:
"""
Check if multiple edges exist in a batch operation.
@ -599,7 +618,7 @@ class KuzuAdapter(GraphDBInterface):
"to_id": str(to_node), # Ensure string type
"relationship_name": str(edge_label), # Ensure string type
}
for from_node, to_node, edge_label in edges
for from_node, to_node, edge_label, _ in edges
]
# Batch check query with direct string comparison
@ -615,9 +634,21 @@ class KuzuAdapter(GraphDBInterface):
results = await self.query(query, {"edges": edge_params})
# Convert results back to tuples and ensure string types
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
# Find the original edge properties for each existing edge
# TODO: get review on this
existing_edges = []
for row in results:
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
# Find the original properties from the input edges
original_props = {}
for orig_from, orig_to, orig_rel, orig_props in edges:
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
original_props = orig_props
break
existing_edges.append((from_id, to_id, rel_name, original_props))
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
# TODO: otherwise, we can just return dummy properties since they are not used apparently
return existing_edges
except Exception as e:
@ -626,10 +657,10 @@ class KuzuAdapter(GraphDBInterface):
async def add_edge(
self,
from_node: str,
to_node: str,
source_id: str,
target_id: str,
relationship_name: str,
edge_properties: Dict[str, Any] = {},
properties: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add an edge between two nodes.
@ -641,23 +672,23 @@ class KuzuAdapter(GraphDBInterface):
Parameters:
-----------
- from_node (str): The identifier of the source node from which the edge originates.
- to_node (str): The identifier of the target node to which the edge points.
- source_id (str): The identifier of the source node from which the edge originates.
- target_id (str): The identifier of the target node to which the edge points.
- relationship_name (str): The label of the edge to be created, representing the
relationship name.
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
(default {})
- properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge.
(default None)
"""
try:
query, params = self._edge_query_and_params(
from_node, to_node, relationship_name, edge_properties
source_id, target_id, relationship_name, properties or {}
)
await self.query(query, params)
except Exception as e:
logger.error(f"Failed to add edge: {e}")
raise
@record_graph_changes
@record_graph_changes # type: ignore
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""
Add multiple edges in a batch operation.
@ -712,7 +743,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add edges in batch: {e}")
raise
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
"""
Get all edges connected to a node.
@ -727,9 +758,8 @@ class KuzuAdapter(GraphDBInterface):
Returns:
--------
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_node, relationship_name, target_node), with source_node and
target_node as dictionaries of node properties.
- List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_id, relationship_name, target_id, edge_properties).
"""
query_str = """
MATCH (n:Node)-[r]-(m:Node)
@ -750,12 +780,14 @@ class KuzuAdapter(GraphDBInterface):
"""
try:
results = await self.query(query_str, {"node_id": node_id})
edges = []
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
for row in results:
if row and len(row) == 3:
source_node = self._parse_node_properties(row[0])
relationship_name = row[1]
target_node = self._parse_node_properties(row[2])
edges.append((source_node, row[1], target_node))
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
return edges
except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}")
@ -977,7 +1009,7 @@ class KuzuAdapter(GraphDBInterface):
return []
async def get_connections(
self, node_id: str
self, node_id: Union[str, UUID]
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
"""
Get all nodes connected to a given node.
@ -1019,7 +1051,9 @@ class KuzuAdapter(GraphDBInterface):
}
"""
try:
results = await self.query(query_str, {"node_id": node_id})
# Convert UUID to string if needed
node_id_str = str(node_id)
results = await self.query(query_str, {"node_id": node_id_str})
edges = []
for row in results:
if row and len(row) == 3:
@ -1177,7 +1211,7 @@ class KuzuAdapter(GraphDBInterface):
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]:
"""
Get subgraph for a set of nodes based on type and names.
@ -1225,9 +1259,9 @@ class KuzuAdapter(GraphDBInterface):
RETURN n.id, n.name, n.type, n.properties
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, dict]] = []
nodes: List[Tuple[str, Dict[str, Any]]] = []
for node_id, name, typ, props in node_rows:
data = {"id": node_id, "name": name, "type": typ}
data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ}
if props:
try:
data.update(json.loads(props))
@ -1241,22 +1275,22 @@ class KuzuAdapter(GraphDBInterface):
RETURN a.id, b.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, dict]] = []
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
for from_id, to_id, rel_type, props in edge_rows:
data = {}
edge_data: Dict[str, Any] = {}
if props:
try:
data = json.loads(props)
edge_data = json.loads(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, data))
edges.append((from_id, to_id, rel_type, edge_data))
return nodes, edges
return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Get filtered nodes and relationships based on attributes.
@ -1299,7 +1333,7 @@ class KuzuAdapter(GraphDBInterface):
)
return ([n[0] for n in nodes], [e[0] for e in edges])
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.
@ -1322,8 +1356,8 @@ class KuzuAdapter(GraphDBInterface):
try:
# Get basic graph data
nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
num_edges = len(edges[0]["elements"]) if edges else 0
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
# Calculate mandatory metrics
mandatory_metrics = {
@ -1531,9 +1565,16 @@ class KuzuAdapter(GraphDBInterface):
# Reinitialize the database
self._initialize_connection()
if not self.connection:
raise RuntimeError("Failed to establish database connection")
# Verify the database is empty
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
count = result.get_next()[0] if result.has_next() else 0
if not isinstance(result, list):
result = [result]
for single_result in result:
count = single_result.get_next()[0] if single_result.has_next() else 0 # type: ignore
if count > 0:
logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion"
@ -1544,7 +1585,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Error during database clearing: {e}")
raise
async def get_document_subgraph(self, data_id: str):
async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]:
"""
Get all nodes that should be deleted when removing a document.
@ -1616,7 +1657,7 @@ class KuzuAdapter(GraphDBInterface):
"orphan_types": result[0][4],
}
async def get_degree_one_nodes(self, node_type: str):
async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]:
"""
Get all nodes that have only one connection.
@ -1769,8 +1810,8 @@ class KuzuAdapter(GraphDBInterface):
ids: List[str] = []
if time_from and time_to:
time_from = date_to_int(time_from)
time_to = date_to_int(time_to)
time_from_int = date_to_int(time_from)
time_to_int = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
@ -1782,13 +1823,13 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from}
AND t <= {time_to}
WHERE t >= {time_from_int}
AND t <= {time_to_int}
RETURN n.id as id
"""
elif time_from:
time_from = date_to_int(time_from)
time_from_int = date_to_int(time_from)
cypher = f"""
MATCH (n:Node)
@ -1800,12 +1841,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from}
WHERE t >= {time_from_int}
RETURN n.id as id
"""
elif time_to:
time_to = date_to_int(time_to)
time_to_int = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
@ -1817,12 +1858,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t <= {time_to}
WHERE t <= {time_to_int}
RETURN n.id as id
"""
else:
return ids
return ", ".join(f"'{uid}'" for uid in ids)
time_nodes = await self.query(cypher)
time_ids_list = [item[0] for item in time_nodes]