mypy: version Neo4j adapter
This commit is contained in:
parent
b9cd847e9d
commit
0fb962e29a
1 changed files with 64 additions and 63 deletions
|
|
@ -8,11 +8,11 @@ from neo4j import AsyncSession
|
|||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
from cognee.tasks.temporal_graph.models import Timestamp
|
||||
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
|
|
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncSession:
|
||||
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Get a session for database operations.
|
||||
"""
|
||||
async with self.driver.session(database=self.graph_database_name) as session:
|
||||
yield session
|
||||
|
||||
@deadlock_retry()
|
||||
@deadlock_retry() # type: ignore
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async with self.get_session() as session:
|
||||
result = await session.run(query, parameters=params)
|
||||
data = await result.data()
|
||||
# TODO: why we don't return List[Dict[str, Any]]?
|
||||
return data
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
|
|
@ -141,21 +142,27 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
return results[0]["node_exists"] if len(results) > 0 else False
|
||||
|
||||
async def add_node(self, node: DataPoint):
|
||||
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
Add a new node to the database based on the provided DataPoint object.
|
||||
Add a new node to the database based on the provided DataPoint object or string ID.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- node (DataPoint): An instance of DataPoint representing the node to add.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The result of the query execution, typically the ID of the added node.
|
||||
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
|
||||
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
|
||||
"""
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
if isinstance(node, str):
|
||||
# TODO: this was not handled in the original code, check if it is correct
|
||||
# Handle string node ID with properties parameter
|
||||
node_id = node
|
||||
node_label = "Node" # Default label for string nodes
|
||||
serialized_properties = self.serialize_properties(properties or {})
|
||||
else:
|
||||
# Handle DataPoint object
|
||||
node_id = str(node.id)
|
||||
node_label = type(node).__name__
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
|
||||
query = dedent(
|
||||
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
|
||||
|
|
@ -167,16 +174,16 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
params = {
|
||||
"node_id": str(node.id),
|
||||
"node_label": type(node).__name__,
|
||||
"node_id": node_id,
|
||||
"node_label": node_label,
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
await self.query(query, params)
|
||||
|
||||
@record_graph_changes
|
||||
@override_distributed(queued_add_nodes)
|
||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||
@record_graph_changes # type: ignore
|
||||
@override_distributed(queued_add_nodes) # type: ignore
|
||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||
"""
|
||||
Add multiple nodes to the database in a single query.
|
||||
|
||||
|
|
@ -201,7 +208,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||
"""
|
||||
|
||||
nodes = [
|
||||
node_params = [
|
||||
{
|
||||
"node_id": str(node.id),
|
||||
"label": type(node).__name__,
|
||||
|
|
@ -210,10 +217,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
for node in nodes
|
||||
]
|
||||
|
||||
results = await self.query(query, dict(nodes=nodes))
|
||||
return results
|
||||
await self.query(query, dict(nodes=node_params))
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a single node from the database by its ID.
|
||||
|
||||
|
|
@ -231,7 +237,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return results[0] if len(results) > 0 else None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]):
|
||||
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve multiple nodes from the database by their IDs.
|
||||
|
||||
|
|
@ -256,7 +262,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return [result["node"] for result in results]
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Remove a node from the database identified by its ID.
|
||||
|
||||
|
|
@ -273,7 +279,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
|
||||
params = {"node_id": node_id}
|
||||
|
||||
return await self.query(query, params)
|
||||
await self.query(query, params)
|
||||
|
||||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||||
"""
|
||||
|
|
@ -296,18 +302,18 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
return await self.query(query, params)
|
||||
await self.query(query, params)
|
||||
|
||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||||
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
||||
"""
|
||||
Check if an edge exists between two nodes with the specified IDs and edge label.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- from_node (UUID): The ID of the node from which the edge originates.
|
||||
- to_node (UUID): The ID of the node to which the edge points.
|
||||
- edge_label (str): The label of the edge to check for existence.
|
||||
- source_id (str): The ID of the node from which the edge originates.
|
||||
- target_id (str): The ID of the node to which the edge points.
|
||||
- relationship_name (str): The label of the edge to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -315,27 +321,28 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
- bool: True if the edge exists, otherwise False.
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
|
||||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
|
||||
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
|
||||
WHERE from_node.id = $source_id AND to_node.id = $target_id
|
||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
params = {
|
||||
"from_node_id": str(from_node),
|
||||
"to_node_id": str(to_node),
|
||||
"source_id": str(source_id),
|
||||
"target_id": str(target_id),
|
||||
}
|
||||
|
||||
edge_exists = await self.query(query, params)
|
||||
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
|
||||
return edge_exists
|
||||
|
||||
async def has_edges(self, edges):
|
||||
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
|
||||
"""
|
||||
Check if multiple edges exist based on provided edge criteria.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- edges: A list of edge specifications to check for existence.
|
||||
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -369,29 +376,24 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: UUID,
|
||||
to_node: UUID,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
relationship_name: str,
|
||||
edge_properties: Optional[Dict[str, Any]] = {},
|
||||
):
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a new edge between two nodes with specified properties.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- from_node (UUID): The ID of the source node of the edge.
|
||||
- to_node (UUID): The ID of the target node of the edge.
|
||||
- source_id (str): The ID of the source node of the edge.
|
||||
- target_id (str): The ID of the target node of the edge.
|
||||
- relationship_name (str): The type/label of the edge to create.
|
||||
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
||||
to the edge. (default {})
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The result of the query execution, typically indicating the created edge.
|
||||
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
||||
to the edge. (default None)
|
||||
"""
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
serialized_properties = self.serialize_properties(properties or {})
|
||||
|
||||
query = dedent(
|
||||
f"""\
|
||||
|
|
@ -405,13 +407,13 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
params = {
|
||||
"from_node": str(from_node),
|
||||
"to_node": str(to_node),
|
||||
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
||||
"to_node": str(target_id),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
await self.query(query, params)
|
||||
|
||||
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -445,9 +447,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return flattened
|
||||
|
||||
@record_graph_changes
|
||||
@override_distributed(queued_add_edges)
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
@record_graph_changes # type: ignore
|
||||
@override_distributed(queued_add_edges) # type: ignore
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||
"""
|
||||
Add multiple edges between nodes in a single query.
|
||||
|
||||
|
|
@ -478,10 +480,10 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
) YIELD rel
|
||||
RETURN rel"""
|
||||
|
||||
edges = [
|
||||
edge_params = [
|
||||
{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
||||
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
||||
"relationship_name": edge[2],
|
||||
"properties": self._flatten_edge_properties(
|
||||
{
|
||||
|
|
@ -495,13 +497,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
]
|
||||
|
||||
try:
|
||||
results = await self.query(query, dict(edges=edges))
|
||||
return results
|
||||
await self.query(query, dict(edges=edge_params))
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||
"""
|
||||
Retrieve all edges connected to a specified node.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue