mypy: version Neo4j adapter

This commit is contained in:
Daulet Amirkhanov 2025-09-04 16:07:36 +01:00
parent b9cd847e9d
commit 0fb962e29a

View file

@ -8,11 +8,11 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
)
@asynccontextmanager
async def get_session(self) -> AsyncSession:
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""
Get a session for database operations.
"""
async with self.driver.session(database=self.graph_database_name) as session:
yield session
@deadlock_retry()
@deadlock_retry() # type: ignore
async def query(
self,
query: str,
@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface):
async with self.get_session() as session:
result = await session.run(query, parameters=params)
data = await result.data()
# TODO: why we don't return List[Dict[str, Any]]?
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
@ -141,21 +142,27 @@ class Neo4jAdapter(GraphDBInterface):
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node: DataPoint):
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
"""
Add a new node to the database based on the provided DataPoint object.
Add a new node to the database based on the provided DataPoint object or string ID.
Parameters:
-----------
- node (DataPoint): An instance of DataPoint representing the node to add.
Returns:
--------
The result of the query execution, typically the ID of the added node.
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
"""
serialized_properties = self.serialize_properties(node.model_dump())
if isinstance(node, str):
# TODO: this was not handled in the original code, check if it is correct
# Handle string node ID with properties parameter
node_id = node
node_label = "Node" # Default label for string nodes
serialized_properties = self.serialize_properties(properties or {})
else:
# Handle DataPoint object
node_id = str(node.id)
node_label = type(node).__name__
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent(
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
@ -167,16 +174,16 @@ class Neo4jAdapter(GraphDBInterface):
)
params = {
"node_id": str(node.id),
"node_label": type(node).__name__,
"node_id": node_id,
"node_label": node_label,
"properties": serialized_properties,
}
return await self.query(query, params)
await self.query(query, params)
@record_graph_changes
@override_distributed(queued_add_nodes)
async def add_nodes(self, nodes: list[DataPoint]) -> None:
@record_graph_changes # type: ignore
@override_distributed(queued_add_nodes) # type: ignore
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the database in a single query.
@ -201,7 +208,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
"""
nodes = [
node_params = [
{
"node_id": str(node.id),
"label": type(node).__name__,
@ -210,10 +217,9 @@ class Neo4jAdapter(GraphDBInterface):
for node in nodes
]
results = await self.query(query, dict(nodes=nodes))
return results
await self.query(query, dict(nodes=node_params))
async def extract_node(self, node_id: str):
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve a single node from the database by its ID.
@ -231,7 +237,7 @@ class Neo4jAdapter(GraphDBInterface):
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]):
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
"""
Retrieve multiple nodes from the database by their IDs.
@ -256,7 +262,7 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in results]
async def delete_node(self, node_id: str):
async def delete_node(self, node_id: str) -> None:
"""
Remove a node from the database identified by its ID.
@ -273,7 +279,7 @@ class Neo4jAdapter(GraphDBInterface):
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id}
return await self.query(query, params)
await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None:
"""
@ -296,18 +302,18 @@ class Neo4jAdapter(GraphDBInterface):
params = {"node_ids": node_ids}
return await self.query(query, params)
await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
"""
Check if an edge exists between two nodes with the specified IDs and edge label.
Parameters:
-----------
- from_node (UUID): The ID of the node from which the edge originates.
- to_node (UUID): The ID of the node to which the edge points.
- edge_label (str): The label of the edge to check for existence.
- source_id (str): The ID of the node from which the edge originates.
- target_id (str): The ID of the node to which the edge points.
- relationship_name (str): The label of the edge to check for existence.
Returns:
--------
@ -315,27 +321,28 @@ class Neo4jAdapter(GraphDBInterface):
- bool: True if the edge exists, otherwise False.
"""
query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $source_id AND to_node.id = $target_id
RETURN COUNT(relationship) > 0 AS edge_exists
"""
params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"source_id": str(source_id),
"target_id": str(target_id),
}
edge_exists = await self.query(query, params)
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
return edge_exists
async def has_edges(self, edges):
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
"""
Check if multiple edges exist based on provided edge criteria.
Parameters:
-----------
- edges: A list of edge specifications to check for existence.
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
Returns:
--------
@ -369,29 +376,24 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edge(
self,
from_node: UUID,
to_node: UUID,
source_id: str,
target_id: str,
relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = {},
):
properties: Optional[Dict[str, Any]] = None,
) -> None:
"""
Create a new edge between two nodes with specified properties.
Parameters:
-----------
- from_node (UUID): The ID of the source node of the edge.
- to_node (UUID): The ID of the target node of the edge.
- source_id (str): The ID of the source node of the edge.
- target_id (str): The ID of the target node of the edge.
- relationship_name (str): The type/label of the edge to create.
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default {})
Returns:
--------
The result of the query execution, typically indicating the created edge.
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default None)
"""
serialized_properties = self.serialize_properties(edge_properties)
serialized_properties = self.serialize_properties(properties or {})
query = dedent(
f"""\
@ -405,13 +407,13 @@ class Neo4jAdapter(GraphDBInterface):
)
params = {
"from_node": str(from_node),
"to_node": str(to_node),
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
"to_node": str(target_id),
"relationship_name": relationship_name,
"properties": serialized_properties,
}
return await self.query(query, params)
await self.query(query, params)
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -445,9 +447,9 @@ class Neo4jAdapter(GraphDBInterface):
return flattened
@record_graph_changes
@override_distributed(queued_add_edges)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
@record_graph_changes # type: ignore
@override_distributed(queued_add_edges) # type: ignore
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""
Add multiple edges between nodes in a single query.
@ -478,10 +480,10 @@ class Neo4jAdapter(GraphDBInterface):
) YIELD rel
RETURN rel"""
edges = [
edge_params = [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"relationship_name": edge[2],
"properties": self._flatten_edge_properties(
{
@ -495,13 +497,12 @@ class Neo4jAdapter(GraphDBInterface):
]
try:
results = await self.query(query, dict(edges=edges))
return results
await self.query(query, dict(edges=edge_params))
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def get_edges(self, node_id: str):
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
"""
Retrieve all edges connected to a specified node.