mypy: version Neo4j adapter

This commit is contained in:
Daulet Amirkhanov 2025-09-04 16:07:36 +01:00
parent b9cd847e9d
commit 0fb962e29a

View file

@ -8,11 +8,11 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import ( from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface, GraphDBInterface,
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
) )
@asynccontextmanager @asynccontextmanager
async def get_session(self) -> AsyncSession: async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
""" """
Get a session for database operations. Get a session for database operations.
""" """
async with self.driver.session(database=self.graph_database_name) as session: async with self.driver.session(database=self.graph_database_name) as session:
yield session yield session
@deadlock_retry() @deadlock_retry() # type: ignore
async def query( async def query(
self, self,
query: str, query: str,
@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface):
async with self.get_session() as session: async with self.get_session() as session:
result = await session.run(query, parameters=params) result = await session.run(query, parameters=params)
data = await result.data() data = await result.data()
# TODO: why we don't return List[Dict[str, Any]]?
return data return data
except Neo4jError as error: except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True) logger.error("Neo4j query error: %s", error, exc_info=True)
@ -141,21 +142,27 @@ class Neo4jAdapter(GraphDBInterface):
) )
return results[0]["node_exists"] if len(results) > 0 else False return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node: DataPoint): async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
""" """
Add a new node to the database based on the provided DataPoint object. Add a new node to the database based on the provided DataPoint object or string ID.
Parameters: Parameters:
----------- -----------
- node (DataPoint): An instance of DataPoint representing the node to add. - node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
Returns:
--------
The result of the query execution, typically the ID of the added node.
""" """
serialized_properties = self.serialize_properties(node.model_dump()) if isinstance(node, str):
# TODO: this was not handled in the original code, check if it is correct
# Handle string node ID with properties parameter
node_id = node
node_label = "Node" # Default label for string nodes
serialized_properties = self.serialize_properties(properties or {})
else:
# Handle DataPoint object
node_id = str(node.id)
node_label = type(node).__name__
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent( query = dedent(
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}}) f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
@ -167,16 +174,16 @@ class Neo4jAdapter(GraphDBInterface):
) )
params = { params = {
"node_id": str(node.id), "node_id": node_id,
"node_label": type(node).__name__, "node_label": node_label,
"properties": serialized_properties, "properties": serialized_properties,
} }
return await self.query(query, params) await self.query(query, params)
@record_graph_changes @record_graph_changes # type: ignore
@override_distributed(queued_add_nodes) @override_distributed(queued_add_nodes) # type: ignore
async def add_nodes(self, nodes: list[DataPoint]) -> None: async def add_nodes(self, nodes: List[DataPoint]) -> None:
""" """
Add multiple nodes to the database in a single query. Add multiple nodes to the database in a single query.
@ -201,7 +208,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
""" """
nodes = [ node_params = [
{ {
"node_id": str(node.id), "node_id": str(node.id),
"label": type(node).__name__, "label": type(node).__name__,
@ -210,10 +217,9 @@ class Neo4jAdapter(GraphDBInterface):
for node in nodes for node in nodes
] ]
results = await self.query(query, dict(nodes=nodes)) await self.query(query, dict(nodes=node_params))
return results
async def extract_node(self, node_id: str): async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
""" """
Retrieve a single node from the database by its ID. Retrieve a single node from the database by its ID.
@ -231,7 +237,7 @@ class Neo4jAdapter(GraphDBInterface):
return results[0] if len(results) > 0 else None return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]): async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
""" """
Retrieve multiple nodes from the database by their IDs. Retrieve multiple nodes from the database by their IDs.
@ -256,7 +262,7 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in results] return [result["node"] for result in results]
async def delete_node(self, node_id: str): async def delete_node(self, node_id: str) -> None:
""" """
Remove a node from the database identified by its ID. Remove a node from the database identified by its ID.
@ -273,7 +279,7 @@ class Neo4jAdapter(GraphDBInterface):
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node" query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id} params = {"node_id": node_id}
return await self.query(query, params) await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None: async def delete_nodes(self, node_ids: list[str]) -> None:
""" """
@ -296,18 +302,18 @@ class Neo4jAdapter(GraphDBInterface):
params = {"node_ids": node_ids} params = {"node_ids": node_ids}
return await self.query(query, params) await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
""" """
Check if an edge exists between two nodes with the specified IDs and edge label. Check if an edge exists between two nodes with the specified IDs and edge label.
Parameters: Parameters:
----------- -----------
- from_node (UUID): The ID of the node from which the edge originates. - source_id (str): The ID of the node from which the edge originates.
- to_node (UUID): The ID of the node to which the edge points. - target_id (str): The ID of the node to which the edge points.
- edge_label (str): The label of the edge to check for existence. - relationship_name (str): The label of the edge to check for existence.
Returns: Returns:
-------- --------
@ -315,27 +321,28 @@ class Neo4jAdapter(GraphDBInterface):
- bool: True if the edge exists, otherwise False. - bool: True if the edge exists, otherwise False.
""" """
query = f""" query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`) MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id WHERE from_node.id = $source_id AND to_node.id = $target_id
RETURN COUNT(relationship) > 0 AS edge_exists RETURN COUNT(relationship) > 0 AS edge_exists
""" """
params = { params = {
"from_node_id": str(from_node), "source_id": str(source_id),
"to_node_id": str(to_node), "target_id": str(target_id),
} }
edge_exists = await self.query(query, params) edge_exists = await self.query(query, params)
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
return edge_exists return edge_exists
async def has_edges(self, edges): async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
""" """
Check if multiple edges exist based on provided edge criteria. Check if multiple edges exist based on provided edge criteria.
Parameters: Parameters:
----------- -----------
- edges: A list of edge specifications to check for existence. - edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
Returns: Returns:
-------- --------
@ -369,29 +376,24 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edge( async def add_edge(
self, self,
from_node: UUID, source_id: str,
to_node: UUID, target_id: str,
relationship_name: str, relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = {}, properties: Optional[Dict[str, Any]] = None,
): ) -> None:
""" """
Create a new edge between two nodes with specified properties. Create a new edge between two nodes with specified properties.
Parameters: Parameters:
----------- -----------
- from_node (UUID): The ID of the source node of the edge. - source_id (str): The ID of the source node of the edge.
- to_node (UUID): The ID of the target node of the edge. - target_id (str): The ID of the target node of the edge.
- relationship_name (str): The type/label of the edge to create. - relationship_name (str): The type/label of the edge to create.
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign - properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default {}) to the edge. (default None)
Returns:
--------
The result of the query execution, typically indicating the created edge.
""" """
serialized_properties = self.serialize_properties(edge_properties) serialized_properties = self.serialize_properties(properties or {})
query = dedent( query = dedent(
f"""\ f"""\
@ -405,13 +407,13 @@ class Neo4jAdapter(GraphDBInterface):
) )
params = { params = {
"from_node": str(from_node), "from_node": str(source_id), # Adding str as callsites may still be passing UUID
"to_node": str(to_node), "to_node": str(target_id),
"relationship_name": relationship_name, "relationship_name": relationship_name,
"properties": serialized_properties, "properties": serialized_properties,
} }
return await self.query(query, params) await self.query(query, params)
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]: def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
""" """
@ -445,9 +447,9 @@ class Neo4jAdapter(GraphDBInterface):
return flattened return flattened
@record_graph_changes @record_graph_changes # type: ignore
@override_distributed(queued_add_edges) @override_distributed(queued_add_edges) # type: ignore
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
""" """
Add multiple edges between nodes in a single query. Add multiple edges between nodes in a single query.
@ -478,10 +480,10 @@ class Neo4jAdapter(GraphDBInterface):
) YIELD rel ) YIELD rel
RETURN rel""" RETURN rel"""
edges = [ edge_params = [
{ {
"from_node": str(edge[0]), "from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), "to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"relationship_name": edge[2], "relationship_name": edge[2],
"properties": self._flatten_edge_properties( "properties": self._flatten_edge_properties(
{ {
@ -495,13 +497,12 @@ class Neo4jAdapter(GraphDBInterface):
] ]
try: try:
results = await self.query(query, dict(edges=edges)) await self.query(query, dict(edges=edge_params))
return results
except Neo4jError as error: except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True) logger.error("Neo4j query error: %s", error, exc_info=True)
raise error raise error
async def get_edges(self, node_id: str): async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
""" """
Retrieve all edges connected to a specified node. Retrieve all edges connected to a specified node.