undo changes for graph engines

This commit is contained in:
Daulet Amirkhanov 2025-09-07 21:01:25 +01:00
parent e68a89f737
commit e87b77fda6
3 changed files with 155 additions and 218 deletions

View file

@ -10,7 +10,7 @@ from kuzu.database import Database
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.tasks.temporal_graph.models import Timestamp
logger = get_logger()
@ -146,21 +146,15 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
def _get_connection(self) -> Connection:
"""Get the connection to the Kuzu database."""
if not self.connection:
raise RuntimeError("Kuzu database connection not initialized")
return self.connection
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
if self._get_connection():
if self.connection:
async with self.KUZU_ASYNC_LOCK:
self._get_connection().execute("CHECKPOINT;")
self.connection.execute("CHECKPOINT;")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
@ -173,9 +167,7 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
"""
Execute a Kuzu query asynchronously with automatic reconnection.
@ -198,32 +190,23 @@ class KuzuAdapter(GraphDBInterface):
loop = asyncio.get_running_loop()
params = params or {}
def blocking_query() -> List[Tuple[Any, ...]]:
def blocking_query():
try:
if not self._get_connection():
if not self.connection:
logger.debug("Reconnecting to Kuzu database...")
self._initialize_connection()
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
result = self._get_connection().execute(query, params)
result = self.connection.execute(query, params)
rows = []
if not isinstance(result, list):
result = [result]
# Handle QueryResult vs List[QueryResult] union type
for single_result in result:
while single_result.has_next():
row = single_result.get_next()
processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
while result.has_next():
row = result.get_next()
processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
return rows
except Exception as e:
logger.error(f"Query execution failed: {str(e)}")
@ -232,7 +215,7 @@ class KuzuAdapter(GraphDBInterface):
return await loop.run_in_executor(self.executor, blocking_query)
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[Optional[Connection], None]:
async def get_session(self):
"""
Get a database session.
@ -241,7 +224,7 @@ class KuzuAdapter(GraphDBInterface):
and on exit performs cleanup if necessary.
"""
try:
yield self._get_connection()
yield self.connection
finally:
pass
@ -272,7 +255,7 @@ class KuzuAdapter(GraphDBInterface):
def _edge_query_and_params(
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
) -> Tuple[str, Dict[str, Any]]:
) -> Tuple[str, dict]:
"""Build the edge creation query and parameters."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
query = """
@ -322,9 +305,7 @@ class KuzuAdapter(GraphDBInterface):
result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
async def add_node(self, node: DataPoint) -> None:
"""
Add a single node to the graph if it doesn't exist.
@ -338,32 +319,20 @@ class KuzuAdapter(GraphDBInterface):
- node (DataPoint): The node to be added, represented as a DataPoint.
"""
try:
if isinstance(node, str):
# Handle string node ID with properties parameter
node_properties = properties or {}
core_properties = {
"id": node,
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Use the passed properties, excluding core fields
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
else:
# Handle DataPoint object
node_properties = node.model_dump()
core_properties = {
"id": str(node_properties.get("id", "")),
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Remove core fields from other properties
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
# Extract core fields with defaults if not present
core_properties = {
"id": str(properties.get("id", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
# Remove core fields from other properties
for key in core_properties:
properties.pop(key, None)
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
# Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
@ -391,7 +360,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add node: {e}")
raise
@record_graph_changes # type: ignore
@record_graph_changes
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the graph in a batch operation.
@ -599,9 +568,7 @@ class KuzuAdapter(GraphDBInterface):
)
return result[0][0] if result else False
async def has_edges(
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
"""
Check if multiple edges exist in a batch operation.
@ -632,7 +599,7 @@ class KuzuAdapter(GraphDBInterface):
"to_id": str(to_node), # Ensure string type
"relationship_name": str(edge_label), # Ensure string type
}
for from_node, to_node, edge_label, _ in edges
for from_node, to_node, edge_label in edges
]
# Batch check query with direct string comparison
@ -648,21 +615,9 @@ class KuzuAdapter(GraphDBInterface):
results = await self.query(query, {"edges": edge_params})
# Convert results back to tuples and ensure string types
# Find the original edge properties for each existing edge
# TODO: get review on this
existing_edges = []
for row in results:
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
# Find the original properties from the input edges
original_props = {}
for orig_from, orig_to, orig_rel, orig_props in edges:
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
original_props = orig_props
break
existing_edges.append((from_id, to_id, rel_name, original_props))
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
# TODO: otherwise, we can just return dummy properties since they are not used apparently
return existing_edges
except Exception as e:
@ -671,10 +626,10 @@ class KuzuAdapter(GraphDBInterface):
async def add_edge(
self,
source_id: str,
target_id: str,
from_node: str,
to_node: str,
relationship_name: str,
properties: Optional[Dict[str, Any]] = None,
edge_properties: Dict[str, Any] = {},
) -> None:
"""
Add an edge between two nodes.
@ -686,23 +641,23 @@ class KuzuAdapter(GraphDBInterface):
Parameters:
-----------
- source_id (str): The identifier of the source node from which the edge originates.
- target_id (str): The identifier of the target node to which the edge points.
- from_node (str): The identifier of the source node from which the edge originates.
- to_node (str): The identifier of the target node to which the edge points.
- relationship_name (str): The label of the edge to be created, representing the
relationship name.
- properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge.
(default None)
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
(default {})
"""
try:
query, params = self._edge_query_and_params(
source_id, target_id, relationship_name, properties or {}
from_node, to_node, relationship_name, edge_properties
)
await self.query(query, params)
except Exception as e:
logger.error(f"Failed to add edge: {e}")
raise
@record_graph_changes # type: ignore
@record_graph_changes
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""
Add multiple edges in a batch operation.
@ -757,7 +712,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add edges in batch: {e}")
raise
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
"""
Get all edges connected to a node.
@ -772,8 +727,9 @@ class KuzuAdapter(GraphDBInterface):
Returns:
--------
- List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_id, relationship_name, target_id, edge_properties).
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_node, relationship_name, target_node), with source_node and
target_node as dictionaries of node properties.
"""
query_str = """
MATCH (n:Node)-[r]-(m:Node)
@ -794,14 +750,12 @@ class KuzuAdapter(GraphDBInterface):
"""
try:
results = await self.query(query_str, {"node_id": node_id})
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
edges = []
for row in results:
if row and len(row) == 3:
source_node = self._parse_node_properties(row[0])
relationship_name = row[1]
target_node = self._parse_node_properties(row[2])
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
edges.append((source_node, row[1], target_node))
return edges
except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}")
@ -1023,7 +977,7 @@ class KuzuAdapter(GraphDBInterface):
return []
async def get_connections(
self, node_id: Union[str, UUID]
self, node_id: str
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
"""
Get all nodes connected to a given node.
@ -1065,9 +1019,7 @@ class KuzuAdapter(GraphDBInterface):
}
"""
try:
# Convert UUID to string if needed
node_id_str = str(node_id)
results = await self.query(query_str, {"node_id": node_id_str})
results = await self.query(query_str, {"node_id": node_id})
edges = []
for row in results:
if row and len(row) == 3:
@ -1225,7 +1177,7 @@ class KuzuAdapter(GraphDBInterface):
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]:
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
"""
Get subgraph for a set of nodes based on type and names.
@ -1273,9 +1225,9 @@ class KuzuAdapter(GraphDBInterface):
RETURN n.id, n.name, n.type, n.properties
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, Dict[str, Any]]] = []
nodes: List[Tuple[str, dict]] = []
for node_id, name, typ, props in node_rows:
data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ}
data = {"id": node_id, "name": name, "type": typ}
if props:
try:
data.update(json.loads(props))
@ -1289,22 +1241,22 @@ class KuzuAdapter(GraphDBInterface):
RETURN a.id, b.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
edges: List[Tuple[str, str, str, dict]] = []
for from_id, to_id, rel_type, props in edge_rows:
edge_data: Dict[str, Any] = {}
data = {}
if props:
try:
edge_data = json.loads(props)
data = json.loads(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, edge_data))
edges.append((from_id, to_id, rel_type, data))
return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs
return nodes, edges
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
):
"""
Get filtered nodes and relationships based on attributes.
@ -1347,7 +1299,7 @@ class KuzuAdapter(GraphDBInterface):
)
return ([n[0] for n in nodes], [e[0] for e in edges])
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.
@ -1370,8 +1322,8 @@ class KuzuAdapter(GraphDBInterface):
try:
# Get basic graph data
nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
num_edges = len(edges[0]["elements"]) if edges else 0
# Calculate mandatory metrics
mandatory_metrics = {
@ -1531,8 +1483,8 @@ class KuzuAdapter(GraphDBInterface):
It raises exceptions for failures occurring during deletion processes.
"""
try:
if self._get_connection():
self._get_connection().close()
if self.connection:
self.connection.close()
self.connection = None
if self.db:
self.db.close()
@ -1563,7 +1515,7 @@ class KuzuAdapter(GraphDBInterface):
occur during file deletions or initializations carefully.
"""
try:
if self._get_connection():
if self.connection:
self.connection = None
if self.db:
self.db.close()
@ -1579,30 +1531,20 @@ class KuzuAdapter(GraphDBInterface):
# Reinitialize the database
self._initialize_connection()
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
# Verify the database is empty
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
if not isinstance(result, list):
result = [result]
for single_result in result:
_next = single_result.get_next()
if not isinstance(_next, list):
raise RuntimeError("Expected list of results")
count = _next[0] if _next else 0
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
count = result.get_next()[0] if result.has_next() else 0
if count > 0:
logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion"
)
self._get_connection().execute("MATCH (n:Node) DETACH DELETE n")
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
logger.info("Database cleared successfully")
except Exception as e:
logger.error(f"Error during database clearing: {e}")
raise
async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]:
async def get_document_subgraph(self, data_id: str):
"""
Get all nodes that should be deleted when removing a document.
@ -1674,7 +1616,7 @@ class KuzuAdapter(GraphDBInterface):
"orphan_types": result[0][4],
}
async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]:
async def get_degree_one_nodes(self, node_type: str):
"""
Get all nodes that have only one connection.
@ -1827,8 +1769,8 @@ class KuzuAdapter(GraphDBInterface):
ids: List[str] = []
if time_from and time_to:
time_from_int = date_to_int(time_from)
time_to_int = date_to_int(time_to)
time_from = date_to_int(time_from)
time_to = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
@ -1840,13 +1782,13 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from_int}
AND t <= {time_to_int}
WHERE t >= {time_from}
AND t <= {time_to}
RETURN n.id as id
"""
elif time_from:
time_from_int = date_to_int(time_from)
time_from = date_to_int(time_from)
cypher = f"""
MATCH (n:Node)
@ -1858,12 +1800,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from_int}
WHERE t >= {time_from}
RETURN n.id as id
"""
elif time_to:
time_to_int = date_to_int(time_to)
time_to = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
@ -1875,12 +1817,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t <= {time_to_int}
WHERE t <= {time_to}
RETURN n.id as id
"""
else:
return ", ".join(f"'{uid}'" for uid in ids)
return ids
time_nodes = await self.query(cypher)
time_ids_list = [item[0] for item in time_nodes]

View file

@ -2,7 +2,7 @@
from cognee.shared.logging_utils import get_logger
import json
from typing import Dict, Any, List, Optional, Tuple, Union
from typing import Dict, Any, List, Optional, Tuple
import aiohttp
from uuid import UUID
@ -14,7 +14,7 @@ logger = get_logger()
class UUIDEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles UUID objects."""
def default(self, obj: Union[UUID, Any]) -> Any:
def default(self, obj):
if isinstance(obj, UUID):
return str(obj)
return super().default(obj)
@ -36,7 +36,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
self.api_url = api_url
self.username = username
self.password = password
self._session: Optional[aiohttp.ClientSession] = None
self._session = None
self._schema_initialized = False
async def _get_session(self) -> aiohttp.ClientSession:
@ -45,13 +45,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
self._session = aiohttp.ClientSession()
return self._session
async def close(self) -> None:
async def close(self):
"""Close the adapter and its session."""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
async def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]:
async def _make_request(self, endpoint: str, data: dict) -> dict:
"""Make a request to the Kuzu API."""
url = f"{self.api_url}{endpoint}"
session = await self._get_session()
@ -73,15 +73,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
status=response.status,
message=error_detail,
)
return await response.json() # type: ignore
return await response.json()
except aiohttp.ClientError as e:
logger.error(f"API request failed: {str(e)}")
logger.error(f"Request data: {data}")
raise
async def query(
self, query: str, params: Optional[dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
"""Execute a Kuzu query via the REST API."""
try:
# Initialize schema if needed
@ -128,7 +126,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to check schema: {e}")
return False
async def _create_schema(self) -> None:
async def _create_schema(self):
"""Create the required schema tables."""
try:
# Create Node table if it doesn't exist
@ -182,7 +180,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to create schema: {e}")
raise
async def _initialize_schema(self) -> None:
async def _initialize_schema(self):
"""Initialize the database schema if it doesn't exist."""
if self._schema_initialized:
return

View file

@ -8,11 +8,11 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
from typing import Optional, Any, List, Dict, Type, Tuple
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
)
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
async def get_session(self) -> AsyncSession:
"""
Get a session for database operations.
"""
async with self.driver.session(database=self.graph_database_name) as session:
yield session
@deadlock_retry() # type: ignore
@deadlock_retry()
async def query(
self,
query: str,
@ -112,7 +112,6 @@ class Neo4jAdapter(GraphDBInterface):
async with self.get_session() as session:
result = await session.run(query, parameters=params)
data = await result.data()
# TODO: why we don't return List[Dict[str, Any]]?
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
@ -142,29 +141,21 @@ class Neo4jAdapter(GraphDBInterface):
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
async def add_node(self, node: DataPoint):
"""
Add a new node to the database based on the provided DataPoint object or string ID.
Add a new node to the database based on the provided DataPoint object.
Parameters:
-----------
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
- node (DataPoint): An instance of DataPoint representing the node to add.
Returns:
--------
The result of the query execution, typically the ID of the added node.
"""
if isinstance(node, str):
# TODO: this was not handled in the original code, check if it is correct
# Handle string node ID with properties parameter
node_id = node
node_label = "Node" # Default label for string nodes
serialized_properties = self.serialize_properties(properties or {})
else:
# Handle DataPoint object
node_id = str(node.id)
node_label = type(node).__name__
serialized_properties = self.serialize_properties(node.model_dump())
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent(
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
@ -176,16 +167,16 @@ class Neo4jAdapter(GraphDBInterface):
)
params = {
"node_id": node_id,
"node_label": node_label,
"node_id": str(node.id),
"node_label": type(node).__name__,
"properties": serialized_properties,
}
await self.query(query, params)
return await self.query(query, params)
@record_graph_changes # type: ignore
@override_distributed(queued_add_nodes) # type: ignore
async def add_nodes(self, nodes: List[DataPoint]) -> None:
@record_graph_changes
@override_distributed(queued_add_nodes)
async def add_nodes(self, nodes: list[DataPoint]) -> None:
"""
Add multiple nodes to the database in a single query.
@ -210,7 +201,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
"""
node_params = [
nodes = [
{
"node_id": str(node.id),
"label": type(node).__name__,
@ -219,9 +210,10 @@ class Neo4jAdapter(GraphDBInterface):
for node in nodes
]
await self.query(query, dict(nodes=node_params))
results = await self.query(query, dict(nodes=nodes))
return results
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
async def extract_node(self, node_id: str):
"""
Retrieve a single node from the database by its ID.
@ -239,7 +231,7 @@ class Neo4jAdapter(GraphDBInterface):
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
async def extract_nodes(self, node_ids: List[str]):
"""
Retrieve multiple nodes from the database by their IDs.
@ -264,7 +256,7 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in results]
async def delete_node(self, node_id: str) -> None:
async def delete_node(self, node_id: str):
"""
Remove a node from the database identified by its ID.
@ -281,7 +273,7 @@ class Neo4jAdapter(GraphDBInterface):
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id}
await self.query(query, params)
return await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None:
"""
@ -304,18 +296,18 @@ class Neo4jAdapter(GraphDBInterface):
params = {"node_ids": node_ids}
await self.query(query, params)
return await self.query(query, params)
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
"""
Check if an edge exists between two nodes with the specified IDs and edge label.
Parameters:
-----------
- source_id (str): The ID of the node from which the edge originates.
- target_id (str): The ID of the node to which the edge points.
- relationship_name (str): The label of the edge to check for existence.
- from_node (UUID): The ID of the node from which the edge originates.
- to_node (UUID): The ID of the node to which the edge points.
- edge_label (str): The label of the edge to check for existence.
Returns:
--------
@ -323,28 +315,27 @@ class Neo4jAdapter(GraphDBInterface):
- bool: True if the edge exists, otherwise False.
"""
query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $source_id AND to_node.id = $target_id
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
RETURN COUNT(relationship) > 0 AS edge_exists
"""
params = {
"source_id": str(source_id),
"target_id": str(target_id),
"from_node_id": str(from_node),
"to_node_id": str(to_node),
}
edge_exists = await self.query(query, params)
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
return edge_exists
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
async def has_edges(self, edges):
"""
Check if multiple edges exist based on provided edge criteria.
Parameters:
-----------
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
- edges: A list of edge specifications to check for existence.
Returns:
--------
@ -378,24 +369,29 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edge(
self,
source_id: str,
target_id: str,
from_node: UUID,
to_node: UUID,
relationship_name: str,
properties: Optional[Dict[str, Any]] = None,
) -> None:
edge_properties: Optional[Dict[str, Any]] = {},
):
"""
Create a new edge between two nodes with specified properties.
Parameters:
-----------
- source_id (str): The ID of the source node of the edge.
- target_id (str): The ID of the target node of the edge.
- from_node (UUID): The ID of the source node of the edge.
- to_node (UUID): The ID of the target node of the edge.
- relationship_name (str): The type/label of the edge to create.
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default None)
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default {})
Returns:
--------
The result of the query execution, typically indicating the created edge.
"""
serialized_properties = self.serialize_properties(properties or {})
serialized_properties = self.serialize_properties(edge_properties)
query = dedent(
f"""\
@ -409,13 +405,13 @@ class Neo4jAdapter(GraphDBInterface):
)
params = {
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
"to_node": str(target_id),
"from_node": str(from_node),
"to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties,
}
await self.query(query, params)
return await self.query(query, params)
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -449,9 +445,9 @@ class Neo4jAdapter(GraphDBInterface):
return flattened
@record_graph_changes # type: ignore
@override_distributed(queued_add_edges) # type: ignore
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
@record_graph_changes
@override_distributed(queued_add_edges)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
"""
Add multiple edges between nodes in a single query.
@ -482,10 +478,10 @@ class Neo4jAdapter(GraphDBInterface):
) YIELD rel
RETURN rel"""
edge_params = [
edges = [
{
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
"properties": self._flatten_edge_properties(
{
@ -499,12 +495,13 @@ class Neo4jAdapter(GraphDBInterface):
]
try:
await self.query(query, dict(edges=edge_params))
results = await self.query(query, dict(edges=edges))
return results
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
async def get_edges(self, node_id: str):
"""
Retrieve all edges connected to a specified node.