diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index a3430d370..085d7cd00 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -10,7 +10,7 @@ from kuzu.database import Database from datetime import datetime, timezone from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator +from typing import Dict, Any, List, Union, Optional, Tuple, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.utils.run_sync import run_sync @@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import ( from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import JSONEncoder from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int -from cognee.modules.engine.models.Timestamp import Timestamp +from cognee.tasks.temporal_graph.models import Timestamp logger = get_logger() @@ -146,21 +146,15 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to initialize Kuzu database: {e}") raise e - def _get_connection(self) -> Connection: - """Get the connection to the Kuzu database.""" - if not self.connection: - raise RuntimeError("Kuzu database connection not initialized") - return self.connection - async def push_to_s3(self) -> None: if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"): from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage s3_file_storage = S3FileStorage("") - if self._get_connection(): + if self.connection: async with self.KUZU_ASYNC_LOCK: - self._get_connection().execute("CHECKPOINT;") + self.connection.execute("CHECKPOINT;") s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True) @@ -173,9 +167,7 @@ class KuzuAdapter(GraphDBInterface): except FileNotFoundError: logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") - async def query( - self, query: str, params: Optional[Dict[str, Any]] = None - ) -> List[Tuple[Any, ...]]: + async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: """ Execute a Kuzu query asynchronously with automatic reconnection. @@ -198,32 +190,23 @@ class KuzuAdapter(GraphDBInterface): loop = asyncio.get_running_loop() params = params or {} - def blocking_query() -> List[Tuple[Any, ...]]: + def blocking_query(): try: - if not self._get_connection(): + if not self.connection: logger.debug("Reconnecting to Kuzu database...") self._initialize_connection() - if not self._get_connection(): - raise RuntimeError("Failed to establish database connection") - - result = self._get_connection().execute(query, params) + result = self.connection.execute(query, params) rows = [] - if not isinstance(result, list): - result = [result] - - # Handle QueryResult vs List[QueryResult] union type - for single_result in result: - while single_result.has_next(): - row = single_result.get_next() - processed_rows = [] - for val in row: - if hasattr(val, "as_py"): - val = val.as_py() - processed_rows.append(val) - rows.append(tuple(processed_rows)) - + while result.has_next(): + row = result.get_next() + processed_rows = [] + for val in row: + if hasattr(val, "as_py"): + val = val.as_py() + processed_rows.append(val) + rows.append(tuple(processed_rows)) return rows except Exception as e: logger.error(f"Query execution failed: {str(e)}") @@ -232,7 +215,7 @@ class KuzuAdapter(GraphDBInterface): return await loop.run_in_executor(self.executor, blocking_query) @asynccontextmanager - async def get_session(self) -> AsyncGenerator[Optional[Connection], None]: + async def get_session(self): """ Get a database session. @@ -241,7 +224,7 @@ class KuzuAdapter(GraphDBInterface): and on exit performs cleanup if necessary. """ try: - yield self._get_connection() + yield self.connection finally: pass @@ -272,7 +255,7 @@ class KuzuAdapter(GraphDBInterface): def _edge_query_and_params( self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any] - ) -> Tuple[str, Dict[str, Any]]: + ) -> Tuple[str, dict]: """Build the edge creation query and parameters.""" now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") query = """ @@ -322,9 +305,7 @@ class KuzuAdapter(GraphDBInterface): result = await self.query(query_str, {"id": node_id}) return result[0][0] if result else False - async def add_node( - self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None - ) -> None: + async def add_node(self, node: DataPoint) -> None: """ Add a single node to the graph if it doesn't exist. @@ -338,32 +319,20 @@ class KuzuAdapter(GraphDBInterface): - node (DataPoint): The node to be added, represented as a DataPoint. """ try: - if isinstance(node, str): - # Handle string node ID with properties parameter - node_properties = properties or {} - core_properties = { - "id": node, - "name": str(node_properties.get("name", "")), - "type": str(node_properties.get("type", "")), - } - # Use the passed properties, excluding core fields - other_properties = { - k: v for k, v in node_properties.items() if k not in ["id", "name", "type"] - } - else: - # Handle DataPoint object - node_properties = node.model_dump() - core_properties = { - "id": str(node_properties.get("id", "")), - "name": str(node_properties.get("name", "")), - "type": str(node_properties.get("type", "")), - } - # Remove core fields from other properties - other_properties = { - k: v for k, v in node_properties.items() if k not in ["id", "name", "type"] - } + properties = node.model_dump() if hasattr(node, "model_dump") else vars(node) - core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder) + # Extract core fields with defaults if not present + core_properties = { + "id": str(properties.get("id", "")), + "name": str(properties.get("name", "")), + "type": str(properties.get("type", "")), + } + + # Remove core fields from other properties + for key in core_properties: + properties.pop(key, None) + + core_properties["properties"] = json.dumps(properties, cls=JSONEncoder) # Add timestamps for new node now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") @@ -391,7 +360,7 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to add node: {e}") raise - @record_graph_changes # type: ignore + @record_graph_changes async def add_nodes(self, nodes: List[DataPoint]) -> None: """ Add multiple nodes to the graph in a batch operation. @@ -599,9 +568,7 @@ class KuzuAdapter(GraphDBInterface): ) return result[0][0] if result else False - async def has_edges( - self, edges: List[Tuple[str, str, str, Dict[str, Any]]] - ) -> List[Tuple[str, str, str, Dict[str, Any]]]: + async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]: """ Check if multiple edges exist in a batch operation. @@ -632,7 +599,7 @@ class KuzuAdapter(GraphDBInterface): "to_id": str(to_node), # Ensure string type "relationship_name": str(edge_label), # Ensure string type } - for from_node, to_node, edge_label, _ in edges + for from_node, to_node, edge_label in edges ] # Batch check query with direct string comparison @@ -648,21 +615,9 @@ class KuzuAdapter(GraphDBInterface): results = await self.query(query, {"edges": edge_params}) # Convert results back to tuples and ensure string types - # Find the original edge properties for each existing edge - # TODO: get review on this - existing_edges = [] - for row in results: - from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2]) - # Find the original properties from the input edges - original_props = {} - for orig_from, orig_to, orig_rel, orig_props in edges: - if orig_from == from_id and orig_to == to_id and orig_rel == rel_name: - original_props = orig_props - break - existing_edges.append((from_id, to_id, rel_name, original_props)) + existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results] logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked") - # TODO: otherwise, we can just return dummy properties since they are not used apparently return existing_edges except Exception as e: @@ -671,10 +626,10 @@ class KuzuAdapter(GraphDBInterface): async def add_edge( self, - source_id: str, - target_id: str, + from_node: str, + to_node: str, relationship_name: str, - properties: Optional[Dict[str, Any]] = None, + edge_properties: Dict[str, Any] = {}, ) -> None: """ Add an edge between two nodes. @@ -686,23 +641,23 @@ class KuzuAdapter(GraphDBInterface): Parameters: ----------- - - source_id (str): The identifier of the source node from which the edge originates. - - target_id (str): The identifier of the target node to which the edge points. + - from_node (str): The identifier of the source node from which the edge originates. + - to_node (str): The identifier of the target node to which the edge points. - relationship_name (str): The label of the edge to be created, representing the relationship name. - - properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge. - (default None) + - edge_properties (Dict[str, Any]): A dictionary containing properties for the edge. + (default {}) """ try: query, params = self._edge_query_and_params( - source_id, target_id, relationship_name, properties or {} + from_node, to_node, relationship_name, edge_properties ) await self.query(query, params) except Exception as e: logger.error(f"Failed to add edge: {e}") raise - @record_graph_changes # type: ignore + @record_graph_changes async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: """ Add multiple edges in a batch operation. @@ -757,7 +712,7 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to add edges in batch: {e}") raise - async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]: + async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: """ Get all edges connected to a node. @@ -772,8 +727,9 @@ class KuzuAdapter(GraphDBInterface): Returns: -------- - - List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each - tuple contains (source_id, relationship_name, target_id, edge_properties). + - List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each + tuple contains (source_node, relationship_name, target_node), with source_node and + target_node as dictionaries of node properties. """ query_str = """ MATCH (n:Node)-[r]-(m:Node) @@ -794,14 +750,12 @@ class KuzuAdapter(GraphDBInterface): """ try: results = await self.query(query_str, {"node_id": node_id}) - edges: List[Tuple[str, str, str, Dict[str, Any]]] = [] + edges = [] for row in results: if row and len(row) == 3: source_node = self._parse_node_properties(row[0]) - relationship_name = row[1] target_node = self._parse_node_properties(row[2]) - # TODO: any edge properties we can add? Adding empty to avoid modifying query without reason - edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings + edges.append((source_node, row[1], target_node)) return edges except Exception as e: logger.error(f"Failed to get edges for node {node_id}: {e}") @@ -1023,7 +977,7 @@ class KuzuAdapter(GraphDBInterface): return [] async def get_connections( - self, node_id: Union[str, UUID] + self, node_id: str ) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]: """ Get all nodes connected to a given node. @@ -1065,9 +1019,7 @@ class KuzuAdapter(GraphDBInterface): } """ try: - # Convert UUID to string if needed - node_id_str = str(node_id) - results = await self.query(query_str, {"node_id": node_id_str}) + results = await self.query(query_str, {"node_id": node_id}) edges = [] for row in results: if row and len(row) == 3: @@ -1225,7 +1177,7 @@ class KuzuAdapter(GraphDBInterface): async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] - ) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]: + ) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]: """ Get subgraph for a set of nodes based on type and names. @@ -1273,9 +1225,9 @@ class KuzuAdapter(GraphDBInterface): RETURN n.id, n.name, n.type, n.properties """ node_rows = await self.query(nodes_query, {"ids": all_ids}) - nodes: List[Tuple[str, Dict[str, Any]]] = [] + nodes: List[Tuple[str, dict]] = [] for node_id, name, typ, props in node_rows: - data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ} + data = {"id": node_id, "name": name, "type": typ} if props: try: data.update(json.loads(props)) @@ -1289,22 +1241,22 @@ class KuzuAdapter(GraphDBInterface): RETURN a.id, b.id, r.relationship_name, r.properties """ edge_rows = await self.query(edges_query, {"ids": all_ids}) - edges: List[Tuple[str, str, str, Dict[str, Any]]] = [] + edges: List[Tuple[str, str, str, dict]] = [] for from_id, to_id, rel_type, props in edge_rows: - edge_data: Dict[str, Any] = {} + data = {} if props: try: - edge_data = json.loads(props) + data = json.loads(props) except json.JSONDecodeError: logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}") - edges.append((from_id, to_id, rel_type, edge_data)) + edges.append((from_id, to_id, rel_type, data)) - return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs + return nodes, edges async def get_filtered_graph_data( self, attribute_filters: List[Dict[str, List[Union[str, int]]]] - ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + ): """ Get filtered nodes and relationships based on attributes. @@ -1347,7 +1299,7 @@ class KuzuAdapter(GraphDBInterface): ) return ([n[0] for n in nodes], [e[0] for e in edges]) - async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]: + async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]: """ Get metrics on graph structure and connectivity. @@ -1370,8 +1322,8 @@ class KuzuAdapter(GraphDBInterface): try: # Get basic graph data nodes, edges = await self.get_model_independent_graph_data() - num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string? - num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string? + num_nodes = len(nodes[0]["nodes"]) if nodes else 0 + num_edges = len(edges[0]["elements"]) if edges else 0 # Calculate mandatory metrics mandatory_metrics = { @@ -1531,8 +1483,8 @@ class KuzuAdapter(GraphDBInterface): It raises exceptions for failures occurring during deletion processes. """ try: - if self._get_connection(): - self._get_connection().close() + if self.connection: + self.connection.close() self.connection = None if self.db: self.db.close() @@ -1563,7 +1515,7 @@ class KuzuAdapter(GraphDBInterface): occur during file deletions or initializations carefully. """ try: - if self._get_connection(): + if self.connection: self.connection = None if self.db: self.db.close() @@ -1579,30 +1531,20 @@ class KuzuAdapter(GraphDBInterface): # Reinitialize the database self._initialize_connection() - - if not self._get_connection(): - raise RuntimeError("Failed to establish database connection") - # Verify the database is empty - result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)") - if not isinstance(result, list): - result = [result] - for single_result in result: - _next = single_result.get_next() - if not isinstance(_next, list): - raise RuntimeError("Expected list of results") - count = _next[0] if _next else 0 + result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)") + count = result.get_next()[0] if result.has_next() else 0 if count > 0: logger.warning( f"Database still contains {count} nodes after clearing, forcing deletion" ) - self._get_connection().execute("MATCH (n:Node) DETACH DELETE n") + self.connection.execute("MATCH (n:Node) DETACH DELETE n") logger.info("Database cleared successfully") except Exception as e: logger.error(f"Error during database clearing: {e}") raise - async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]: + async def get_document_subgraph(self, data_id: str): """ Get all nodes that should be deleted when removing a document. @@ -1674,7 +1616,7 @@ class KuzuAdapter(GraphDBInterface): "orphan_types": result[0][4], } - async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]: + async def get_degree_one_nodes(self, node_type: str): """ Get all nodes that have only one connection. @@ -1827,8 +1769,8 @@ class KuzuAdapter(GraphDBInterface): ids: List[str] = [] if time_from and time_to: - time_from_int = date_to_int(time_from) - time_to_int = date_to_int(time_to) + time_from = date_to_int(time_from) + time_to = date_to_int(time_to) cypher = f""" MATCH (n:Node) @@ -1840,13 +1782,13 @@ class KuzuAdapter(GraphDBInterface): WHEN t_str IS NULL OR t_str = '' THEN NULL ELSE CAST(t_str AS INT64) END AS t - WHERE t >= {time_from_int} - AND t <= {time_to_int} + WHERE t >= {time_from} + AND t <= {time_to} RETURN n.id as id """ elif time_from: - time_from_int = date_to_int(time_from) + time_from = date_to_int(time_from) cypher = f""" MATCH (n:Node) @@ -1858,12 +1800,12 @@ class KuzuAdapter(GraphDBInterface): WHEN t_str IS NULL OR t_str = '' THEN NULL ELSE CAST(t_str AS INT64) END AS t - WHERE t >= {time_from_int} + WHERE t >= {time_from} RETURN n.id as id """ elif time_to: - time_to_int = date_to_int(time_to) + time_to = date_to_int(time_to) cypher = f""" MATCH (n:Node) @@ -1875,12 +1817,12 @@ class KuzuAdapter(GraphDBInterface): WHEN t_str IS NULL OR t_str = '' THEN NULL ELSE CAST(t_str AS INT64) END AS t - WHERE t <= {time_to_int} + WHERE t <= {time_to} RETURN n.id as id """ else: - return ", ".join(f"'{uid}'" for uid in ids) + return ids time_nodes = await self.query(cypher) time_ids_list = [item[0] for item in time_nodes] diff --git a/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py index 7dcb5e2a6..c75b70f75 100644 --- a/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py @@ -2,7 +2,7 @@ from cognee.shared.logging_utils import get_logger import json -from typing import Dict, Any, List, Optional, Tuple, Union +from typing import Dict, Any, List, Optional, Tuple import aiohttp from uuid import UUID @@ -14,7 +14,7 @@ logger = get_logger() class UUIDEncoder(json.JSONEncoder): """Custom JSON encoder that handles UUID objects.""" - def default(self, obj: Union[UUID, Any]) -> Any: + def default(self, obj): if isinstance(obj, UUID): return str(obj) return super().default(obj) @@ -36,7 +36,7 @@ class RemoteKuzuAdapter(KuzuAdapter): self.api_url = api_url self.username = username self.password = password - self._session: Optional[aiohttp.ClientSession] = None + self._session = None self._schema_initialized = False async def _get_session(self) -> aiohttp.ClientSession: @@ -45,13 +45,13 @@ class RemoteKuzuAdapter(KuzuAdapter): self._session = aiohttp.ClientSession() return self._session - async def close(self) -> None: + async def close(self): """Close the adapter and its session.""" if self._session and not self._session.closed: await self._session.close() self._session = None - async def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: + async def _make_request(self, endpoint: str, data: dict) -> dict: """Make a request to the Kuzu API.""" url = f"{self.api_url}{endpoint}" session = await self._get_session() @@ -73,15 +73,13 @@ class RemoteKuzuAdapter(KuzuAdapter): status=response.status, message=error_detail, ) - return await response.json() # type: ignore + return await response.json() except aiohttp.ClientError as e: logger.error(f"API request failed: {str(e)}") logger.error(f"Request data: {data}") raise - async def query( - self, query: str, params: Optional[dict[str, Any]] = None - ) -> List[Tuple[Any, ...]]: + async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: """Execute a Kuzu query via the REST API.""" try: # Initialize schema if needed @@ -128,7 +126,7 @@ class RemoteKuzuAdapter(KuzuAdapter): logger.error(f"Failed to check schema: {e}") return False - async def _create_schema(self) -> None: + async def _create_schema(self): """Create the required schema tables.""" try: # Create Node table if it doesn't exist @@ -182,7 +180,7 @@ class RemoteKuzuAdapter(KuzuAdapter): logger.error(f"Failed to create schema: {e}") raise - async def _initialize_schema(self) -> None: + async def _initialize_schema(self): """Initialize the database schema if it doesn't exist.""" if self._schema_initialized: return diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 58f859576..03b16eb33 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -8,11 +8,11 @@ from neo4j import AsyncSession from neo4j import AsyncGraphDatabase from neo4j.exceptions import Neo4jError from contextlib import asynccontextmanager -from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator +from typing import Optional, Any, List, Dict, Type, Tuple from cognee.infrastructure.engine import DataPoint from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int -from cognee.modules.engine.models.Timestamp import Timestamp +from cognee.tasks.temporal_graph.models import Timestamp from cognee.shared.logging_utils import get_logger, ERROR from cognee.infrastructure.databases.graph.graph_db_interface import ( GraphDBInterface, @@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface): ) @asynccontextmanager - async def get_session(self) -> AsyncGenerator[AsyncSession, None]: + async def get_session(self) -> AsyncSession: """ Get a session for database operations. """ async with self.driver.session(database=self.graph_database_name) as session: yield session - @deadlock_retry() # type: ignore + @deadlock_retry() async def query( self, query: str, @@ -112,7 +112,6 @@ class Neo4jAdapter(GraphDBInterface): async with self.get_session() as session: result = await session.run(query, parameters=params) data = await result.data() - # TODO: why we don't return List[Dict[str, Any]]? return data except Neo4jError as error: logger.error("Neo4j query error: %s", error, exc_info=True) @@ -142,29 +141,21 @@ class Neo4jAdapter(GraphDBInterface): ) return results[0]["node_exists"] if len(results) > 0 else False - async def add_node( - self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None - ) -> None: + async def add_node(self, node: DataPoint): """ - Add a new node to the database based on the provided DataPoint object or string ID. + Add a new node to the database based on the provided DataPoint object. Parameters: ----------- - - node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add. - - properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID. + - node (DataPoint): An instance of DataPoint representing the node to add. + + Returns: + -------- + + The result of the query execution, typically the ID of the added node. """ - if isinstance(node, str): - # TODO: this was not handled in the original code, check if it is correct - # Handle string node ID with properties parameter - node_id = node - node_label = "Node" # Default label for string nodes - serialized_properties = self.serialize_properties(properties or {}) - else: - # Handle DataPoint object - node_id = str(node.id) - node_label = type(node).__name__ - serialized_properties = self.serialize_properties(node.model_dump()) + serialized_properties = self.serialize_properties(node.model_dump()) query = dedent( f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}}) @@ -176,16 +167,16 @@ class Neo4jAdapter(GraphDBInterface): ) params = { - "node_id": node_id, - "node_label": node_label, + "node_id": str(node.id), + "node_label": type(node).__name__, "properties": serialized_properties, } - await self.query(query, params) + return await self.query(query, params) - @record_graph_changes # type: ignore - @override_distributed(queued_add_nodes) # type: ignore - async def add_nodes(self, nodes: List[DataPoint]) -> None: + @record_graph_changes + @override_distributed(queued_add_nodes) + async def add_nodes(self, nodes: list[DataPoint]) -> None: """ Add multiple nodes to the database in a single query. @@ -210,7 +201,7 @@ class Neo4jAdapter(GraphDBInterface): RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId """ - node_params = [ + nodes = [ { "node_id": str(node.id), "label": type(node).__name__, @@ -219,9 +210,10 @@ class Neo4jAdapter(GraphDBInterface): for node in nodes ] - await self.query(query, dict(nodes=node_params)) + results = await self.query(query, dict(nodes=nodes)) + return results - async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]: + async def extract_node(self, node_id: str): """ Retrieve a single node from the database by its ID. @@ -239,7 +231,7 @@ class Neo4jAdapter(GraphDBInterface): return results[0] if len(results) > 0 else None - async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]: + async def extract_nodes(self, node_ids: List[str]): """ Retrieve multiple nodes from the database by their IDs. @@ -264,7 +256,7 @@ class Neo4jAdapter(GraphDBInterface): return [result["node"] for result in results] - async def delete_node(self, node_id: str) -> None: + async def delete_node(self, node_id: str): """ Remove a node from the database identified by its ID. @@ -281,7 +273,7 @@ class Neo4jAdapter(GraphDBInterface): query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node" params = {"node_id": node_id} - await self.query(query, params) + return await self.query(query, params) async def delete_nodes(self, node_ids: list[str]) -> None: """ @@ -304,18 +296,18 @@ class Neo4jAdapter(GraphDBInterface): params = {"node_ids": node_ids} - await self.query(query, params) + return await self.query(query, params) - async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool: + async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: """ Check if an edge exists between two nodes with the specified IDs and edge label. Parameters: ----------- - - source_id (str): The ID of the node from which the edge originates. - - target_id (str): The ID of the node to which the edge points. - - relationship_name (str): The label of the edge to check for existence. + - from_node (UUID): The ID of the node from which the edge originates. + - to_node (UUID): The ID of the node to which the edge points. + - edge_label (str): The label of the edge to check for existence. Returns: -------- @@ -323,28 +315,27 @@ class Neo4jAdapter(GraphDBInterface): - bool: True if the edge exists, otherwise False. """ query = f""" - MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`) - WHERE from_node.id = $source_id AND to_node.id = $target_id + MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`) + WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id RETURN COUNT(relationship) > 0 AS edge_exists """ params = { - "source_id": str(source_id), - "target_id": str(target_id), + "from_node_id": str(from_node), + "to_node_id": str(to_node), } edge_exists = await self.query(query, params) - assert isinstance(edge_exists, bool), "Edge existence check should return a boolean" return edge_exists - async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]: + async def has_edges(self, edges): """ Check if multiple edges exist based on provided edge criteria. Parameters: ----------- - - edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties) + - edges: A list of edge specifications to check for existence. Returns: -------- @@ -378,24 +369,29 @@ class Neo4jAdapter(GraphDBInterface): async def add_edge( self, - source_id: str, - target_id: str, + from_node: UUID, + to_node: UUID, relationship_name: str, - properties: Optional[Dict[str, Any]] = None, - ) -> None: + edge_properties: Optional[Dict[str, Any]] = {}, + ): """ Create a new edge between two nodes with specified properties. Parameters: ----------- - - source_id (str): The ID of the source node of the edge. - - target_id (str): The ID of the target node of the edge. + - from_node (UUID): The ID of the source node of the edge. + - to_node (UUID): The ID of the target node of the edge. - relationship_name (str): The type/label of the edge to create. - - properties (Optional[Dict[str, Any]]): A dictionary of properties to assign - to the edge. (default None) + - edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign + to the edge. (default {}) + + Returns: + -------- + + The result of the query execution, typically indicating the created edge. """ - serialized_properties = self.serialize_properties(properties or {}) + serialized_properties = self.serialize_properties(edge_properties) query = dedent( f"""\ @@ -409,13 +405,13 @@ class Neo4jAdapter(GraphDBInterface): ) params = { - "from_node": str(source_id), # Adding str as callsites may still be passing UUID - "to_node": str(target_id), + "from_node": str(from_node), + "to_node": str(to_node), "relationship_name": relationship_name, "properties": serialized_properties, } - await self.query(query, params) + return await self.query(query, params) def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]: """ @@ -449,9 +445,9 @@ class Neo4jAdapter(GraphDBInterface): return flattened - @record_graph_changes # type: ignore - @override_distributed(queued_add_edges) # type: ignore - async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: + @record_graph_changes + @override_distributed(queued_add_edges) + async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: """ Add multiple edges between nodes in a single query. @@ -482,10 +478,10 @@ class Neo4jAdapter(GraphDBInterface): ) YIELD rel RETURN rel""" - edge_params = [ + edges = [ { - "from_node": str(edge[0]), # Adding str as callsites may still be passing UUID - "to_node": str(edge[1]), # Adding str as callsites may still be passing UUID + "from_node": str(edge[0]), + "to_node": str(edge[1]), "relationship_name": edge[2], "properties": self._flatten_edge_properties( { @@ -499,12 +495,13 @@ class Neo4jAdapter(GraphDBInterface): ] try: - await self.query(query, dict(edges=edge_params)) + results = await self.query(query, dict(edges=edges)) + return results except Neo4jError as error: logger.error("Neo4j query error: %s", error, exc_info=True) raise error - async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]: + async def get_edges(self, node_id: str): """ Retrieve all edges connected to a specified node.