cognee/cognee/infrastructure/databases/graph/kuzu/adapter.py
2025-10-20 11:56:15 +01:00

1895 lines
69 KiB
Python

"""Adapter for Kuzu graph database."""
import os
import json
import asyncio
import tempfile
from uuid import UUID, uuid5, NAMESPACE_OID
from kuzu import Connection
from kuzu.database import Database
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.infrastructure.databases.cache.config import get_cache_config
logger = get_logger()
cache_config = get_cache_config()
if cache_config.shared_kuzu_lock:
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
class KuzuAdapter(GraphDBInterface):
"""
Adapter for Kuzu graph database operations with improved consistency and async support.
This class facilitates operations for working with the Kuzu graph database, supporting
both direct database queries and a structured asynchronous interface for node and edge
management. It contains methods for querying, adding, and deleting nodes and edges as
well as for graph metrics and data extraction.
"""
def __init__(self, db_path: str):
"""Initialize Kuzu database connection and schema."""
self.open_connections = 0
self._is_closed = False
self.db_path = db_path # Path for the database directory
self.db: Optional[Database] = None
self.connection: Optional[Connection] = None
if cache_config.shared_kuzu_lock:
self.redis_lock = get_cache_engine(
lock_key="kuzu-lock-" + str(uuid5(NAMESPACE_OID, db_path))
)
else:
self.executor = ThreadPoolExecutor()
self._initialize_connection()
self.KUZU_ASYNC_LOCK = asyncio.Lock()
self._connection_change_lock = asyncio.Lock()
def _initialize_connection(self) -> None:
"""Initialize the Kuzu database connection and schema."""
def _install_json_extension():
"""
Function handles installing of the json extension for the current Kuzu version.
This has to be done with an empty graph db before connecting to an existing database otherwise
missing json extension errors will be raised.
"""
try:
with tempfile.NamedTemporaryFile(mode="w", delete=True) as temp_file:
temp_graph_file = temp_file.name
tmp_db = Database(
temp_graph_file,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
tmp_db.init_database()
connection = Connection(tmp_db)
connection.execute("INSTALL JSON;")
except Exception as e:
logger.info(f"JSON extension already installed or not needed: {e}")
_install_json_extension()
try:
if "s3://" in self.db_path:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
self.temp_graph_file = temp_file.name
run_sync(self.pull_from_s3())
self.db = Database(
self.temp_graph_file,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
else:
# Ensure the parent directory exists before creating the database
db_dir = os.path.dirname(self.db_path)
# If db_path is just a filename, db_dir will be empty string
# In this case, use the directory containing the db_path or current directory
if not db_dir:
# If no directory in path, use the absolute path's directory
abs_path = os.path.abspath(self.db_path)
db_dir = os.path.dirname(abs_path)
file_storage = get_file_storage(db_dir)
run_sync(file_storage.ensure_directory_exists())
try:
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
except RuntimeError:
from .kuzu_migrate import read_kuzu_storage_version
import kuzu
kuzu_db_version = read_kuzu_storage_version(self.db_path)
if (
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
) and kuzu_db_version != kuzu.__version__:
# Try to migrate kuzu database to latest version
from .kuzu_migrate import kuzu_migration
kuzu_migration(
new_db=self.db_path + "_new",
old_db=self.db_path,
new_version=kuzu.__version__,
old_version=kuzu_db_version,
overwrite=True,
)
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
self.db.init_database()
self.connection = Connection(self.db)
try:
self.connection.execute("LOAD EXTENSION JSON;")
logger.info("Loaded JSON extension")
except Exception as e:
logger.info(f"JSON extension already loaded or unavailable: {e}")
# Create node table with essential fields and timestamp
self.connection.execute("""
CREATE NODE TABLE IF NOT EXISTS Node(
id STRING PRIMARY KEY,
name STRING,
type STRING,
created_at TIMESTAMP,
updated_at TIMESTAMP,
properties STRING
)
""")
# Create relationship table with timestamp
self.connection.execute("""
CREATE REL TABLE IF NOT EXISTS EDGE(
FROM Node TO Node,
relationship_name STRING,
created_at TIMESTAMP,
updated_at TIMESTAMP,
properties STRING
)
""")
logger.debug("Kuzu database initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
if self.connection:
async with self.KUZU_ASYNC_LOCK:
self.connection.execute("CHECKPOINT;")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
async def pull_from_s3(self) -> None:
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
try:
s3_file_storage.s3.get(self.db_path, self.temp_graph_file, recursive=True)
except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
"""
Execute a Kuzu query asynchronously with automatic reconnection.
This method runs a database query while managing potential reconnections. It handles
parameters in a dictionary and processes results to return structured data. The method
raises any exceptions encountered during query execution.
Parameters:
-----------
- query (str): The Kuzu query string to be executed.
- params (Optional[dict]): A dictionary of parameters for the query, if applicable.
(default None)
Returns:
--------
- List[Tuple]: A list of tuples representing the query results.
"""
loop = asyncio.get_running_loop()
params = params or {}
def blocking_query():
lock_acquired = False
try:
if cache_config.shared_kuzu_lock:
self.redis_lock.acquire_lock()
lock_acquired = True
if not self.connection:
logger.info("Reconnecting to Kuzu database...")
self._initialize_connection()
result = self.connection.execute(query, params)
rows = []
while result.has_next():
row = result.get_next()
processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
return rows
except Exception as e:
logger.error(f"Query execution failed: {str(e)}")
raise
finally:
if cache_config.shared_kuzu_lock and lock_acquired:
try:
self.close()
finally:
self.redis_lock.release_lock()
if cache_config.shared_kuzu_lock:
async with self._connection_change_lock:
self.open_connections += 1
logger.info(f"Open connections after open: {self.open_connections}")
try:
result = blocking_query()
finally:
self.open_connections -= 1
logger.info(f"Open connections after close: {self.open_connections}")
return result
else:
result = await loop.run_in_executor(self.executor, blocking_query)
return result
def close(self):
if self.connection:
del self.connection
self.connection = None
if self.db:
del self.db
self.db = None
self._is_closed = True
logger.info("Kuzu database closed successfully")
def reopen(self):
if self._is_closed:
self._is_closed = False
self._initialize_connection()
logger.info("Kuzu database re-opened successfully")
@asynccontextmanager
async def get_session(self):
"""
Get a database session.
This provides an API-compatible session management for Kuzu, even though it does not
have built-in session management like other databases. It yields the current connection
and on exit performs cleanup if necessary.
"""
try:
yield self.connection
finally:
pass
def _parse_node(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a raw node result (with JSON properties) into a dictionary."""
if data.get("properties"):
try:
props = json.loads(data["properties"])
# Remove the JSON field and merge its contents
data.pop("properties")
data.update(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {data.get('id')}")
return data
def _parse_node_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
if isinstance(data, dict) and "properties" in data and data["properties"]:
props = json.loads(data["properties"])
data.update(props)
del data["properties"]
return data
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {data.get('id')}")
return data
# Helper method for building edge queries
def _edge_query_and_params(
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
) -> Tuple[str, dict]:
"""Build the edge creation query and parameters."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
query = """
MATCH (from:Node), (to:Node)
WHERE from.id = $from_id AND to.id = $to_id
MERGE (from)-[r:EDGE {
relationship_name: $relationship_name
}]->(to)
ON CREATE SET
r.created_at = timestamp($created_at),
r.updated_at = timestamp($updated_at),
r.properties = $properties
ON MATCH SET
r.updated_at = timestamp($updated_at),
r.properties = $properties
"""
params = {
"from_id": from_node,
"to_id": to_node,
"relationship_name": relationship_name,
"created_at": now,
"updated_at": now,
"properties": json.dumps(properties, cls=JSONEncoder),
}
return query, params
# Node Operations
async def has_node(self, node_id: str) -> bool:
"""
Check if a node exists.
This method checks for the existence of a node in the database by its identifier. It
returns a boolean indicating whether the node is present or not.
Parameters:
-----------
- node_id (str): The identifier of the node to check.
Returns:
--------
- bool: True if the node exists, False otherwise.
"""
query_str = "MATCH (n:Node) WHERE n.id = $id RETURN COUNT(n) > 0"
result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False
async def add_node(self, node: DataPoint) -> None:
"""
Add a single node to the graph if it doesn't exist.
This method constructs and executes a query to add a node to the graph, ensuring that it
is not duplicated by checking its existence first. An error is raised if the operation
fails.
Parameters:
-----------
- node (DataPoint): The node to be added, represented as a DataPoint.
"""
try:
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
# Extract core fields with defaults if not present
core_properties = {
"id": str(properties.get("id", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
# Remove core fields from other properties
for key in core_properties:
properties.pop(key, None)
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
# Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
fields = []
params = {}
for key, value in core_properties.items():
if value is not None:
param_name = f"param_{key}"
fields.append(f"{key}: ${param_name}")
params[param_name] = value
# Add timestamp fields
fields.extend(
["created_at: timestamp($created_at)", "updated_at: timestamp($updated_at)"]
)
params.update({"created_at": now, "updated_at": now})
merge_query = f"""
MERGE (n:Node {{id: $param_id}})
ON CREATE SET n += {{{", ".join(fields)}}}
"""
await self.query(merge_query, params)
except Exception as e:
logger.error(f"Failed to add node: {e}")
raise
@record_graph_changes
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the graph in a batch operation.
This method allows for the addition of multiple nodes in a single operation to enhance
performance. It processes a list of nodes and constructs the necessary query for
insertion. Errors encountered during the addition will be logged and raised.
Parameters:
-----------
- nodes (List[DataPoint]): A list of nodes to be added to the graph, each
represented as a DataPoint.
"""
if not nodes:
return
try:
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
# Prepare all nodes data
node_params = []
for node in nodes:
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
core_properties = {
"id": str(properties.get("id", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
# Remove core fields from other properties
for key in core_properties:
properties.pop(key, None)
node_params.append(
{
**core_properties,
"properties": json.dumps(properties, cls=JSONEncoder),
"created_at": now,
"updated_at": now,
}
)
if node_params:
# Batch merge nodes
merge_query = """
UNWIND $nodes AS node
MERGE (n:Node {id: node.id})
ON CREATE SET
n.name = node.name,
n.type = node.type,
n.properties = node.properties,
n.created_at = timestamp(node.created_at),
n.updated_at = timestamp(node.updated_at)
ON MATCH SET
n.name = node.name,
n.type = node.type,
n.properties = node.properties,
n.updated_at = timestamp(node.updated_at)
"""
await self.query(merge_query, {"nodes": node_params})
logger.debug(f"Processed {len(node_params)} nodes in batch")
except Exception as e:
logger.error(f"Failed to add nodes in batch: {e}")
raise
async def delete_node(self, node_id: str) -> None:
"""
Delete a node and its relationships.
This method removes a node identified by its ID along with all associated relationships.
It encapsulates the delete operation for simplicity in usage.
Parameters:
-----------
- node_id (str): The identifier of the node to be deleted.
"""
query_str = "MATCH (n:Node) WHERE n.id = $id DETACH DELETE n"
await self.query(query_str, {"id": node_id})
async def delete_nodes(self, node_ids: List[str]) -> None:
"""
Delete multiple nodes at once.
This method facilitates the deletion of a list of nodes, identified by their IDs,
concurrently. It ensures efficiency by using a single query to detach deletes for all
nodes in the list.
Parameters:
-----------
- node_ids (List[str]): A list of identifiers for the nodes to be deleted.
"""
query_str = "MATCH (n:Node) WHERE n.id IN $ids DETACH DELETE n"
await self.query(query_str, {"ids": node_ids})
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
"""
Extract a node by its ID.
This method retrieves a node's data by its identifier and returns it as a dictionary. If
the node is not found or an error occurs, it returns None.
Parameters:
-----------
- node_id (str): The identifier of the node to be extracted.
Returns:
--------
- Optional[Dict[str, Any]]: A dictionary of the node's properties if found,
otherwise None.
"""
query_str = """
MATCH (n:Node)
WHERE n.id = $id
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
}
"""
try:
result = await self.query(query_str, {"id": node_id})
if result and result[0]:
node_data = self._parse_node(result[0][0])
return node_data
return None
except Exception as e:
logger.error(f"Failed to extract node {node_id}: {e}")
return None
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
"""
Extract multiple nodes by their IDs.
This method retrieves a list of nodes identified by their IDs and returns their data as
a list of dictionaries. It handles possible retrieval errors internally and will return
an empty list if no nodes are found.
Parameters:
-----------
- node_ids (List[str]): A list of identifiers for the nodes to be extracted.
Returns:
--------
- List[Dict[str, Any]]: A list of dictionaries containing the properties of the
extracted nodes.
"""
query_str = """
MATCH (n:Node)
WHERE n.id IN $node_ids
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
}
"""
try:
results = await self.query(query_str, {"node_ids": node_ids})
# Parse each node using the same helper function
nodes = [self._parse_node(row[0]) for row in results if row[0]]
return nodes
except Exception as e:
logger.error(f"Failed to extract nodes: {e}")
return []
# Edge Operations
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
"""
Check if an edge exists between nodes with the given relationship name.
This method verifies the existence of a directed edge defined by the relationship name
between two specified nodes. It returns a boolean value indicating presence or absence
of the edge.
Parameters:
-----------
- from_node (str): The identifier of the source node.
- to_node (str): The identifier of the target node.
- edge_label (str): The label of the edge representing the relationship name.
Returns:
--------
- bool: True if the edge exists, False otherwise.
"""
query_str = """
MATCH (from:Node)-[r:EDGE]->(to:Node)
WHERE from.id = $from_id AND to.id = $to_id AND r.relationship_name = $edge_label
RETURN COUNT(r) > 0
"""
result = await self.query(
query_str, {"from_id": from_node, "to_id": to_node, "edge_label": edge_label}
)
return result[0][0] if result else False
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
"""
Check if multiple edges exist in a batch operation.
This method checks for the presence of specified edges in the database and returns a
list of edges that exist. It is beneficial for efficiency in checking multiple edges
simultaneously.
Parameters:
-----------
- edges (List[Tuple[str, str, str]]): A list of edges where each edge is represented
as a tuple of (from_node, to_node, edge_label).
Returns:
--------
- List[Tuple[str, str, str]]: A list of tuples representing the existing edges from
the provided list.
"""
if not edges:
return []
try:
# Transform edges into format needed for batch query
edge_params = [
{
"from_id": str(from_node), # Ensure string type
"to_id": str(to_node), # Ensure string type
"relationship_name": str(edge_label), # Ensure string type
}
for from_node, to_node, edge_label in edges
]
# Batch check query with direct string comparison
query = """
UNWIND $edges AS edge
MATCH (from:Node)-[r:EDGE]->(to:Node)
WHERE from.id = edge.from_id
AND to.id = edge.to_id
AND r.relationship_name = edge.relationship_name
RETURN from.id, to.id, r.relationship_name
"""
results = await self.query(query, {"edges": edge_params})
# Convert results back to tuples and ensure string types
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
return existing_edges
except Exception as e:
logger.error(f"Failed to check edges in batch: {e}")
return []
async def add_edge(
self,
from_node: str,
to_node: str,
relationship_name: str,
edge_properties: Dict[str, Any] = {},
) -> None:
"""
Add an edge between two nodes.
This method constructs and executes a query to create a directed edge between two
specified nodes with certain properties. It will raise an error if the addition fails
during execution.
Parameters:
-----------
- from_node (str): The identifier of the source node from which the edge originates.
- to_node (str): The identifier of the target node to which the edge points.
- relationship_name (str): The label of the edge to be created, representing the
relationship name.
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
(default {})
"""
try:
query, params = self._edge_query_and_params(
from_node, to_node, relationship_name, edge_properties
)
await self.query(query, params)
except Exception as e:
logger.error(f"Failed to add edge: {e}")
raise
@record_graph_changes
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""
Add multiple edges in a batch operation.
This method enables efficient insertion of multiple edges at once by processing a list
of edge details. It improves performance for batch operations compared to adding edges
individually. Errors during execution are logged and raised as necessary.
Parameters:
-----------
- edges (List[Tuple[str, str, str, Dict[str, Any]]]): A list of edges represented as
tuples of (from_node, to_node, relationship_name, edge_properties).
"""
if not edges:
return
try:
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
edge_params = [
{
"from_id": from_node,
"to_id": to_node,
"relationship_name": relationship_name,
"properties": json.dumps(properties, cls=JSONEncoder),
"created_at": now,
"updated_at": now,
}
for from_node, to_node, relationship_name, properties in edges
]
query = """
UNWIND $edges AS edge
MATCH (from:Node), (to:Node)
WHERE from.id = edge.from_id AND to.id = edge.to_id
MERGE (from)-[r:EDGE {
relationship_name: edge.relationship_name
}]->(to)
ON CREATE SET
r.created_at = timestamp(edge.created_at),
r.updated_at = timestamp(edge.updated_at),
r.properties = edge.properties
ON MATCH SET
r.updated_at = timestamp(edge.updated_at),
r.properties = edge.properties
"""
await self.query(query, {"edges": edge_params})
except Exception as e:
logger.error(f"Failed to add edges in batch: {e}")
raise
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
"""
Get all edges connected to a node.
This method retrieves all edges that are linked to a specified node and returns them in
a structured format. If an error occurs or no edges exist, an empty list is returned.
Parameters:
-----------
- node_id (str): The identifier of the node for which to retrieve edges.
Returns:
--------
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_node, relationship_name, target_node), with source_node and
target_node as dictionaries of node properties.
"""
query_str = """
MATCH (n:Node)-[r]-(m:Node)
WHERE n.id = $node_id
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
},
r.relationship_name,
{
id: m.id,
name: m.name,
type: m.type,
properties: m.properties
}
"""
try:
results = await self.query(query_str, {"node_id": node_id})
edges = []
for row in results:
if row and len(row) == 3:
source_node = self._parse_node_properties(row[0])
target_node = self._parse_node_properties(row[2])
edges.append((source_node, row[1], target_node))
return edges
except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}")
return []
# Neighbor Operations
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
"""
Get all neighboring nodes.
This method simply calls the get_neighbours method for API compatibility and retrieves
connected nodes neighboring the specified node. It returns a list of neighbor nodes'
properties as dictionaries.
Parameters:
-----------
- node_id (str): The identifier of the node for which to find neighbors.
Returns:
--------
- List[Dict[str, Any]]: A list of dictionaries representing neighboring nodes'
properties.
"""
return await self.get_neighbours(node_id)
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
"""
Get a single node by ID.
This method retrieves the properties of a node identified by its ID and returns them as
a dictionary. If the node does not exist, None is returned.
Parameters:
-----------
- node_id (str): The identifier of the node to retrieve.
Returns:
--------
- Optional[Dict[str, Any]]: A dictionary containing the properties of the node if
found, otherwise None.
"""
query_str = """
MATCH (n:Node)
WHERE n.id = $id
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
}
"""
try:
result = await self.query(query_str, {"id": node_id})
if result and result[0]:
return self._parse_node(result[0][0])
return None
except Exception as e:
logger.error(f"Failed to get node {node_id}: {e}")
return None
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
"""
Get multiple nodes by their IDs.
This method retrieves properties for multiple nodes identified by their IDs and returns
them as a list of dictionaries. An empty list is returned if no nodes are found or an
error occurs.
Parameters:
-----------
- node_ids (List[str]): A list of identifiers for the nodes to be retrieved.
Returns:
--------
- List[Dict[str, Any]]: A list of dictionaries containing properties of each
retrieved node.
"""
query_str = """
MATCH (n:Node)
WHERE n.id IN $node_ids
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
}
"""
try:
results = await self.query(query_str, {"node_ids": node_ids})
return [self._parse_node(row[0]) for row in results if row[0]]
except Exception as e:
logger.error(f"Failed to get nodes: {e}")
return []
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
"""
Get all neighbouring nodes.
This method retrieves all neighboring nodes connected to a specified node and returns
them as a list of dictionaries. It may return an empty list if no neighbors exist or an
error occurs.
Parameters:
-----------
- node_id (str): The identifier of the node for which to find neighbors.
Returns:
--------
- List[Dict[str, Any]]: A list of dictionaries representing neighboring nodes'
properties.
"""
query_str = """
MATCH (n)-[r]-(m)
WHERE n.id = $id
RETURN DISTINCT properties(m)
"""
try:
result = await self.query(query_str, {"id": node_id})
return [row[0] for row in result] if result else []
except Exception as e:
logger.error(f"Failed to get neighbours for node {node_id}: {e}")
return []
async def get_predecessors(
self, node_id: Union[str, UUID], edge_label: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get all predecessor nodes.
This method retrieves all nodes that are predecessors of the specified node. If an edge
label is provided, it filters the results accordingly. It returns a list of dictionaries
containing properties of these predecessor nodes.
Parameters:
-----------
- node_id (Union[str, UUID]): The identifier of the specified node.
- edge_label (Optional[str]): An optional label to filter the edges by relationship
name. (default None)
Returns:
--------
- List[Dict[str, Any]]: A list of dictionaries representing all predecessor nodes'
properties.
"""
try:
if edge_label:
query_str = """
MATCH (n)<-[r:EDGE]-(m)
WHERE n.id = $id AND r.relationship_name = $edge_label
RETURN properties(m)
"""
params = {"id": str(node_id), "edge_label": edge_label}
else:
query_str = """
MATCH (n)<-[r:EDGE]-(m)
WHERE n.id = $id
RETURN properties(m)
"""
params = {"id": str(node_id)}
result = await self.query(query_str, params)
return [row[0] for row in result] if result else []
except Exception as e:
logger.error(f"Failed to get predecessors for node {node_id}: {e}")
return []
async def get_successors(
self, node_id: Union[str, UUID], edge_label: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get all successor nodes.
This method retrieves all nodes that are successors of the specified node. An edge label
can be provided to filter the results. It returns a list of dictionaries detailing these
successor nodes' properties.
Parameters:
-----------
- node_id (Union[str, UUID]): The identifier of the specified node.
- edge_label (Optional[str]): An optional label to filter the edges by relationship
name. (default None)
Returns:
--------
- List[Dict[str, Any]]: A list of dictionaries representing all successor nodes'
properties.
"""
try:
if edge_label:
query_str = """
MATCH (n)-[r:EDGE]->(m)
WHERE n.id = $id AND r.relationship_name = $edge_label
RETURN properties(m)
"""
params = {"id": str(node_id), "edge_label": edge_label}
else:
query_str = """
MATCH (n)-[r:EDGE]->(m)
WHERE n.id = $id
RETURN properties(m)
"""
params = {"id": str(node_id)}
result = await self.query(query_str, params)
return [row[0] for row in result] if result else []
except Exception as e:
logger.error(f"Failed to get successors for node {node_id}: {e}")
return []
async def get_connections(
self, node_id: str
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
"""
Get all nodes connected to a given node.
This method retrieves all nodes directly connected to a specified node along with the
relationships between them, returning structured data in a list of tuples. Each tuple
contains source and target node properties along with the relationship information.
Parameters:
-----------
- node_id (str): The identifier of the node for which to retrieve connections.
Returns:
--------
- List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]}: A list of tuples
containing (source_node, relationship_name, target_node) with dictionaries for
source_node and target_node properties.
"""
query_str = """
MATCH (n:Node)-[r:EDGE]-(m:Node)
WHERE n.id = $node_id
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
},
{
relationship_name: r.relationship_name,
properties: r.properties
},
{
id: m.id,
name: m.name,
type: m.type,
properties: m.properties
}
"""
try:
results = await self.query(query_str, {"node_id": node_id})
edges = []
for row in results:
if row and len(row) == 3:
processed_rows = []
for i, item in enumerate(row):
if isinstance(item, dict):
if "properties" in item and item["properties"]:
try:
props = json.loads(item["properties"])
item.update(props)
del item["properties"]
except json.JSONDecodeError:
logger.warning(
f"Failed to parse JSON properties for node/edge {i}"
)
processed_rows.append(item)
edges.append(tuple(processed_rows))
return edges if edges else [] # Always return a list, even if empty
except Exception as e:
logger.error(f"Failed to get connections for node {node_id}: {e}")
return [] # Return empty list on error
async def remove_connection_to_predecessors_of(
self, node_ids: List[str], edge_label: str
) -> None:
"""
Remove all incoming edges of specified type for given nodes.
This method disconnects predecessor relationships of a specific type for the specified
nodes, managing edges in a single operation effectively.
Parameters:
-----------
- node_ids (List[str]): A list of identifiers for the nodes whose relationships to
be removed.
- edge_label (str): The label of the edge to be removed.
"""
query_str = """
MATCH (n)<-[r:EDGE]-(m)
WHERE n.id IN $node_ids AND r.relationship_name = $edge_label
DELETE r
"""
await self.query(query_str, {"node_ids": node_ids, "edge_label": edge_label})
async def remove_connection_to_successors_of(
self, node_ids: List[str], edge_label: str
) -> None:
"""
Remove all outgoing edges of specified type for given nodes.
This method disconnects successor relationships of a specified type for the specified
nodes in a single efficient operation.
Parameters:
-----------
- node_ids (List[str]): A list of identifiers for the nodes whose relationships to
be removed.
- edge_label (str): The label of the edge to be removed.
"""
query_str = """
MATCH (n)-[r:EDGE]->(m)
WHERE n.id IN $node_ids AND r.relationship_name = $edge_label
DELETE r
"""
await self.query(query_str, {"node_ids": node_ids, "edge_label": edge_label})
# Graph-wide Operations
async def get_graph_data(
self,
) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, str, str, Dict[str, Any]]]]:
"""
Get all nodes and edges in the graph.
This method fetches the entire graph's structure, including all nodes and their
properties as well as relationships and their details, returning them in a structured
format. Errors during query execution will result in raised exceptions.
Returns:
--------
- Tuple[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, str, str, Dict[str, Any]]]]:
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
tuples of (source_id, target_id, relationship_name, properties).
"""
try:
nodes_query = """
MATCH (n:Node)
RETURN n.id, {
name: n.name,
type: n.type,
properties: n.properties
}
"""
nodes = await self.query(nodes_query)
formatted_nodes = []
for n in nodes:
if n[0]:
node_id = str(n[0])
props = n[1]
if props.get("properties"):
try:
additional_props = json.loads(props["properties"])
props.update(additional_props)
del props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {node_id}")
formatted_nodes.append((node_id, props))
if not formatted_nodes:
logger.warning("No nodes found in the database")
return [], []
edges_query = """
MATCH (n:Node)-[r]->(m:Node)
RETURN n.id, m.id, r.relationship_name, r.properties
"""
edges = await self.query(edges_query)
formatted_edges = []
for e in edges:
if e and len(e) >= 3:
source_id = str(e[0])
target_id = str(e[1])
rel_type = str(e[2])
props = {}
if len(e) > 3 and e[3]:
try:
props = json.loads(e[3])
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Failed to parse edge properties for {source_id}->{target_id}"
)
formatted_edges.append((source_id, target_id, rel_type, props))
if formatted_nodes and not formatted_edges:
logger.debug("No edges found, creating self-referential edges for nodes")
for node_id, _ in formatted_nodes:
formatted_edges.append(
(
node_id,
node_id,
"SELF",
{
"relationship_name": "SELF",
"relationship_type": "SELF",
"vector_distance": 0.0,
},
)
)
return formatted_nodes, formatted_edges
except Exception as e:
logger.error(f"Failed to get graph data: {e}")
raise
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
"""
Get subgraph for a set of nodes based on type and names.
This method queries for nodes of a specific type and their corresponding neighbors,
returning both nodes and edges connecting them. It's useful for analyzing a targeted
subset of the graph.
Parameters:
-----------
- node_type (Type[Any]): Type of nodes to retrieve as specified by the user.
- node_name (List[str]): List of names corresponding to the nodes to be retrieved.
Returns:
--------
- Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]}: A tuple
containing a list of nodes and a list of edges related to those nodes.
"""
label = node_type.__name__
primary_query = """
UNWIND $names AS wantedName
MATCH (n:Node)
WHERE n.type = $label AND n.name = wantedName
RETURN DISTINCT n.id
"""
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
primary_ids = [row[0] for row in primary_rows]
if not primary_ids:
return [], []
neighbor_query = """
MATCH (n:Node)-[:EDGE]-(nbr:Node)
WHERE n.id IN $ids
RETURN DISTINCT nbr.id
"""
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
neighbor_ids = [row[0] for row in nbr_rows]
all_ids = list({*primary_ids, *neighbor_ids})
nodes_query = """
MATCH (n:Node)
WHERE n.id IN $ids
RETURN n.id, n.name, n.type, n.properties
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, dict]] = []
for node_id, name, typ, props in node_rows:
data = {"id": node_id, "name": name, "type": typ}
if props:
try:
data.update(json.loads(props))
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for node {node_id}")
nodes.append((node_id, data))
edges_query = """
MATCH (a:Node)-[r:EDGE]-(b:Node)
WHERE a.id IN $ids AND b.id IN $ids
RETURN a.id, b.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, dict]] = []
for from_id, to_id, rel_type, props in edge_rows:
data = {}
if props:
try:
data = json.loads(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, data))
return nodes, edges
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):
"""
Get filtered nodes and relationships based on attributes.
This method accepts attribute filters and retrieves nodes and relationships that match
the specified conditions. It allows complex filtering across node properties and edge
attributes.
Parameters:
-----------
- attribute_filters (List[Dict[str, List[Union[str, int]]]]): A list of dictionaries
specifying attributes and their corresponding values for filtering nodes and
edges.
Returns:
--------
A tuple containing a list of filtered node properties and a list of filtered edge
properties.
"""
where_clauses = []
params = {}
for i, filter_dict in enumerate(attribute_filters):
for attr, values in filter_dict.items():
param_name = f"values_{i}_{attr}"
where_clauses.append(f"n.{attr} IN ${param_name}")
params[param_name] = values
where_clause = " AND ".join(where_clauses)
nodes_query = (
f"MATCH (n:Node) WHERE {where_clause} RETURN n.id, {{properties: n.properties}}"
)
edges_query = f"""
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}
RETURN n1.id, n2.id, r.relationship_name, r.properties
"""
nodes, edges = await asyncio.gather(
self.query(nodes_query, params), self.query(edges_query, params)
)
formatted_nodes = []
for n in nodes:
if n[0]:
node_id = str(n[0])
props = n[1]
if props.get("properties"):
try:
additional_props = json.loads(props["properties"])
props.update(additional_props)
del props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {node_id}")
formatted_nodes.append((node_id, props))
if not formatted_nodes:
logger.warning("No nodes found in the database")
return [], []
formatted_edges = []
for e in edges:
if e and len(e) >= 3:
source_id = str(e[0])
target_id = str(e[1])
rel_type = str(e[2])
props = {}
if len(e) > 3 and e[3]:
try:
props = json.loads(e[3])
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Failed to parse edge properties for {source_id}->{target_id}"
)
formatted_edges.append((source_id, target_id, rel_type, props))
return formatted_nodes, formatted_edges
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.
This method computes various metrics around the graph, such as node and edge counts,
mean degree, and connected component sizes. Optionally, it can include additional
metrics based on user request.
Parameters:
-----------
- include_optional: A boolean flag indicating whether to include optional metrics in
the output. (default False)
Returns:
--------
- Dict[str, Any]: A dictionary containing various metrics related to the graph.
"""
try:
# Get basic graph data
nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
num_edges = len(edges[0]["elements"]) if edges else 0
# Calculate mandatory metrics
mandatory_metrics = {
"num_nodes": num_nodes,
"num_edges": num_edges,
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
"edge_density": num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
"num_connected_components": await self._get_num_connected_components(),
"sizes_of_connected_components": await self._get_size_of_connected_components(),
}
if include_optional:
# Calculate optional metrics
shortest_path_lengths = await self._get_shortest_path_lengths()
optional_metrics = {
"num_selfloops": await self._count_self_loops(),
"diameter": max(shortest_path_lengths) if shortest_path_lengths else -1,
"avg_shortest_path_length": sum(shortest_path_lengths)
/ len(shortest_path_lengths)
if shortest_path_lengths
else -1,
"avg_clustering": await self._get_avg_clustering(),
}
else:
optional_metrics = {
"num_selfloops": -1,
"diameter": -1,
"avg_shortest_path_length": -1,
"avg_clustering": -1,
}
return {**mandatory_metrics, **optional_metrics}
except Exception as e:
logger.error(f"Failed to get graph metrics: {e}")
return {
"num_nodes": 0,
"num_edges": 0,
"mean_degree": 0,
"edge_density": 0,
"num_connected_components": 0,
"sizes_of_connected_components": [],
"num_selfloops": -1,
"diameter": -1,
"avg_shortest_path_length": -1,
"avg_clustering": -1,
}
async def _get_num_connected_components(self) -> int:
"""Get the number of connected components in the graph."""
query = """
MATCH (n:Node)
WITH n.id AS node_id
MATCH path = (n)-[:EDGE*1..3]-(m)
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
RETURN SIZE(components) AS num_components
"""
result = await self.query(query)
return result[0][0] if result else 0
async def _get_size_of_connected_components(self) -> List[int]:
"""Get the sizes of all connected components in the graph."""
query = """
MATCH (n:Node)
WITH n.id AS node_id
MATCH path = (n)-[:EDGE*1..3]-(m)
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
UNWIND components AS component
RETURN SIZE(component) AS component_size
"""
result = await self.query(query)
return [row[0] for row in result] if result else []
async def _get_shortest_path_lengths(self) -> List[int]:
"""Get the lengths of shortest paths between all pairs of nodes."""
query = """
MATCH (n:Node), (m:Node)
WHERE n.id < m.id
MATCH path = (n)-[:EDGE*]-(m)
RETURN MIN(LENGTH(path)) AS length
"""
result = await self.query(query)
return [row[0] for row in result if row[0] is not None] if result else []
async def _count_self_loops(self) -> int:
"""Count the number of self-loops in the graph."""
query = """
MATCH (n:Node)-[r:EDGE]->(n)
RETURN COUNT(r) AS count
"""
result = await self.query(query)
return result[0][0] if result else 0
async def _get_avg_clustering(self) -> float:
"""Calculate the average clustering coefficient of the graph."""
query = """
MATCH (n:Node)-[:EDGE]-(neighbor)
WITH n, COUNT(DISTINCT neighbor) as degree
MATCH (n)-[:EDGE]-(n1)-[:EDGE]-(n2)-[:EDGE]-(n)
WHERE n1 <> n2
RETURN AVG(CASE WHEN degree <= 1 THEN 0 ELSE COUNT(DISTINCT n2) / (degree * (degree-1)) END) AS avg_clustering
"""
result = await self.query(query)
return result[0][0] if result and result[0][0] is not None else -1
async def get_disconnected_nodes(self) -> List[str]:
"""
Get nodes that are not connected to any other node.
This method retrieves identifiers of nodes that lack any relationships in the graph,
indicating they are standalone. It will return an empty list if no disconnected nodes
exist.
Returns:
--------
- List[str]: A list of identifiers for disconnected nodes.
"""
query_str = """
MATCH (n:Node)
WHERE NOT EXISTS((n)-[]-())
RETURN n.id
"""
result = await self.query(query_str)
return [str(row[0]) for row in result]
# Graph Meta-Data Operations
async def get_model_independent_graph_data(self) -> Dict[str, List[str]]:
"""
Get graph data independent of any specific data model.
This method returns a representation of the graph that includes distinct node labels and
relationship types, making it easier to analyze the graph's structure without tying it
to a specific implementation.
Returns:
--------
- Dict[str, List[str]]: A dictionary summarizing the node labels and relationship
types present in the graph.
"""
node_labels = await self.query("MATCH (n:Node) RETURN DISTINCT labels(n)")
rel_types = await self.query("MATCH ()-[r:EDGE]->() RETURN DISTINCT r.relationship_name")
return {
"node_labels": [label[0] for label in node_labels],
"relationship_types": [rel[0] for rel in rel_types],
}
async def delete_graph(self) -> None:
"""
Delete all data from the graph database.
This method deletes all nodes and relationships from the graph database.
It raises exceptions for failures occurring during deletion processes.
"""
try:
if self.connection:
self.connection.close()
self.connection = None
if self.db:
self.db.close()
self.db = None
db_dir = os.path.dirname(self.db_path)
db_name = os.path.basename(self.db_path)
file_storage = get_file_storage(db_dir)
if await file_storage.is_file(db_name):
await file_storage.remove(db_name)
await file_storage.remove(f"{db_name}.lock")
else:
await file_storage.remove_all(db_name)
logger.info(f"Deleted Kuzu database files at {self.db_path}")
except Exception as e:
logger.error(f"Failed to delete graph data: {e}")
raise
async def get_document_subgraph(self, data_id: str):
"""
Get all nodes that should be deleted when removing a document.
This method constructs a complex query that identifies all nodes related to a specified
document and returns a dictionary of these nodes. Ensures thorough checks for orphaned
entities and inaccurate relationships that should be removed alongside the document.
Parameters:
-----------
- data_id (str): The identifier for the document to query against.
Returns:
--------
A dictionary containing details of the document and associated nodes that need to be
deleted, or None if no related nodes are found.
"""
query = """
MATCH (doc:Node)
WHERE (doc.type = 'TextDocument' OR doc.type = 'PdfDocument' OR doc.type = 'AudioDocument' OR doc.type = 'ImageDocument' OR doc.type = 'UnstructuredDocument') AND doc.id = $data_id
OPTIONAL MATCH (doc)<-[e1:EDGE]-(chunk:Node)
WHERE e1.relationship_name = 'is_part_of' AND chunk.type = 'DocumentChunk'
OPTIONAL MATCH (chunk)-[e2:EDGE]->(entity:Node)
WHERE e2.relationship_name = 'contains' AND entity.type = 'Entity'
AND NOT EXISTS {
MATCH (entity)<-[e3:EDGE]-(otherChunk:Node)-[e4:EDGE]->(otherDoc:Node)
WHERE e3.relationship_name = 'contains'
AND e4.relationship_name = 'is_part_of'
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument' OR otherDoc.type = 'AudioDocument' OR otherDoc.type = 'ImageDocument' OR otherDoc.type = 'UnstructuredDocument')
AND otherDoc.id <> doc.id
}
OPTIONAL MATCH (chunk)<-[e5:EDGE]-(made_node:Node)
WHERE e5.relationship_name = 'made_from' AND made_node.type = 'TextSummary'
OPTIONAL MATCH (entity)-[e6:EDGE]->(type:Node)
WHERE e6.relationship_name = 'is_a' AND type.type = 'EntityType'
AND NOT EXISTS {
MATCH (type)<-[e7:EDGE]-(otherEntity:Node)-[e8:EDGE]-(otherChunk:Node)-[e9:EDGE]-(otherDoc:Node)
WHERE e7.relationship_name = 'is_a'
AND e8.relationship_name = 'contains'
AND e9.relationship_name = 'is_part_of'
AND otherEntity.type = 'Entity'
AND otherChunk.type = 'DocumentChunk'
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument' OR otherDoc.type = 'AudioDocument' OR otherDoc.type = 'ImageDocument' OR otherDoc.type = 'UnstructuredDocument')
AND otherDoc.id <> doc.id
}
RETURN
COLLECT(DISTINCT doc) as document,
COLLECT(DISTINCT chunk) as chunks,
COLLECT(DISTINCT entity) as orphan_entities,
COLLECT(DISTINCT made_node) as made_from_nodes,
COLLECT(DISTINCT type) as orphan_types
"""
result = await self.query(query, {"data_id": f"{data_id}"})
if not result or not result[0]:
return None
# Convert tuple to dictionary
return {
"document": result[0][0],
"chunks": result[0][1],
"orphan_entities": result[0][2],
"made_from_nodes": result[0][3],
"orphan_types": result[0][4],
}
async def get_degree_one_nodes(self, node_type: str):
"""
Get all nodes that have only one connection.
This method retrieves nodes which are connected to exactly one other node, identified by
their specific type. It raises a ValueError if the input type is invalid and processes
queries efficiently to return targeted results.
Parameters:
-----------
- node_type (str): The type of nodes to filter by, must be 'Entity' or 'EntityType'.
Returns:
--------
A list of nodes that have only one connection, as identified by the specified type.
"""
if not node_type or node_type not in ["Entity", "EntityType"]:
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
query = f"""
MATCH (n:Node)
WHERE n.type = '{node_type}'
WITH n, COUNT {{ MATCH (n)--() }} as degree
WHERE degree = 1
RETURN n
"""
result = await self.query(query)
return [record[0] for record in result] if result else []
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
"""
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
Parameters:
-----------
- limit (int): The maximum number of interaction IDs to return.
Returns:
--------
- List[str]: A list of interaction IDs, sorted by created_at descending.
"""
query = """
MATCH (n)
WHERE n.type = 'CogneeUserInteraction'
RETURN n.id as id
ORDER BY n.created_at DESC
LIMIT $limit
"""
rows = await self.query(query, {"limit": limit})
id_list = [row[0] for row in rows]
return id_list
async def apply_feedback_weight(
self,
node_ids: List[str],
weight: float,
) -> None:
"""
Increment `feedback_weight` inside r.properties JSON for edges where
relationship_name = 'used_graph_element_to_answer'.
"""
# Step 1: fetch matching edges
query = """
MATCH (n:Node)-[r:EDGE]->()
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
RETURN r.properties, n.id
"""
results = await self.query(query, {"node_ids": node_ids})
# Step 2: update JSON client-side
updates = []
for props_json, source_id in results:
try:
props = json.loads(props_json) if props_json else {}
except json.JSONDecodeError:
props = {}
props["feedback_weight"] = props.get("feedback_weight", 0) + weight
updates.append((source_id, json.dumps(props)))
# Step 3: write back
for node_id, new_props in updates:
update_query = """
MATCH (n:Node)-[r:EDGE]->()
WHERE n.id = $node_id AND r.relationship_name = 'used_graph_element_to_answer'
SET r.properties = $props
"""
await self.query(update_query, {"node_id": node_id, "props": new_props})
async def collect_events(self, ids: List[str]) -> Any:
"""
Collect all Event-type nodes reachable within 1..2 hops
from the given node IDs.
Args:
graph_engine: Object exposing an async .query(str) -> Any
ids: List of node IDs (strings)
Returns:
List of events
"""
event_collection_cypher = """UNWIND [{quoted}] AS uid
MATCH (start {{id: uid}})
MATCH (start)-[*1..2]-(event)
WHERE event.type = 'Event'
WITH DISTINCT event
RETURN collect(event) AS events;
"""
query = event_collection_cypher.format(quoted=ids)
result = await self.query(query)
events = []
for node in result[0][0]:
props = json.loads(node["properties"])
event = {
"id": node["id"],
"name": node["name"],
"description": props.get("description"),
}
if props.get("location"):
event["location"] = props["location"]
events.append(event)
return [{"events": events}]
async def collect_time_ids(
self,
time_from: Optional[Timestamp] = None,
time_to: Optional[Timestamp] = None,
) -> str:
"""
Collect IDs of Timestamp nodes between time_from and time_to.
Args:
graph_engine: Object exposing an async .query(query, params) -> list[dict]
time_from: Lower bound int (inclusive), optional
time_to: Upper bound int (inclusive), optional
Returns:
A string of quoted IDs: "'id1', 'id2', 'id3'"
(ready for use in a Cypher UNWIND clause).
"""
ids: List[str] = []
if time_from and time_to:
time_from = date_to_int(time_from)
time_to = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
WHERE n.type = 'Timestamp'
// Extract time_at from the JSON string and cast to INT64
WITH n, json_extract(n.properties, '$.time_at') AS t_str
WITH n,
CASE
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from}
AND t <= {time_to}
RETURN n.id as id
"""
elif time_from:
time_from = date_to_int(time_from)
cypher = f"""
MATCH (n:Node)
WHERE n.type = 'Timestamp'
// Extract time_at from the JSON string and cast to INT64
WITH n, json_extract(n.properties, '$.time_at') AS t_str
WITH n,
CASE
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from}
RETURN n.id as id
"""
elif time_to:
time_to = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
WHERE n.type = 'Timestamp'
// Extract time_at from the JSON string and cast to INT64
WITH n, json_extract(n.properties, '$.time_at') AS t_str
WITH n,
CASE
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t <= {time_to}
RETURN n.id as id
"""
else:
return ids
time_nodes = await self.query(cypher)
time_ids_list = [item[0] for item in time_nodes]
return ", ".join(f"'{uid}'" for uid in time_ids_list)