diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 03b16eb33..9284ebbbb 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -8,11 +8,11 @@ from neo4j import AsyncSession from neo4j import AsyncGraphDatabase from neo4j.exceptions import Neo4jError from contextlib import asynccontextmanager -from typing import Optional, Any, List, Dict, Type, Tuple +from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator from cognee.infrastructure.engine import DataPoint from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int -from cognee.tasks.temporal_graph.models import Timestamp +from cognee.modules.engine.models.Timestamp import Timestamp from cognee.shared.logging_utils import get_logger, ERROR from cognee.infrastructure.databases.graph.graph_db_interface import ( GraphDBInterface, @@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface): ) @asynccontextmanager - async def get_session(self) -> AsyncSession: + async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """ Get a session for database operations. """ async with self.driver.session(database=self.graph_database_name) as session: yield session - @deadlock_retry() + @deadlock_retry() # type: ignore async def query( self, query: str, @@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface): async with self.get_session() as session: result = await session.run(query, parameters=params) data = await result.data() + # TODO: why we don't return List[Dict[str, Any]]? return data except Neo4jError as error: logger.error("Neo4j query error: %s", error, exc_info=True) @@ -141,21 +142,27 @@ class Neo4jAdapter(GraphDBInterface): ) return results[0]["node_exists"] if len(results) > 0 else False - async def add_node(self, node: DataPoint): + async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None: """ - Add a new node to the database based on the provided DataPoint object. + Add a new node to the database based on the provided DataPoint object or string ID. Parameters: ----------- - - node (DataPoint): An instance of DataPoint representing the node to add. - - Returns: - -------- - - The result of the query execution, typically the ID of the added node. + - node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add. + - properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID. """ - serialized_properties = self.serialize_properties(node.model_dump()) + if isinstance(node, str): + # TODO: this was not handled in the original code, check if it is correct + # Handle string node ID with properties parameter + node_id = node + node_label = "Node" # Default label for string nodes + serialized_properties = self.serialize_properties(properties or {}) + else: + # Handle DataPoint object + node_id = str(node.id) + node_label = type(node).__name__ + serialized_properties = self.serialize_properties(node.model_dump()) query = dedent( f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}}) @@ -167,16 +174,16 @@ class Neo4jAdapter(GraphDBInterface): ) params = { - "node_id": str(node.id), - "node_label": type(node).__name__, + "node_id": node_id, + "node_label": node_label, "properties": serialized_properties, } - return await self.query(query, params) + await self.query(query, params) - @record_graph_changes - @override_distributed(queued_add_nodes) - async def add_nodes(self, nodes: list[DataPoint]) -> None: + @record_graph_changes # type: ignore + @override_distributed(queued_add_nodes) # type: ignore + async def add_nodes(self, nodes: List[DataPoint]) -> None: """ Add multiple nodes to the database in a single query. @@ -201,7 +208,7 @@ class Neo4jAdapter(GraphDBInterface): RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId """ - nodes = [ + node_params = [ { "node_id": str(node.id), "label": type(node).__name__, @@ -210,10 +217,9 @@ class Neo4jAdapter(GraphDBInterface): for node in nodes ] - results = await self.query(query, dict(nodes=nodes)) - return results + await self.query(query, dict(nodes=node_params)) - async def extract_node(self, node_id: str): + async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]: """ Retrieve a single node from the database by its ID. @@ -231,7 +237,7 @@ class Neo4jAdapter(GraphDBInterface): return results[0] if len(results) > 0 else None - async def extract_nodes(self, node_ids: List[str]): + async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]: """ Retrieve multiple nodes from the database by their IDs. @@ -256,7 +262,7 @@ class Neo4jAdapter(GraphDBInterface): return [result["node"] for result in results] - async def delete_node(self, node_id: str): + async def delete_node(self, node_id: str) -> None: """ Remove a node from the database identified by its ID. @@ -273,7 +279,7 @@ class Neo4jAdapter(GraphDBInterface): query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node" params = {"node_id": node_id} - return await self.query(query, params) + await self.query(query, params) async def delete_nodes(self, node_ids: list[str]) -> None: """ @@ -296,18 +302,18 @@ class Neo4jAdapter(GraphDBInterface): params = {"node_ids": node_ids} - return await self.query(query, params) + await self.query(query, params) - async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: + async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool: """ Check if an edge exists between two nodes with the specified IDs and edge label. Parameters: ----------- - - from_node (UUID): The ID of the node from which the edge originates. - - to_node (UUID): The ID of the node to which the edge points. - - edge_label (str): The label of the edge to check for existence. + - source_id (str): The ID of the node from which the edge originates. + - target_id (str): The ID of the node to which the edge points. + - relationship_name (str): The label of the edge to check for existence. Returns: -------- @@ -315,27 +321,28 @@ class Neo4jAdapter(GraphDBInterface): - bool: True if the edge exists, otherwise False. """ query = f""" - MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`) - WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id + MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`) + WHERE from_node.id = $source_id AND to_node.id = $target_id RETURN COUNT(relationship) > 0 AS edge_exists """ params = { - "from_node_id": str(from_node), - "to_node_id": str(to_node), + "source_id": str(source_id), + "target_id": str(target_id), } edge_exists = await self.query(query, params) + assert isinstance(edge_exists, bool), "Edge existence check should return a boolean" return edge_exists - async def has_edges(self, edges): + async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]: """ Check if multiple edges exist based on provided edge criteria. Parameters: ----------- - - edges: A list of edge specifications to check for existence. + - edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties) Returns: -------- @@ -369,29 +376,24 @@ class Neo4jAdapter(GraphDBInterface): async def add_edge( self, - from_node: UUID, - to_node: UUID, + source_id: str, + target_id: str, relationship_name: str, - edge_properties: Optional[Dict[str, Any]] = {}, - ): + properties: Optional[Dict[str, Any]] = None, + ) -> None: """ Create a new edge between two nodes with specified properties. Parameters: ----------- - - from_node (UUID): The ID of the source node of the edge. - - to_node (UUID): The ID of the target node of the edge. + - source_id (str): The ID of the source node of the edge. + - target_id (str): The ID of the target node of the edge. - relationship_name (str): The type/label of the edge to create. - - edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign - to the edge. (default {}) - - Returns: - -------- - - The result of the query execution, typically indicating the created edge. + - properties (Optional[Dict[str, Any]]): A dictionary of properties to assign + to the edge. (default None) """ - serialized_properties = self.serialize_properties(edge_properties) + serialized_properties = self.serialize_properties(properties or {}) query = dedent( f"""\ @@ -405,13 +407,13 @@ class Neo4jAdapter(GraphDBInterface): ) params = { - "from_node": str(from_node), - "to_node": str(to_node), + "from_node": str(source_id), # Adding str as callsites may still be passing UUID + "to_node": str(target_id), "relationship_name": relationship_name, "properties": serialized_properties, } - return await self.query(query, params) + await self.query(query, params) def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]: """ @@ -445,9 +447,9 @@ class Neo4jAdapter(GraphDBInterface): return flattened - @record_graph_changes - @override_distributed(queued_add_edges) - async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: + @record_graph_changes # type: ignore + @override_distributed(queued_add_edges) # type: ignore + async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: """ Add multiple edges between nodes in a single query. @@ -478,10 +480,10 @@ class Neo4jAdapter(GraphDBInterface): ) YIELD rel RETURN rel""" - edges = [ + edge_params = [ { - "from_node": str(edge[0]), - "to_node": str(edge[1]), + "from_node": str(edge[0]), # Adding str as callsites may still be passing UUID + "to_node": str(edge[1]), # Adding str as callsites may still be passing UUID "relationship_name": edge[2], "properties": self._flatten_edge_properties( { @@ -495,13 +497,12 @@ class Neo4jAdapter(GraphDBInterface): ] try: - results = await self.query(query, dict(edges=edges)) - return results + await self.query(query, dict(edges=edge_params)) except Neo4jError as error: logger.error("Neo4j query error: %s", error, exc_info=True) raise error - async def get_edges(self, node_id: str): + async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]: """ Retrieve all edges connected to a specified node.