mypy: first fix KuzuAdapter mypy errors
This commit is contained in:
parent
86f3d46bf5
commit
6b2301ff28
1 changed files with 114 additions and 73 deletions
|
|
@ -10,7 +10,7 @@ from kuzu.database import Database
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.utils.run_sync import run_sync
|
from cognee.infrastructure.utils.run_sync import run_sync
|
||||||
|
|
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||||
from cognee.tasks.temporal_graph.models import Timestamp
|
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -167,7 +167,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
|
||||||
"""
|
"""
|
||||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||||
|
|
||||||
|
|
@ -190,23 +190,32 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
params = params or {}
|
params = params or {}
|
||||||
|
|
||||||
def blocking_query():
|
def blocking_query() -> List[Tuple[Any, ...]]:
|
||||||
try:
|
try:
|
||||||
if not self.connection:
|
if not self.connection:
|
||||||
logger.debug("Reconnecting to Kuzu database...")
|
logger.debug("Reconnecting to Kuzu database...")
|
||||||
self._initialize_connection()
|
self._initialize_connection()
|
||||||
|
|
||||||
|
if not self.connection:
|
||||||
|
raise RuntimeError("Failed to establish database connection")
|
||||||
|
|
||||||
result = self.connection.execute(query, params)
|
result = self.connection.execute(query, params)
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
while result.has_next():
|
if not isinstance(result, list):
|
||||||
row = result.get_next()
|
result = [result]
|
||||||
processed_rows = []
|
|
||||||
for val in row:
|
# Handle QueryResult vs List[QueryResult] union type
|
||||||
if hasattr(val, "as_py"):
|
for single_result in result:
|
||||||
val = val.as_py()
|
while single_result.has_next():
|
||||||
processed_rows.append(val)
|
row = single_result.get_next()
|
||||||
rows.append(tuple(processed_rows))
|
processed_rows = []
|
||||||
|
for val in row:
|
||||||
|
if hasattr(val, "as_py"):
|
||||||
|
val = val.as_py()
|
||||||
|
processed_rows.append(val)
|
||||||
|
rows.append(tuple(processed_rows))
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Query execution failed: {str(e)}")
|
logger.error(f"Query execution failed: {str(e)}")
|
||||||
|
|
@ -215,7 +224,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
return await loop.run_in_executor(self.executor, blocking_query)
|
return await loop.run_in_executor(self.executor, blocking_query)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_session(self):
|
async def get_session(self) -> AsyncGenerator[Optional[Connection], None]:
|
||||||
"""
|
"""
|
||||||
Get a database session.
|
Get a database session.
|
||||||
|
|
||||||
|
|
@ -255,7 +264,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
def _edge_query_and_params(
|
def _edge_query_and_params(
|
||||||
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
|
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, Dict[str, Any]]:
|
||||||
"""Build the edge creation query and parameters."""
|
"""Build the edge creation query and parameters."""
|
||||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||||
query = """
|
query = """
|
||||||
|
|
@ -305,7 +314,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
result = await self.query(query_str, {"id": node_id})
|
result = await self.query(query_str, {"id": node_id})
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def add_node(self, node: DataPoint) -> None:
|
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Add a single node to the graph if it doesn't exist.
|
Add a single node to the graph if it doesn't exist.
|
||||||
|
|
||||||
|
|
@ -319,20 +328,30 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
- node (DataPoint): The node to be added, represented as a DataPoint.
|
- node (DataPoint): The node to be added, represented as a DataPoint.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
if isinstance(node, str):
|
||||||
|
# Handle string node ID with properties parameter
|
||||||
|
node_properties = properties or {}
|
||||||
|
core_properties = {
|
||||||
|
"id": node,
|
||||||
|
"name": str(node_properties.get("name", "")),
|
||||||
|
"type": str(node_properties.get("type", "")),
|
||||||
|
}
|
||||||
|
# Use the passed properties, excluding core fields
|
||||||
|
other_properties = {k: v for k, v in node_properties.items()
|
||||||
|
if k not in ["id", "name", "type"]}
|
||||||
|
else:
|
||||||
|
# Handle DataPoint object
|
||||||
|
node_properties = node.model_dump()
|
||||||
|
core_properties = {
|
||||||
|
"id": str(node_properties.get("id", "")),
|
||||||
|
"name": str(node_properties.get("name", "")),
|
||||||
|
"type": str(node_properties.get("type", "")),
|
||||||
|
}
|
||||||
|
# Remove core fields from other properties
|
||||||
|
other_properties = {k: v for k, v in node_properties.items()
|
||||||
|
if k not in ["id", "name", "type"]}
|
||||||
|
|
||||||
# Extract core fields with defaults if not present
|
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
||||||
core_properties = {
|
|
||||||
"id": str(properties.get("id", "")),
|
|
||||||
"name": str(properties.get("name", "")),
|
|
||||||
"type": str(properties.get("type", "")),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Remove core fields from other properties
|
|
||||||
for key in core_properties:
|
|
||||||
properties.pop(key, None)
|
|
||||||
|
|
||||||
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
|
|
||||||
|
|
||||||
# Add timestamps for new node
|
# Add timestamps for new node
|
||||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||||
|
|
@ -360,7 +379,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Failed to add node: {e}")
|
logger.error(f"Failed to add node: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the graph in a batch operation.
|
Add multiple nodes to the graph in a batch operation.
|
||||||
|
|
@ -568,7 +587,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Check if multiple edges exist in a batch operation.
|
Check if multiple edges exist in a batch operation.
|
||||||
|
|
||||||
|
|
@ -599,7 +618,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"to_id": str(to_node), # Ensure string type
|
"to_id": str(to_node), # Ensure string type
|
||||||
"relationship_name": str(edge_label), # Ensure string type
|
"relationship_name": str(edge_label), # Ensure string type
|
||||||
}
|
}
|
||||||
for from_node, to_node, edge_label in edges
|
for from_node, to_node, edge_label, _ in edges
|
||||||
]
|
]
|
||||||
|
|
||||||
# Batch check query with direct string comparison
|
# Batch check query with direct string comparison
|
||||||
|
|
@ -615,9 +634,21 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
results = await self.query(query, {"edges": edge_params})
|
results = await self.query(query, {"edges": edge_params})
|
||||||
|
|
||||||
# Convert results back to tuples and ensure string types
|
# Convert results back to tuples and ensure string types
|
||||||
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
|
# Find the original edge properties for each existing edge
|
||||||
|
# TODO: get review on this
|
||||||
|
existing_edges = []
|
||||||
|
for row in results:
|
||||||
|
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
|
||||||
|
# Find the original properties from the input edges
|
||||||
|
original_props = {}
|
||||||
|
for orig_from, orig_to, orig_rel, orig_props in edges:
|
||||||
|
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
|
||||||
|
original_props = orig_props
|
||||||
|
break
|
||||||
|
existing_edges.append((from_id, to_id, rel_name, original_props))
|
||||||
|
|
||||||
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
|
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
|
||||||
|
# TODO: otherwise, we can just return dummy properties since they are not used apparently
|
||||||
return existing_edges
|
return existing_edges
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -626,10 +657,10 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
self,
|
self,
|
||||||
from_node: str,
|
source_id: str,
|
||||||
to_node: str,
|
target_id: str,
|
||||||
relationship_name: str,
|
relationship_name: str,
|
||||||
edge_properties: Dict[str, Any] = {},
|
properties: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add an edge between two nodes.
|
Add an edge between two nodes.
|
||||||
|
|
@ -641,23 +672,23 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- from_node (str): The identifier of the source node from which the edge originates.
|
- source_id (str): The identifier of the source node from which the edge originates.
|
||||||
- to_node (str): The identifier of the target node to which the edge points.
|
- target_id (str): The identifier of the target node to which the edge points.
|
||||||
- relationship_name (str): The label of the edge to be created, representing the
|
- relationship_name (str): The label of the edge to be created, representing the
|
||||||
relationship name.
|
relationship name.
|
||||||
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
|
- properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge.
|
||||||
(default {})
|
(default None)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
query, params = self._edge_query_and_params(
|
query, params = self._edge_query_and_params(
|
||||||
from_node, to_node, relationship_name, edge_properties
|
source_id, target_id, relationship_name, properties or {}
|
||||||
)
|
)
|
||||||
await self.query(query, params)
|
await self.query(query, params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to add edge: {e}")
|
logger.error(f"Failed to add edge: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple edges in a batch operation.
|
Add multiple edges in a batch operation.
|
||||||
|
|
@ -712,7 +743,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Failed to add edges in batch: {e}")
|
logger.error(f"Failed to add edges in batch: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
|
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Get all edges connected to a node.
|
Get all edges connected to a node.
|
||||||
|
|
||||||
|
|
@ -727,9 +758,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
||||||
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
|
- List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each
|
||||||
tuple contains (source_node, relationship_name, target_node), with source_node and
|
tuple contains (source_id, relationship_name, target_id, edge_properties).
|
||||||
target_node as dictionaries of node properties.
|
|
||||||
"""
|
"""
|
||||||
query_str = """
|
query_str = """
|
||||||
MATCH (n:Node)-[r]-(m:Node)
|
MATCH (n:Node)-[r]-(m:Node)
|
||||||
|
|
@ -750,12 +780,14 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = await self.query(query_str, {"node_id": node_id})
|
results = await self.query(query_str, {"node_id": node_id})
|
||||||
edges = []
|
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
|
||||||
for row in results:
|
for row in results:
|
||||||
if row and len(row) == 3:
|
if row and len(row) == 3:
|
||||||
source_node = self._parse_node_properties(row[0])
|
source_node = self._parse_node_properties(row[0])
|
||||||
|
relationship_name = row[1]
|
||||||
target_node = self._parse_node_properties(row[2])
|
target_node = self._parse_node_properties(row[2])
|
||||||
edges.append((source_node, row[1], target_node))
|
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
|
||||||
|
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
|
||||||
return edges
|
return edges
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
||||||
|
|
@ -977,7 +1009,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_connections(
|
async def get_connections(
|
||||||
self, node_id: str
|
self, node_id: Union[str, UUID]
|
||||||
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
|
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Get all nodes connected to a given node.
|
Get all nodes connected to a given node.
|
||||||
|
|
@ -1019,7 +1051,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = await self.query(query_str, {"node_id": node_id})
|
# Convert UUID to string if needed
|
||||||
|
node_id_str = str(node_id)
|
||||||
|
results = await self.query(query_str, {"node_id": node_id_str})
|
||||||
edges = []
|
edges = []
|
||||||
for row in results:
|
for row in results:
|
||||||
if row and len(row) == 3:
|
if row and len(row) == 3:
|
||||||
|
|
@ -1177,7 +1211,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
async def get_nodeset_subgraph(
|
async def get_nodeset_subgraph(
|
||||||
self, node_type: Type[Any], node_name: List[str]
|
self, node_type: Type[Any], node_name: List[str]
|
||||||
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Get subgraph for a set of nodes based on type and names.
|
Get subgraph for a set of nodes based on type and names.
|
||||||
|
|
||||||
|
|
@ -1225,9 +1259,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
RETURN n.id, n.name, n.type, n.properties
|
RETURN n.id, n.name, n.type, n.properties
|
||||||
"""
|
"""
|
||||||
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
||||||
nodes: List[Tuple[str, dict]] = []
|
nodes: List[Tuple[str, Dict[str, Any]]] = []
|
||||||
for node_id, name, typ, props in node_rows:
|
for node_id, name, typ, props in node_rows:
|
||||||
data = {"id": node_id, "name": name, "type": typ}
|
data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ}
|
||||||
if props:
|
if props:
|
||||||
try:
|
try:
|
||||||
data.update(json.loads(props))
|
data.update(json.loads(props))
|
||||||
|
|
@ -1241,22 +1275,22 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
RETURN a.id, b.id, r.relationship_name, r.properties
|
RETURN a.id, b.id, r.relationship_name, r.properties
|
||||||
"""
|
"""
|
||||||
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
||||||
edges: List[Tuple[str, str, str, dict]] = []
|
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
|
||||||
for from_id, to_id, rel_type, props in edge_rows:
|
for from_id, to_id, rel_type, props in edge_rows:
|
||||||
data = {}
|
edge_data: Dict[str, Any] = {}
|
||||||
if props:
|
if props:
|
||||||
try:
|
try:
|
||||||
data = json.loads(props)
|
edge_data = json.loads(props)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
||||||
|
|
||||||
edges.append((from_id, to_id, rel_type, data))
|
edges.append((from_id, to_id, rel_type, edge_data))
|
||||||
|
|
||||||
return nodes, edges
|
return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs
|
||||||
|
|
||||||
async def get_filtered_graph_data(
|
async def get_filtered_graph_data(
|
||||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||||
):
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Get filtered nodes and relationships based on attributes.
|
Get filtered nodes and relationships based on attributes.
|
||||||
|
|
||||||
|
|
@ -1299,7 +1333,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return ([n[0] for n in nodes], [e[0] for e in edges])
|
return ([n[0] for n in nodes], [e[0] for e in edges])
|
||||||
|
|
||||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get metrics on graph structure and connectivity.
|
Get metrics on graph structure and connectivity.
|
||||||
|
|
||||||
|
|
@ -1322,8 +1356,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
try:
|
try:
|
||||||
# Get basic graph data
|
# Get basic graph data
|
||||||
nodes, edges = await self.get_model_independent_graph_data()
|
nodes, edges = await self.get_model_independent_graph_data()
|
||||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
|
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
|
||||||
num_edges = len(edges[0]["elements"]) if edges else 0
|
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
|
||||||
|
|
||||||
# Calculate mandatory metrics
|
# Calculate mandatory metrics
|
||||||
mandatory_metrics = {
|
mandatory_metrics = {
|
||||||
|
|
@ -1531,9 +1565,16 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
# Reinitialize the database
|
# Reinitialize the database
|
||||||
self._initialize_connection()
|
self._initialize_connection()
|
||||||
|
|
||||||
|
if not self.connection:
|
||||||
|
raise RuntimeError("Failed to establish database connection")
|
||||||
|
|
||||||
# Verify the database is empty
|
# Verify the database is empty
|
||||||
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
|
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
|
||||||
count = result.get_next()[0] if result.has_next() else 0
|
if not isinstance(result, list):
|
||||||
|
result = [result]
|
||||||
|
for single_result in result:
|
||||||
|
count = single_result.get_next()[0] if single_result.has_next() else 0 # type: ignore
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Database still contains {count} nodes after clearing, forcing deletion"
|
f"Database still contains {count} nodes after clearing, forcing deletion"
|
||||||
|
|
@ -1544,7 +1585,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Error during database clearing: {e}")
|
logger.error(f"Error during database clearing: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_document_subgraph(self, data_id: str):
|
async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get all nodes that should be deleted when removing a document.
|
Get all nodes that should be deleted when removing a document.
|
||||||
|
|
||||||
|
|
@ -1616,7 +1657,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"orphan_types": result[0][4],
|
"orphan_types": result[0][4],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_degree_one_nodes(self, node_type: str):
|
async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get all nodes that have only one connection.
|
Get all nodes that have only one connection.
|
||||||
|
|
||||||
|
|
@ -1769,8 +1810,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
ids: List[str] = []
|
ids: List[str] = []
|
||||||
|
|
||||||
if time_from and time_to:
|
if time_from and time_to:
|
||||||
time_from = date_to_int(time_from)
|
time_from_int = date_to_int(time_from)
|
||||||
time_to = date_to_int(time_to)
|
time_to_int = date_to_int(time_to)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n:Node)
|
MATCH (n:Node)
|
||||||
|
|
@ -1782,13 +1823,13 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||||
ELSE CAST(t_str AS INT64)
|
ELSE CAST(t_str AS INT64)
|
||||||
END AS t
|
END AS t
|
||||||
WHERE t >= {time_from}
|
WHERE t >= {time_from_int}
|
||||||
AND t <= {time_to}
|
AND t <= {time_to_int}
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
elif time_from:
|
elif time_from:
|
||||||
time_from = date_to_int(time_from)
|
time_from_int = date_to_int(time_from)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n:Node)
|
MATCH (n:Node)
|
||||||
|
|
@ -1800,12 +1841,12 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||||
ELSE CAST(t_str AS INT64)
|
ELSE CAST(t_str AS INT64)
|
||||||
END AS t
|
END AS t
|
||||||
WHERE t >= {time_from}
|
WHERE t >= {time_from_int}
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
elif time_to:
|
elif time_to:
|
||||||
time_to = date_to_int(time_to)
|
time_to_int = date_to_int(time_to)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n:Node)
|
MATCH (n:Node)
|
||||||
|
|
@ -1817,12 +1858,12 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||||
ELSE CAST(t_str AS INT64)
|
ELSE CAST(t_str AS INT64)
|
||||||
END AS t
|
END AS t
|
||||||
WHERE t <= {time_to}
|
WHERE t <= {time_to_int}
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return ids
|
return ", ".join(f"'{uid}'" for uid in ids)
|
||||||
|
|
||||||
time_nodes = await self.query(cypher)
|
time_nodes = await self.query(cypher)
|
||||||
time_ids_list = [item[0] for item in time_nodes]
|
time_ids_list = [item[0] for item in time_nodes]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue