undo changes for graph engines

This commit is contained in:
Daulet Amirkhanov 2025-09-07 21:01:25 +01:00
parent e68a89f737
commit e87b77fda6
3 changed files with 155 additions and 218 deletions

View file

@ -10,7 +10,7 @@ from kuzu.database import Database
from datetime import datetime, timezone from datetime import datetime, timezone
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator from typing import Dict, Any, List, Union, Optional, Tuple, Type
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync from cognee.infrastructure.utils.run_sync import run_sync
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.modules.engine.models.Timestamp import Timestamp from cognee.tasks.temporal_graph.models import Timestamp
logger = get_logger() logger = get_logger()
@ -146,21 +146,15 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to initialize Kuzu database: {e}") logger.error(f"Failed to initialize Kuzu database: {e}")
raise e raise e
def _get_connection(self) -> Connection:
"""Get the connection to the Kuzu database."""
if not self.connection:
raise RuntimeError("Kuzu database connection not initialized")
return self.connection
async def push_to_s3(self) -> None: async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"): if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("") s3_file_storage = S3FileStorage("")
if self._get_connection(): if self.connection:
async with self.KUZU_ASYNC_LOCK: async with self.KUZU_ASYNC_LOCK:
self._get_connection().execute("CHECKPOINT;") self.connection.execute("CHECKPOINT;")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True) s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
@ -173,9 +167,7 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query( async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
""" """
Execute a Kuzu query asynchronously with automatic reconnection. Execute a Kuzu query asynchronously with automatic reconnection.
@ -198,32 +190,23 @@ class KuzuAdapter(GraphDBInterface):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
params = params or {} params = params or {}
def blocking_query() -> List[Tuple[Any, ...]]: def blocking_query():
try: try:
if not self._get_connection(): if not self.connection:
logger.debug("Reconnecting to Kuzu database...") logger.debug("Reconnecting to Kuzu database...")
self._initialize_connection() self._initialize_connection()
if not self._get_connection(): result = self.connection.execute(query, params)
raise RuntimeError("Failed to establish database connection")
result = self._get_connection().execute(query, params)
rows = [] rows = []
if not isinstance(result, list): while result.has_next():
result = [result] row = result.get_next()
processed_rows = []
# Handle QueryResult vs List[QueryResult] union type for val in row:
for single_result in result: if hasattr(val, "as_py"):
while single_result.has_next(): val = val.as_py()
row = single_result.get_next() processed_rows.append(val)
processed_rows = [] rows.append(tuple(processed_rows))
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
return rows return rows
except Exception as e: except Exception as e:
logger.error(f"Query execution failed: {str(e)}") logger.error(f"Query execution failed: {str(e)}")
@ -232,7 +215,7 @@ class KuzuAdapter(GraphDBInterface):
return await loop.run_in_executor(self.executor, blocking_query) return await loop.run_in_executor(self.executor, blocking_query)
@asynccontextmanager @asynccontextmanager
async def get_session(self) -> AsyncGenerator[Optional[Connection], None]: async def get_session(self):
""" """
Get a database session. Get a database session.
@ -241,7 +224,7 @@ class KuzuAdapter(GraphDBInterface):
and on exit performs cleanup if necessary. and on exit performs cleanup if necessary.
""" """
try: try:
yield self._get_connection() yield self.connection
finally: finally:
pass pass
@ -272,7 +255,7 @@ class KuzuAdapter(GraphDBInterface):
def _edge_query_and_params( def _edge_query_and_params(
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any] self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
) -> Tuple[str, Dict[str, Any]]: ) -> Tuple[str, dict]:
"""Build the edge creation query and parameters.""" """Build the edge creation query and parameters."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
query = """ query = """
@ -322,9 +305,7 @@ class KuzuAdapter(GraphDBInterface):
result = await self.query(query_str, {"id": node_id}) result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False return result[0][0] if result else False
async def add_node( async def add_node(self, node: DataPoint) -> None:
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
""" """
Add a single node to the graph if it doesn't exist. Add a single node to the graph if it doesn't exist.
@ -338,32 +319,20 @@ class KuzuAdapter(GraphDBInterface):
- node (DataPoint): The node to be added, represented as a DataPoint. - node (DataPoint): The node to be added, represented as a DataPoint.
""" """
try: try:
if isinstance(node, str): properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
# Handle string node ID with properties parameter
node_properties = properties or {}
core_properties = {
"id": node,
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Use the passed properties, excluding core fields
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
else:
# Handle DataPoint object
node_properties = node.model_dump()
core_properties = {
"id": str(node_properties.get("id", "")),
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Remove core fields from other properties
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder) # Extract core fields with defaults if not present
core_properties = {
"id": str(properties.get("id", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
# Remove core fields from other properties
for key in core_properties:
properties.pop(key, None)
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
# Add timestamps for new node # Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
@ -391,7 +360,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add node: {e}") logger.error(f"Failed to add node: {e}")
raise raise
@record_graph_changes # type: ignore @record_graph_changes
async def add_nodes(self, nodes: List[DataPoint]) -> None: async def add_nodes(self, nodes: List[DataPoint]) -> None:
""" """
Add multiple nodes to the graph in a batch operation. Add multiple nodes to the graph in a batch operation.
@ -599,9 +568,7 @@ class KuzuAdapter(GraphDBInterface):
) )
return result[0][0] if result else False return result[0][0] if result else False
async def has_edges( async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
""" """
Check if multiple edges exist in a batch operation. Check if multiple edges exist in a batch operation.
@ -632,7 +599,7 @@ class KuzuAdapter(GraphDBInterface):
"to_id": str(to_node), # Ensure string type "to_id": str(to_node), # Ensure string type
"relationship_name": str(edge_label), # Ensure string type "relationship_name": str(edge_label), # Ensure string type
} }
for from_node, to_node, edge_label, _ in edges for from_node, to_node, edge_label in edges
] ]
# Batch check query with direct string comparison # Batch check query with direct string comparison
@ -648,21 +615,9 @@ class KuzuAdapter(GraphDBInterface):
results = await self.query(query, {"edges": edge_params}) results = await self.query(query, {"edges": edge_params})
# Convert results back to tuples and ensure string types # Convert results back to tuples and ensure string types
# Find the original edge properties for each existing edge existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
# TODO: get review on this
existing_edges = []
for row in results:
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
# Find the original properties from the input edges
original_props = {}
for orig_from, orig_to, orig_rel, orig_props in edges:
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
original_props = orig_props
break
existing_edges.append((from_id, to_id, rel_name, original_props))
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked") logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
# TODO: otherwise, we can just return dummy properties since they are not used apparently
return existing_edges return existing_edges
except Exception as e: except Exception as e:
@ -671,10 +626,10 @@ class KuzuAdapter(GraphDBInterface):
async def add_edge( async def add_edge(
self, self,
source_id: str, from_node: str,
target_id: str, to_node: str,
relationship_name: str, relationship_name: str,
properties: Optional[Dict[str, Any]] = None, edge_properties: Dict[str, Any] = {},
) -> None: ) -> None:
""" """
Add an edge between two nodes. Add an edge between two nodes.
@ -686,23 +641,23 @@ class KuzuAdapter(GraphDBInterface):
Parameters: Parameters:
----------- -----------
- source_id (str): The identifier of the source node from which the edge originates. - from_node (str): The identifier of the source node from which the edge originates.
- target_id (str): The identifier of the target node to which the edge points. - to_node (str): The identifier of the target node to which the edge points.
- relationship_name (str): The label of the edge to be created, representing the - relationship_name (str): The label of the edge to be created, representing the
relationship name. relationship name.
- properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge. - edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
(default None) (default {})
""" """
try: try:
query, params = self._edge_query_and_params( query, params = self._edge_query_and_params(
source_id, target_id, relationship_name, properties or {} from_node, to_node, relationship_name, edge_properties
) )
await self.query(query, params) await self.query(query, params)
except Exception as e: except Exception as e:
logger.error(f"Failed to add edge: {e}") logger.error(f"Failed to add edge: {e}")
raise raise
@record_graph_changes # type: ignore @record_graph_changes
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
""" """
Add multiple edges in a batch operation. Add multiple edges in a batch operation.
@ -757,7 +712,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add edges in batch: {e}") logger.error(f"Failed to add edges in batch: {e}")
raise raise
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]: async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
""" """
Get all edges connected to a node. Get all edges connected to a node.
@ -772,8 +727,9 @@ class KuzuAdapter(GraphDBInterface):
Returns: Returns:
-------- --------
- List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each - List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_id, relationship_name, target_id, edge_properties). tuple contains (source_node, relationship_name, target_node), with source_node and
target_node as dictionaries of node properties.
""" """
query_str = """ query_str = """
MATCH (n:Node)-[r]-(m:Node) MATCH (n:Node)-[r]-(m:Node)
@ -794,14 +750,12 @@ class KuzuAdapter(GraphDBInterface):
""" """
try: try:
results = await self.query(query_str, {"node_id": node_id}) results = await self.query(query_str, {"node_id": node_id})
edges: List[Tuple[str, str, str, Dict[str, Any]]] = [] edges = []
for row in results: for row in results:
if row and len(row) == 3: if row and len(row) == 3:
source_node = self._parse_node_properties(row[0]) source_node = self._parse_node_properties(row[0])
relationship_name = row[1]
target_node = self._parse_node_properties(row[2]) target_node = self._parse_node_properties(row[2])
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason edges.append((source_node, row[1], target_node))
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
return edges return edges
except Exception as e: except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}") logger.error(f"Failed to get edges for node {node_id}: {e}")
@ -1023,7 +977,7 @@ class KuzuAdapter(GraphDBInterface):
return [] return []
async def get_connections( async def get_connections(
self, node_id: Union[str, UUID] self, node_id: str
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]: ) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
""" """
Get all nodes connected to a given node. Get all nodes connected to a given node.
@ -1065,9 +1019,7 @@ class KuzuAdapter(GraphDBInterface):
} }
""" """
try: try:
# Convert UUID to string if needed results = await self.query(query_str, {"node_id": node_id})
node_id_str = str(node_id)
results = await self.query(query_str, {"node_id": node_id_str})
edges = [] edges = []
for row in results: for row in results:
if row and len(row) == 3: if row and len(row) == 3:
@ -1225,7 +1177,7 @@ class KuzuAdapter(GraphDBInterface):
async def get_nodeset_subgraph( async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str] self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]: ) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
""" """
Get subgraph for a set of nodes based on type and names. Get subgraph for a set of nodes based on type and names.
@ -1273,9 +1225,9 @@ class KuzuAdapter(GraphDBInterface):
RETURN n.id, n.name, n.type, n.properties RETURN n.id, n.name, n.type, n.properties
""" """
node_rows = await self.query(nodes_query, {"ids": all_ids}) node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, Dict[str, Any]]] = [] nodes: List[Tuple[str, dict]] = []
for node_id, name, typ, props in node_rows: for node_id, name, typ, props in node_rows:
data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ} data = {"id": node_id, "name": name, "type": typ}
if props: if props:
try: try:
data.update(json.loads(props)) data.update(json.loads(props))
@ -1289,22 +1241,22 @@ class KuzuAdapter(GraphDBInterface):
RETURN a.id, b.id, r.relationship_name, r.properties RETURN a.id, b.id, r.relationship_name, r.properties
""" """
edge_rows = await self.query(edges_query, {"ids": all_ids}) edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, Dict[str, Any]]] = [] edges: List[Tuple[str, str, str, dict]] = []
for from_id, to_id, rel_type, props in edge_rows: for from_id, to_id, rel_type, props in edge_rows:
edge_data: Dict[str, Any] = {} data = {}
if props: if props:
try: try:
edge_data = json.loads(props) data = json.loads(props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}") logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, edge_data)) edges.append((from_id, to_id, rel_type, data))
return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs return nodes, edges
async def get_filtered_graph_data( async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]] self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ):
""" """
Get filtered nodes and relationships based on attributes. Get filtered nodes and relationships based on attributes.
@ -1347,7 +1299,7 @@ class KuzuAdapter(GraphDBInterface):
) )
return ([n[0] for n in nodes], [e[0] for e in edges]) return ([n[0] for n in nodes], [e[0] for e in edges])
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]: async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
""" """
Get metrics on graph structure and connectivity. Get metrics on graph structure and connectivity.
@ -1370,8 +1322,8 @@ class KuzuAdapter(GraphDBInterface):
try: try:
# Get basic graph data # Get basic graph data
nodes, edges = await self.get_model_independent_graph_data() nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string? num_nodes = len(nodes[0]["nodes"]) if nodes else 0
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string? num_edges = len(edges[0]["elements"]) if edges else 0
# Calculate mandatory metrics # Calculate mandatory metrics
mandatory_metrics = { mandatory_metrics = {
@ -1531,8 +1483,8 @@ class KuzuAdapter(GraphDBInterface):
It raises exceptions for failures occurring during deletion processes. It raises exceptions for failures occurring during deletion processes.
""" """
try: try:
if self._get_connection(): if self.connection:
self._get_connection().close() self.connection.close()
self.connection = None self.connection = None
if self.db: if self.db:
self.db.close() self.db.close()
@ -1563,7 +1515,7 @@ class KuzuAdapter(GraphDBInterface):
occur during file deletions or initializations carefully. occur during file deletions or initializations carefully.
""" """
try: try:
if self._get_connection(): if self.connection:
self.connection = None self.connection = None
if self.db: if self.db:
self.db.close() self.db.close()
@ -1579,30 +1531,20 @@ class KuzuAdapter(GraphDBInterface):
# Reinitialize the database # Reinitialize the database
self._initialize_connection() self._initialize_connection()
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
# Verify the database is empty # Verify the database is empty
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)") result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
if not isinstance(result, list): count = result.get_next()[0] if result.has_next() else 0
result = [result]
for single_result in result:
_next = single_result.get_next()
if not isinstance(_next, list):
raise RuntimeError("Expected list of results")
count = _next[0] if _next else 0
if count > 0: if count > 0:
logger.warning( logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion" f"Database still contains {count} nodes after clearing, forcing deletion"
) )
self._get_connection().execute("MATCH (n:Node) DETACH DELETE n") self.connection.execute("MATCH (n:Node) DETACH DELETE n")
logger.info("Database cleared successfully") logger.info("Database cleared successfully")
except Exception as e: except Exception as e:
logger.error(f"Error during database clearing: {e}") logger.error(f"Error during database clearing: {e}")
raise raise
async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]: async def get_document_subgraph(self, data_id: str):
""" """
Get all nodes that should be deleted when removing a document. Get all nodes that should be deleted when removing a document.
@ -1674,7 +1616,7 @@ class KuzuAdapter(GraphDBInterface):
"orphan_types": result[0][4], "orphan_types": result[0][4],
} }
async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]: async def get_degree_one_nodes(self, node_type: str):
""" """
Get all nodes that have only one connection. Get all nodes that have only one connection.
@ -1827,8 +1769,8 @@ class KuzuAdapter(GraphDBInterface):
ids: List[str] = [] ids: List[str] = []
if time_from and time_to: if time_from and time_to:
time_from_int = date_to_int(time_from) time_from = date_to_int(time_from)
time_to_int = date_to_int(time_to) time_to = date_to_int(time_to)
cypher = f""" cypher = f"""
MATCH (n:Node) MATCH (n:Node)
@ -1840,13 +1782,13 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64) ELSE CAST(t_str AS INT64)
END AS t END AS t
WHERE t >= {time_from_int} WHERE t >= {time_from}
AND t <= {time_to_int} AND t <= {time_to}
RETURN n.id as id RETURN n.id as id
""" """
elif time_from: elif time_from:
time_from_int = date_to_int(time_from) time_from = date_to_int(time_from)
cypher = f""" cypher = f"""
MATCH (n:Node) MATCH (n:Node)
@ -1858,12 +1800,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64) ELSE CAST(t_str AS INT64)
END AS t END AS t
WHERE t >= {time_from_int} WHERE t >= {time_from}
RETURN n.id as id RETURN n.id as id
""" """
elif time_to: elif time_to:
time_to_int = date_to_int(time_to) time_to = date_to_int(time_to)
cypher = f""" cypher = f"""
MATCH (n:Node) MATCH (n:Node)
@ -1875,12 +1817,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64) ELSE CAST(t_str AS INT64)
END AS t END AS t
WHERE t <= {time_to_int} WHERE t <= {time_to}
RETURN n.id as id RETURN n.id as id
""" """
else: else:
return ", ".join(f"'{uid}'" for uid in ids) return ids
time_nodes = await self.query(cypher) time_nodes = await self.query(cypher)
time_ids_list = [item[0] for item in time_nodes] time_ids_list = [item[0] for item in time_nodes]

View file

@ -2,7 +2,7 @@
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
import json import json
from typing import Dict, Any, List, Optional, Tuple, Union from typing import Dict, Any, List, Optional, Tuple
import aiohttp import aiohttp
from uuid import UUID from uuid import UUID
@ -14,7 +14,7 @@ logger = get_logger()
class UUIDEncoder(json.JSONEncoder): class UUIDEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles UUID objects.""" """Custom JSON encoder that handles UUID objects."""
def default(self, obj: Union[UUID, Any]) -> Any: def default(self, obj):
if isinstance(obj, UUID): if isinstance(obj, UUID):
return str(obj) return str(obj)
return super().default(obj) return super().default(obj)
@ -36,7 +36,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
self.api_url = api_url self.api_url = api_url
self.username = username self.username = username
self.password = password self.password = password
self._session: Optional[aiohttp.ClientSession] = None self._session = None
self._schema_initialized = False self._schema_initialized = False
async def _get_session(self) -> aiohttp.ClientSession: async def _get_session(self) -> aiohttp.ClientSession:
@ -45,13 +45,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
self._session = aiohttp.ClientSession() self._session = aiohttp.ClientSession()
return self._session return self._session
async def close(self) -> None: async def close(self):
"""Close the adapter and its session.""" """Close the adapter and its session."""
if self._session and not self._session.closed: if self._session and not self._session.closed:
await self._session.close() await self._session.close()
self._session = None self._session = None
async def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: async def _make_request(self, endpoint: str, data: dict) -> dict:
"""Make a request to the Kuzu API.""" """Make a request to the Kuzu API."""
url = f"{self.api_url}{endpoint}" url = f"{self.api_url}{endpoint}"
session = await self._get_session() session = await self._get_session()
@ -73,15 +73,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
status=response.status, status=response.status,
message=error_detail, message=error_detail,
) )
return await response.json() # type: ignore return await response.json()
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error(f"API request failed: {str(e)}") logger.error(f"API request failed: {str(e)}")
logger.error(f"Request data: {data}") logger.error(f"Request data: {data}")
raise raise
async def query( async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
self, query: str, params: Optional[dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
"""Execute a Kuzu query via the REST API.""" """Execute a Kuzu query via the REST API."""
try: try:
# Initialize schema if needed # Initialize schema if needed
@ -128,7 +126,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to check schema: {e}") logger.error(f"Failed to check schema: {e}")
return False return False
async def _create_schema(self) -> None: async def _create_schema(self):
"""Create the required schema tables.""" """Create the required schema tables."""
try: try:
# Create Node table if it doesn't exist # Create Node table if it doesn't exist
@ -182,7 +180,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to create schema: {e}") logger.error(f"Failed to create schema: {e}")
raise raise
async def _initialize_schema(self) -> None: async def _initialize_schema(self):
"""Initialize the database schema if it doesn't exist.""" """Initialize the database schema if it doesn't exist."""
if self._schema_initialized: if self._schema_initialized:
return return

View file

@ -8,11 +8,11 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator from typing import Optional, Any, List, Dict, Type, Tuple
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.modules.engine.models.Timestamp import Timestamp from cognee.tasks.temporal_graph.models import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import ( from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface, GraphDBInterface,
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
) )
@asynccontextmanager @asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]: async def get_session(self) -> AsyncSession:
""" """
Get a session for database operations. Get a session for database operations.
""" """
async with self.driver.session(database=self.graph_database_name) as session: async with self.driver.session(database=self.graph_database_name) as session:
yield session yield session
@deadlock_retry() # type: ignore @deadlock_retry()
async def query( async def query(
self, self,
query: str, query: str,
@ -112,7 +112,6 @@ class Neo4jAdapter(GraphDBInterface):
async with self.get_session() as session: async with self.get_session() as session:
result = await session.run(query, parameters=params) result = await session.run(query, parameters=params)
data = await result.data() data = await result.data()
# TODO: why we don't return List[Dict[str, Any]]?
return data return data
except Neo4jError as error: except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True) logger.error("Neo4j query error: %s", error, exc_info=True)
@ -142,29 +141,21 @@ class Neo4jAdapter(GraphDBInterface):
) )
return results[0]["node_exists"] if len(results) > 0 else False return results[0]["node_exists"] if len(results) > 0 else False
async def add_node( async def add_node(self, node: DataPoint):
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
""" """
Add a new node to the database based on the provided DataPoint object or string ID. Add a new node to the database based on the provided DataPoint object.
Parameters: Parameters:
----------- -----------
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add. - node (DataPoint): An instance of DataPoint representing the node to add.
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
Returns:
--------
The result of the query execution, typically the ID of the added node.
""" """
if isinstance(node, str): serialized_properties = self.serialize_properties(node.model_dump())
# TODO: this was not handled in the original code, check if it is correct
# Handle string node ID with properties parameter
node_id = node
node_label = "Node" # Default label for string nodes
serialized_properties = self.serialize_properties(properties or {})
else:
# Handle DataPoint object
node_id = str(node.id)
node_label = type(node).__name__
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent( query = dedent(
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}}) f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
@ -176,16 +167,16 @@ class Neo4jAdapter(GraphDBInterface):
) )
params = { params = {
"node_id": node_id, "node_id": str(node.id),
"node_label": node_label, "node_label": type(node).__name__,
"properties": serialized_properties, "properties": serialized_properties,
} }
await self.query(query, params) return await self.query(query, params)
@record_graph_changes # type: ignore @record_graph_changes
@override_distributed(queued_add_nodes) # type: ignore @override_distributed(queued_add_nodes)
async def add_nodes(self, nodes: List[DataPoint]) -> None: async def add_nodes(self, nodes: list[DataPoint]) -> None:
""" """
Add multiple nodes to the database in a single query. Add multiple nodes to the database in a single query.
@ -210,7 +201,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
""" """
node_params = [ nodes = [
{ {
"node_id": str(node.id), "node_id": str(node.id),
"label": type(node).__name__, "label": type(node).__name__,
@ -219,9 +210,10 @@ class Neo4jAdapter(GraphDBInterface):
for node in nodes for node in nodes
] ]
await self.query(query, dict(nodes=node_params)) results = await self.query(query, dict(nodes=nodes))
return results
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]: async def extract_node(self, node_id: str):
""" """
Retrieve a single node from the database by its ID. Retrieve a single node from the database by its ID.
@ -239,7 +231,7 @@ class Neo4jAdapter(GraphDBInterface):
return results[0] if len(results) > 0 else None return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]: async def extract_nodes(self, node_ids: List[str]):
""" """
Retrieve multiple nodes from the database by their IDs. Retrieve multiple nodes from the database by their IDs.
@ -264,7 +256,7 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in results] return [result["node"] for result in results]
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str):
""" """
Remove a node from the database identified by its ID. Remove a node from the database identified by its ID.
@ -281,7 +273,7 @@ class Neo4jAdapter(GraphDBInterface):
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node" query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id} params = {"node_id": node_id}
await self.query(query, params) return await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None: async def delete_nodes(self, node_ids: list[str]) -> None:
""" """
@ -304,18 +296,18 @@ class Neo4jAdapter(GraphDBInterface):
params = {"node_ids": node_ids} params = {"node_ids": node_ids}
await self.query(query, params) return await self.query(query, params)
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool: async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
""" """
Check if an edge exists between two nodes with the specified IDs and edge label. Check if an edge exists between two nodes with the specified IDs and edge label.
Parameters: Parameters:
----------- -----------
- source_id (str): The ID of the node from which the edge originates. - from_node (UUID): The ID of the node from which the edge originates.
- target_id (str): The ID of the node to which the edge points. - to_node (UUID): The ID of the node to which the edge points.
- relationship_name (str): The label of the edge to check for existence. - edge_label (str): The label of the edge to check for existence.
Returns: Returns:
-------- --------
@ -323,28 +315,27 @@ class Neo4jAdapter(GraphDBInterface):
- bool: True if the edge exists, otherwise False. - bool: True if the edge exists, otherwise False.
""" """
query = f""" query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`) MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $source_id AND to_node.id = $target_id WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
RETURN COUNT(relationship) > 0 AS edge_exists RETURN COUNT(relationship) > 0 AS edge_exists
""" """
params = { params = {
"source_id": str(source_id), "from_node_id": str(from_node),
"target_id": str(target_id), "to_node_id": str(to_node),
} }
edge_exists = await self.query(query, params) edge_exists = await self.query(query, params)
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
return edge_exists return edge_exists
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]: async def has_edges(self, edges):
""" """
Check if multiple edges exist based on provided edge criteria. Check if multiple edges exist based on provided edge criteria.
Parameters: Parameters:
----------- -----------
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties) - edges: A list of edge specifications to check for existence.
Returns: Returns:
-------- --------
@ -378,24 +369,29 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edge( async def add_edge(
self, self,
source_id: str, from_node: UUID,
target_id: str, to_node: UUID,
relationship_name: str, relationship_name: str,
properties: Optional[Dict[str, Any]] = None, edge_properties: Optional[Dict[str, Any]] = {},
) -> None: ):
""" """
Create a new edge between two nodes with specified properties. Create a new edge between two nodes with specified properties.
Parameters: Parameters:
----------- -----------
- source_id (str): The ID of the source node of the edge. - from_node (UUID): The ID of the source node of the edge.
- target_id (str): The ID of the target node of the edge. - to_node (UUID): The ID of the target node of the edge.
- relationship_name (str): The type/label of the edge to create. - relationship_name (str): The type/label of the edge to create.
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign - edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default None) to the edge. (default {})
Returns:
--------
The result of the query execution, typically indicating the created edge.
""" """
serialized_properties = self.serialize_properties(properties or {}) serialized_properties = self.serialize_properties(edge_properties)
query = dedent( query = dedent(
f"""\ f"""\
@ -409,13 +405,13 @@ class Neo4jAdapter(GraphDBInterface):
) )
params = { params = {
"from_node": str(source_id), # Adding str as callsites may still be passing UUID "from_node": str(from_node),
"to_node": str(target_id), "to_node": str(to_node),
"relationship_name": relationship_name, "relationship_name": relationship_name,
"properties": serialized_properties, "properties": serialized_properties,
} }
await self.query(query, params) return await self.query(query, params)
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]: def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
""" """
@ -449,9 +445,9 @@ class Neo4jAdapter(GraphDBInterface):
return flattened return flattened
@record_graph_changes # type: ignore @record_graph_changes
@override_distributed(queued_add_edges) # type: ignore @override_distributed(queued_add_edges)
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
""" """
Add multiple edges between nodes in a single query. Add multiple edges between nodes in a single query.
@ -482,10 +478,10 @@ class Neo4jAdapter(GraphDBInterface):
) YIELD rel ) YIELD rel
RETURN rel""" RETURN rel"""
edge_params = [ edges = [
{ {
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID "from_node": str(edge[0]),
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID "to_node": str(edge[1]),
"relationship_name": edge[2], "relationship_name": edge[2],
"properties": self._flatten_edge_properties( "properties": self._flatten_edge_properties(
{ {
@ -499,12 +495,13 @@ class Neo4jAdapter(GraphDBInterface):
] ]
try: try:
await self.query(query, dict(edges=edge_params)) results = await self.query(query, dict(edges=edges))
return results
except Neo4jError as error: except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True) logger.error("Neo4j query error: %s", error, exc_info=True)
raise error raise error
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]: async def get_edges(self, node_id: str):
""" """
Retrieve all edges connected to a specified node. Retrieve all edges connected to a specified node.