mypy: version Neo4j adapter
This commit is contained in:
parent
b9cd847e9d
commit
0fb962e29a
1 changed files with 64 additions and 63 deletions
|
|
@ -8,11 +8,11 @@ from neo4j import AsyncSession
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from neo4j.exceptions import Neo4jError
|
from neo4j.exceptions import Neo4jError
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||||
from cognee.tasks.temporal_graph.models import Timestamp
|
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
GraphDBInterface,
|
GraphDBInterface,
|
||||||
|
|
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_session(self) -> AsyncSession:
|
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""
|
"""
|
||||||
Get a session for database operations.
|
Get a session for database operations.
|
||||||
"""
|
"""
|
||||||
async with self.driver.session(database=self.graph_database_name) as session:
|
async with self.driver.session(database=self.graph_database_name) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
@deadlock_retry()
|
@deadlock_retry() # type: ignore
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
async with self.get_session() as session:
|
async with self.get_session() as session:
|
||||||
result = await session.run(query, parameters=params)
|
result = await session.run(query, parameters=params)
|
||||||
data = await result.data()
|
data = await result.data()
|
||||||
|
# TODO: why we don't return List[Dict[str, Any]]?
|
||||||
return data
|
return data
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||||
|
|
@ -141,21 +142,27 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return results[0]["node_exists"] if len(results) > 0 else False
|
return results[0]["node_exists"] if len(results) > 0 else False
|
||||||
|
|
||||||
async def add_node(self, node: DataPoint):
|
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Add a new node to the database based on the provided DataPoint object.
|
Add a new node to the database based on the provided DataPoint object or string ID.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- node (DataPoint): An instance of DataPoint representing the node to add.
|
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
|
||||||
|
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
The result of the query execution, typically the ID of the added node.
|
|
||||||
"""
|
"""
|
||||||
serialized_properties = self.serialize_properties(node.model_dump())
|
if isinstance(node, str):
|
||||||
|
# TODO: this was not handled in the original code, check if it is correct
|
||||||
|
# Handle string node ID with properties parameter
|
||||||
|
node_id = node
|
||||||
|
node_label = "Node" # Default label for string nodes
|
||||||
|
serialized_properties = self.serialize_properties(properties or {})
|
||||||
|
else:
|
||||||
|
# Handle DataPoint object
|
||||||
|
node_id = str(node.id)
|
||||||
|
node_label = type(node).__name__
|
||||||
|
serialized_properties = self.serialize_properties(node.model_dump())
|
||||||
|
|
||||||
query = dedent(
|
query = dedent(
|
||||||
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
|
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
|
||||||
|
|
@ -167,16 +174,16 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"node_id": str(node.id),
|
"node_id": node_id,
|
||||||
"node_label": type(node).__name__,
|
"node_label": node_label,
|
||||||
"properties": serialized_properties,
|
"properties": serialized_properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
@override_distributed(queued_add_nodes)
|
@override_distributed(queued_add_nodes) # type: ignore
|
||||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the database in a single query.
|
Add multiple nodes to the database in a single query.
|
||||||
|
|
||||||
|
|
@ -201,7 +208,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nodes = [
|
node_params = [
|
||||||
{
|
{
|
||||||
"node_id": str(node.id),
|
"node_id": str(node.id),
|
||||||
"label": type(node).__name__,
|
"label": type(node).__name__,
|
||||||
|
|
@ -210,10 +217,9 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
for node in nodes
|
for node in nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
results = await self.query(query, dict(nodes=nodes))
|
await self.query(query, dict(nodes=node_params))
|
||||||
return results
|
|
||||||
|
|
||||||
async def extract_node(self, node_id: str):
|
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Retrieve a single node from the database by its ID.
|
Retrieve a single node from the database by its ID.
|
||||||
|
|
||||||
|
|
@ -231,7 +237,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return results[0] if len(results) > 0 else None
|
return results[0] if len(results) > 0 else None
|
||||||
|
|
||||||
async def extract_nodes(self, node_ids: List[str]):
|
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Retrieve multiple nodes from the database by their IDs.
|
Retrieve multiple nodes from the database by their IDs.
|
||||||
|
|
||||||
|
|
@ -256,7 +262,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return [result["node"] for result in results]
|
return [result["node"] for result in results]
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Remove a node from the database identified by its ID.
|
Remove a node from the database identified by its ID.
|
||||||
|
|
||||||
|
|
@ -273,7 +279,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
|
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
|
||||||
params = {"node_id": node_id}
|
params = {"node_id": node_id}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -296,18 +302,18 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
params = {"node_ids": node_ids}
|
params = {"node_ids": node_ids}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an edge exists between two nodes with the specified IDs and edge label.
|
Check if an edge exists between two nodes with the specified IDs and edge label.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- from_node (UUID): The ID of the node from which the edge originates.
|
- source_id (str): The ID of the node from which the edge originates.
|
||||||
- to_node (UUID): The ID of the node to which the edge points.
|
- target_id (str): The ID of the node to which the edge points.
|
||||||
- edge_label (str): The label of the edge to check for existence.
|
- relationship_name (str): The label of the edge to check for existence.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -315,27 +321,28 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
- bool: True if the edge exists, otherwise False.
|
- bool: True if the edge exists, otherwise False.
|
||||||
"""
|
"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
|
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
|
||||||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
|
WHERE from_node.id = $source_id AND to_node.id = $target_id
|
||||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"from_node_id": str(from_node),
|
"source_id": str(source_id),
|
||||||
"to_node_id": str(to_node),
|
"target_id": str(target_id),
|
||||||
}
|
}
|
||||||
|
|
||||||
edge_exists = await self.query(query, params)
|
edge_exists = await self.query(query, params)
|
||||||
|
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
|
||||||
return edge_exists
|
return edge_exists
|
||||||
|
|
||||||
async def has_edges(self, edges):
|
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
|
||||||
"""
|
"""
|
||||||
Check if multiple edges exist based on provided edge criteria.
|
Check if multiple edges exist based on provided edge criteria.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- edges: A list of edge specifications to check for existence.
|
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -369,29 +376,24 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
self,
|
self,
|
||||||
from_node: UUID,
|
source_id: str,
|
||||||
to_node: UUID,
|
target_id: str,
|
||||||
relationship_name: str,
|
relationship_name: str,
|
||||||
edge_properties: Optional[Dict[str, Any]] = {},
|
properties: Optional[Dict[str, Any]] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new edge between two nodes with specified properties.
|
Create a new edge between two nodes with specified properties.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- from_node (UUID): The ID of the source node of the edge.
|
- source_id (str): The ID of the source node of the edge.
|
||||||
- to_node (UUID): The ID of the target node of the edge.
|
- target_id (str): The ID of the target node of the edge.
|
||||||
- relationship_name (str): The type/label of the edge to create.
|
- relationship_name (str): The type/label of the edge to create.
|
||||||
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
||||||
to the edge. (default {})
|
to the edge. (default None)
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
The result of the query execution, typically indicating the created edge.
|
|
||||||
"""
|
"""
|
||||||
serialized_properties = self.serialize_properties(edge_properties)
|
serialized_properties = self.serialize_properties(properties or {})
|
||||||
|
|
||||||
query = dedent(
|
query = dedent(
|
||||||
f"""\
|
f"""\
|
||||||
|
|
@ -405,13 +407,13 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"from_node": str(from_node),
|
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
||||||
"to_node": str(to_node),
|
"to_node": str(target_id),
|
||||||
"relationship_name": relationship_name,
|
"relationship_name": relationship_name,
|
||||||
"properties": serialized_properties,
|
"properties": serialized_properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -445,9 +447,9 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return flattened
|
return flattened
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
@override_distributed(queued_add_edges)
|
@override_distributed(queued_add_edges) # type: ignore
|
||||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple edges between nodes in a single query.
|
Add multiple edges between nodes in a single query.
|
||||||
|
|
||||||
|
|
@ -478,10 +480,10 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
) YIELD rel
|
) YIELD rel
|
||||||
RETURN rel"""
|
RETURN rel"""
|
||||||
|
|
||||||
edges = [
|
edge_params = [
|
||||||
{
|
{
|
||||||
"from_node": str(edge[0]),
|
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
||||||
"to_node": str(edge[1]),
|
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
||||||
"relationship_name": edge[2],
|
"relationship_name": edge[2],
|
||||||
"properties": self._flatten_edge_properties(
|
"properties": self._flatten_edge_properties(
|
||||||
{
|
{
|
||||||
|
|
@ -495,13 +497,12 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await self.query(query, dict(edges=edges))
|
await self.query(query, dict(edges=edge_params))
|
||||||
return results
|
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def get_edges(self, node_id: str):
|
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Retrieve all edges connected to a specified node.
|
Retrieve all edges connected to a specified node.
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue