1373 lines
43 KiB
Python
1373 lines
43 KiB
Python
"""Neo4j Adapter for Graph Database"""
|
||
|
||
import json
|
||
import asyncio
|
||
from uuid import UUID
|
||
from textwrap import dedent
|
||
from neo4j import AsyncSession
|
||
from neo4j import AsyncGraphDatabase
|
||
from neo4j.exceptions import Neo4jError
|
||
from contextlib import asynccontextmanager
|
||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||
|
||
from cognee.infrastructure.engine import DataPoint
|
||
from cognee.shared.logging_utils import get_logger, ERROR
|
||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||
GraphDBInterface,
|
||
record_graph_changes,
|
||
)
|
||
from cognee.modules.storage.utils import JSONEncoder
|
||
|
||
from distributed.utils import override_distributed
|
||
from distributed.tasks.queued_add_nodes import queued_add_nodes
|
||
from distributed.tasks.queued_add_edges import queued_add_edges
|
||
|
||
from .neo4j_metrics_utils import (
|
||
get_avg_clustering,
|
||
get_edge_density,
|
||
get_num_connected_components,
|
||
get_shortest_path_lengths,
|
||
get_size_of_connected_components,
|
||
count_self_loops,
|
||
)
|
||
from .deadlock_retry import deadlock_retry
|
||
|
||
|
||
logger = get_logger("Neo4jAdapter")
|
||
|
||
BASE_LABEL = "__Node__"
|
||
|
||
|
||
class Neo4jAdapter(GraphDBInterface):
|
||
"""
|
||
Adapter for interacting with a Neo4j graph database, implementing the GraphDBInterface.
|
||
This class provides methods for querying, adding, deleting nodes and edges, as well as
|
||
managing sessions and projecting graphs.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
graph_database_url: str,
|
||
graph_database_username: Optional[str] = None,
|
||
graph_database_password: Optional[str] = None,
|
||
graph_database_name: Optional[str] = None,
|
||
driver: Optional[Any] = None,
|
||
):
|
||
# Only use auth if both username and password are provided
|
||
auth = None
|
||
if graph_database_username and graph_database_password:
|
||
auth = (graph_database_username, graph_database_password)
|
||
elif graph_database_username or graph_database_password:
|
||
logger = get_logger(__name__)
|
||
logger.warning("Neo4j credentials incomplete – falling back to anonymous connection.")
|
||
self.graph_database_name = graph_database_name
|
||
self.driver = driver or AsyncGraphDatabase.driver(
|
||
graph_database_url,
|
||
auth=auth,
|
||
max_connection_lifetime=120,
|
||
notifications_min_severity="OFF",
|
||
)
|
||
|
||
async def initialize(self) -> None:
|
||
"""
|
||
Initializes the database: adds uniqueness constraint on id and performs indexing
|
||
"""
|
||
await self.query(
|
||
(f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:`{BASE_LABEL}`) REQUIRE n.id IS UNIQUE;")
|
||
)
|
||
|
||
@asynccontextmanager
|
||
async def get_session(self) -> AsyncSession:
|
||
"""
|
||
Get a session for database operations.
|
||
"""
|
||
async with self.driver.session(database=self.graph_database_name) as session:
|
||
yield session
|
||
|
||
@deadlock_retry()
|
||
async def query(
|
||
self,
|
||
query: str,
|
||
params: Optional[Dict[str, Any]] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
Execute a Cypher query against the Neo4j database and return the result.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- query (str): A string containing the Cypher query to execute.
|
||
- params (Optional[Dict[str, Any]]): A dictionary of parameters to be passed to the
|
||
query. (default None)
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- List[Dict[str, Any]]: A list of dictionaries representing the result of the query
|
||
execution.
|
||
"""
|
||
try:
|
||
async with self.get_session() as session:
|
||
result = await session.run(query, parameters=params)
|
||
data = await result.data()
|
||
return data
|
||
except Neo4jError as error:
|
||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||
raise error
|
||
|
||
async def has_node(self, node_id: str) -> bool:
|
||
"""
|
||
Check if a node with the specified ID exists in the database.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node to check for existence.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- bool: True if the node exists, otherwise False.
|
||
"""
|
||
results = self.query(
|
||
f"""
|
||
MATCH (n:`{BASE_LABEL}`)
|
||
WHERE n.id = $node_id
|
||
RETURN COUNT(n) > 0 AS node_exists
|
||
""",
|
||
{"node_id": node_id},
|
||
)
|
||
return results[0]["node_exists"] if len(results) > 0 else False
|
||
|
||
async def add_node(self, node: DataPoint):
|
||
"""
|
||
Add a new node to the database based on the provided DataPoint object.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node (DataPoint): An instance of DataPoint representing the node to add.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
The result of the query execution, typically the ID of the added node.
|
||
"""
|
||
serialized_properties = self.serialize_properties(node.model_dump())
|
||
|
||
query = dedent(
|
||
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
|
||
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
||
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
||
WITH node, $node_label AS label
|
||
CALL apoc.create.addLabels(node, [label]) YIELD node AS labeledNode
|
||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId"""
|
||
)
|
||
|
||
params = {
|
||
"node_id": str(node.id),
|
||
"node_label": type(node).__name__,
|
||
"properties": serialized_properties,
|
||
}
|
||
|
||
return await self.query(query, params)
|
||
|
||
@record_graph_changes
|
||
@override_distributed(queued_add_nodes)
|
||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||
"""
|
||
Add multiple nodes to the database in a single query.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- nodes (list[DataPoint]): A list of DataPoint instances representing the nodes to
|
||
add.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- None: None
|
||
"""
|
||
query = f"""
|
||
UNWIND $nodes AS node
|
||
MERGE (n: `{BASE_LABEL}`{{id: node.node_id}})
|
||
ON CREATE SET n += node.properties, n.updated_at = timestamp()
|
||
ON MATCH SET n += node.properties, n.updated_at = timestamp()
|
||
WITH n, node.label AS label
|
||
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
|
||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||
"""
|
||
|
||
nodes = [
|
||
{
|
||
"node_id": str(node.id),
|
||
"label": type(node).__name__,
|
||
"properties": self.serialize_properties(node.model_dump()),
|
||
}
|
||
for node in nodes
|
||
]
|
||
|
||
results = await self.query(query, dict(nodes=nodes))
|
||
return results
|
||
|
||
async def extract_node(self, node_id: str):
|
||
"""
|
||
Retrieve a single node from the database by its ID.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node to retrieve.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
The node represented as a dictionary, or None if it does not exist.
|
||
"""
|
||
results = await self.extract_nodes([node_id])
|
||
|
||
return results[0] if len(results) > 0 else None
|
||
|
||
async def extract_nodes(self, node_ids: List[str]):
|
||
"""
|
||
Retrieve multiple nodes from the database by their IDs.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_ids (List[str]): A list of IDs for the nodes to retrieve.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A list of nodes represented as dictionaries.
|
||
"""
|
||
query = f"""
|
||
UNWIND $node_ids AS id
|
||
MATCH (node: `{BASE_LABEL}`{{id: id}})
|
||
RETURN node"""
|
||
|
||
params = {"node_ids": node_ids}
|
||
|
||
results = await self.query(query, params)
|
||
|
||
return [result["node"] for result in results]
|
||
|
||
async def delete_node(self, node_id: str):
|
||
"""
|
||
Remove a node from the database identified by its ID.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node to delete.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
The result of the query execution, typically indicating success or failure.
|
||
"""
|
||
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
|
||
params = {"node_id": node_id}
|
||
|
||
return await self.query(query, params)
|
||
|
||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||
"""
|
||
Delete multiple nodes from the database using their IDs.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_ids (list[str]): A list of IDs of the nodes to delete.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- None: None
|
||
"""
|
||
query = f"""
|
||
UNWIND $node_ids AS id
|
||
MATCH (node: `{BASE_LABEL}`{{id: id}})
|
||
DETACH DELETE node"""
|
||
|
||
params = {"node_ids": node_ids}
|
||
|
||
return await self.query(query, params)
|
||
|
||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||
"""
|
||
Check if an edge exists between two nodes with the specified IDs and edge label.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- from_node (UUID): The ID of the node from which the edge originates.
|
||
- to_node (UUID): The ID of the node to which the edge points.
|
||
- edge_label (str): The label of the edge to check for existence.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- bool: True if the edge exists, otherwise False.
|
||
"""
|
||
query = f"""
|
||
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
|
||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
|
||
RETURN COUNT(relationship) > 0 AS edge_exists
|
||
"""
|
||
|
||
params = {
|
||
"from_node_id": str(from_node),
|
||
"to_node_id": str(to_node),
|
||
}
|
||
|
||
edge_exists = await self.query(query, params)
|
||
return edge_exists
|
||
|
||
async def has_edges(self, edges):
|
||
"""
|
||
Check if multiple edges exist based on provided edge criteria.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- edges: A list of edge specifications to check for existence.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A list of boolean values indicating the existence of each edge.
|
||
"""
|
||
query = """
|
||
UNWIND $edges AS edge
|
||
MATCH (a)-[r]->(b)
|
||
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||
"""
|
||
|
||
try:
|
||
params = {
|
||
"edges": [
|
||
{
|
||
"from_node": str(edge[0]),
|
||
"to_node": str(edge[1]),
|
||
"relationship_name": edge[2],
|
||
}
|
||
for edge in edges
|
||
],
|
||
}
|
||
|
||
results = await self.query(query, params)
|
||
return [result["edge_exists"] for result in results]
|
||
except Neo4jError as error:
|
||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||
raise error
|
||
|
||
async def add_edge(
|
||
self,
|
||
from_node: UUID,
|
||
to_node: UUID,
|
||
relationship_name: str,
|
||
edge_properties: Optional[Dict[str, Any]] = {},
|
||
):
|
||
"""
|
||
Create a new edge between two nodes with specified properties.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- from_node (UUID): The ID of the source node of the edge.
|
||
- to_node (UUID): The ID of the target node of the edge.
|
||
- relationship_name (str): The type/label of the edge to create.
|
||
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
||
to the edge. (default {})
|
||
|
||
Returns:
|
||
--------
|
||
|
||
The result of the query execution, typically indicating the created edge.
|
||
"""
|
||
serialized_properties = self.serialize_properties(edge_properties)
|
||
|
||
query = dedent(
|
||
f"""\
|
||
MATCH (from_node :`{BASE_LABEL}`{{id: $from_node}}),
|
||
(to_node :`{BASE_LABEL}`{{id: $to_node}})
|
||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
||
ON CREATE SET r += $properties, r.updated_at = timestamp()
|
||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||
RETURN r
|
||
"""
|
||
)
|
||
|
||
params = {
|
||
"from_node": str(from_node),
|
||
"to_node": str(to_node),
|
||
"relationship_name": relationship_name,
|
||
"properties": serialized_properties,
|
||
}
|
||
|
||
return await self.query(query, params)
|
||
|
||
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
Flatten edge properties to handle nested dictionaries like weights.
|
||
|
||
Neo4j doesn't support nested dictionaries as property values, so we need to
|
||
flatten the 'weights' dictionary into individual properties with prefixes.
|
||
|
||
Args:
|
||
properties: Dictionary of edge properties that may contain nested dicts
|
||
|
||
Returns:
|
||
Flattened properties dictionary suitable for Neo4j storage
|
||
"""
|
||
flattened = {}
|
||
|
||
for key, value in properties.items():
|
||
if key == "weights" and isinstance(value, dict):
|
||
# Flatten weights dictionary into individual properties
|
||
for weight_name, weight_value in value.items():
|
||
flattened[f"weight_{weight_name}"] = weight_value
|
||
elif isinstance(value, dict):
|
||
# For other nested dictionaries, serialize as JSON string
|
||
flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder)
|
||
elif isinstance(value, list):
|
||
# For lists, serialize as JSON string
|
||
flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder)
|
||
else:
|
||
# Keep primitive types as-is
|
||
flattened[key] = value
|
||
|
||
return flattened
|
||
|
||
@record_graph_changes
|
||
@override_distributed(queued_add_edges)
|
||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||
"""
|
||
Add multiple edges between nodes in a single query.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- edges (list[tuple[str, str, str, dict[str, Any]]]): A list of tuples where each
|
||
tuple contains edge details to add.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- None: None
|
||
"""
|
||
query = f"""
|
||
UNWIND $edges AS edge
|
||
MATCH (from_node: `{BASE_LABEL}`{{id: edge.from_node}})
|
||
MATCH (to_node: `{BASE_LABEL}`{{id: edge.to_node}})
|
||
CALL apoc.merge.relationship(
|
||
from_node,
|
||
edge.relationship_name,
|
||
{{
|
||
source_node_id: edge.from_node,
|
||
target_node_id: edge.to_node
|
||
}},
|
||
edge.properties,
|
||
to_node
|
||
) YIELD rel
|
||
RETURN rel"""
|
||
|
||
edges = [
|
||
{
|
||
"from_node": str(edge[0]),
|
||
"to_node": str(edge[1]),
|
||
"relationship_name": edge[2],
|
||
"properties": self._flatten_edge_properties(
|
||
{
|
||
**(edge[3] if edge[3] else {}),
|
||
"source_node_id": str(edge[0]),
|
||
"target_node_id": str(edge[1]),
|
||
}
|
||
),
|
||
}
|
||
for edge in edges
|
||
]
|
||
|
||
try:
|
||
results = await self.query(query, dict(edges=edges))
|
||
return results
|
||
except Neo4jError as error:
|
||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||
raise error
|
||
|
||
async def get_edges(self, node_id: str):
|
||
"""
|
||
Retrieve all edges connected to a specified node.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node for which edges are retrieved.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A list of edges connecting to the specified node, represented as tuples of details.
|
||
"""
|
||
query = f"""
|
||
MATCH (n: `{BASE_LABEL}`{{id: $node_id}})-[r]-(m)
|
||
RETURN n, r, m
|
||
"""
|
||
|
||
results = await self.query(query, dict(node_id=node_id))
|
||
|
||
return [
|
||
(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
|
||
for result in results
|
||
]
|
||
|
||
async def get_disconnected_nodes(self) -> list[str]:
|
||
"""
|
||
Find and return nodes that are not connected to any other nodes in the graph.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- list[str]: A list of IDs of disconnected nodes.
|
||
"""
|
||
# return await self.query(
|
||
# "MATCH (node) WHERE NOT (node)<-[:*]-() RETURN node.id as id",
|
||
# )
|
||
query = """
|
||
// Step 1: Collect all nodes
|
||
MATCH (n)
|
||
WITH COLLECT(n) AS nodes
|
||
|
||
// Step 2: Find all connected components
|
||
WITH nodes
|
||
CALL {
|
||
WITH nodes
|
||
UNWIND nodes AS startNode
|
||
MATCH path = (startNode)-[*]-(connectedNode)
|
||
WITH startNode, COLLECT(DISTINCT connectedNode) AS component
|
||
RETURN component
|
||
}
|
||
|
||
// Step 3: Aggregate components
|
||
WITH COLLECT(component) AS components
|
||
|
||
// Step 4: Identify the largest connected component
|
||
UNWIND components AS component
|
||
WITH component
|
||
ORDER BY SIZE(component) DESC
|
||
LIMIT 1
|
||
WITH component AS largestComponent
|
||
|
||
// Step 5: Find nodes not in the largest connected component
|
||
MATCH (n)
|
||
WHERE NOT n IN largestComponent
|
||
RETURN COLLECT(ID(n)) AS ids
|
||
"""
|
||
|
||
results = await self.query(query)
|
||
return results[0]["ids"] if len(results) > 0 else []
|
||
|
||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||
"""
|
||
Retrieve the predecessor nodes of a specified node based on an optional edge label.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node whose predecessors are to be retrieved.
|
||
- edge_label (str): Optional edge label to filter predecessors. (default None)
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- list[str]: A list of predecessor node IDs.
|
||
"""
|
||
if edge_label is not None:
|
||
query = f"""
|
||
MATCH (node: `{BASE_LABEL}`)<-[r:`{edge_label}`]-(predecessor)
|
||
WHERE node.id = $node_id
|
||
RETURN predecessor
|
||
"""
|
||
|
||
results = await self.query(
|
||
query,
|
||
dict(
|
||
node_id=node_id,
|
||
),
|
||
)
|
||
|
||
return [result["predecessor"] for result in results]
|
||
else:
|
||
query = f"""
|
||
MATCH (node: `{BASE_LABEL}`)<-[r]-(predecessor)
|
||
WHERE node.id = $node_id
|
||
RETURN predecessor
|
||
"""
|
||
|
||
results = await self.query(
|
||
query,
|
||
dict(
|
||
node_id=node_id,
|
||
),
|
||
)
|
||
|
||
return [result["predecessor"] for result in results]
|
||
|
||
async def get_successors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||
"""
|
||
Retrieve the successor nodes of a specified node based on an optional edge label.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node whose successors are to be retrieved.
|
||
- edge_label (str): Optional edge label to filter successors. (default None)
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- list[str]: A list of successor node IDs.
|
||
"""
|
||
if edge_label is not None:
|
||
query = f"""
|
||
MATCH (node: `{BASE_LABEL}`)-[r:`{edge_label}`]->(successor)
|
||
WHERE node.id = $node_id
|
||
RETURN successor
|
||
"""
|
||
|
||
results = await self.query(
|
||
query,
|
||
dict(
|
||
node_id=node_id,
|
||
edge_label=edge_label,
|
||
),
|
||
)
|
||
|
||
return [result["successor"] for result in results]
|
||
else:
|
||
query = f"""
|
||
MATCH (node: `{BASE_LABEL}`)-[r]->(successor)
|
||
WHERE node.id = $node_id
|
||
RETURN successor
|
||
"""
|
||
|
||
results = await self.query(
|
||
query,
|
||
dict(
|
||
node_id=node_id,
|
||
),
|
||
)
|
||
|
||
return [result["successor"] for result in results]
|
||
|
||
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get all neighbors of a specified node, including all directly connected nodes.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node for which neighbors are retrieved.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- List[Dict[str, Any]]: A list of neighboring nodes represented as dictionaries.
|
||
"""
|
||
return await self.get_neighbours(node_id)
|
||
|
||
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Retrieve a single node based on its ID.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (str): The ID of the node to retrieve.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- Optional[Dict[str, Any]]: The requested node as a dictionary, or None if it does
|
||
not exist.
|
||
"""
|
||
query = f"""
|
||
MATCH (node: `{BASE_LABEL}`{{id: $node_id}})
|
||
RETURN node
|
||
"""
|
||
results = await self.query(query, {"node_id": node_id})
|
||
return results[0]["node"] if results else None
|
||
|
||
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
Retrieve multiple nodes based on their IDs.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_ids (List[str]): A list of node IDs to retrieve.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- List[Dict[str, Any]]: A list of nodes represented as dictionaries.
|
||
"""
|
||
query = f"""
|
||
UNWIND $node_ids AS id
|
||
MATCH (node:`{BASE_LABEL}` {{id: id}})
|
||
RETURN node
|
||
"""
|
||
results = await self.query(query, {"node_ids": node_ids})
|
||
return [result["node"] for result in results]
|
||
|
||
async def get_connections(self, node_id: UUID) -> list:
|
||
"""
|
||
Retrieve all connections (predecessors and successors) for a specified node.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_id (UUID): The ID of the node for which connections are retrieved.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- list: A list of connections represented as tuples of details.
|
||
"""
|
||
predecessors_query = f"""
|
||
MATCH (node:`{BASE_LABEL}`)<-[relation]-(neighbour)
|
||
WHERE node.id = $node_id
|
||
RETURN neighbour, relation, node
|
||
"""
|
||
successors_query = f"""
|
||
MATCH (node:`{BASE_LABEL}`)-[relation]->(neighbour)
|
||
WHERE node.id = $node_id
|
||
RETURN node, relation, neighbour
|
||
"""
|
||
|
||
predecessors, successors = await asyncio.gather(
|
||
self.query(predecessors_query, dict(node_id=str(node_id))),
|
||
self.query(successors_query, dict(node_id=str(node_id))),
|
||
)
|
||
|
||
connections = []
|
||
|
||
for neighbour in predecessors:
|
||
neighbour = neighbour["relation"]
|
||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||
|
||
for neighbour in successors:
|
||
neighbour = neighbour["relation"]
|
||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||
|
||
return connections
|
||
|
||
async def remove_connection_to_predecessors_of(
|
||
self, node_ids: list[str], edge_label: str
|
||
) -> None:
|
||
"""
|
||
Remove connections (edges) to all predecessors of specified nodes based on edge label.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_ids (list[str]): A list of IDs of nodes from which connections are to be
|
||
removed.
|
||
- edge_label (str): The label of the edges to remove.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- None: None
|
||
"""
|
||
# Not understanding
|
||
query = f"""
|
||
UNWIND $node_ids AS id
|
||
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
|
||
DELETE r;
|
||
"""
|
||
|
||
params = {"node_ids": node_ids}
|
||
|
||
return await self.query(query, params)
|
||
|
||
async def remove_connection_to_successors_of(
|
||
self, node_ids: list[str], edge_label: str
|
||
) -> None:
|
||
"""
|
||
Remove connections (edges) to all successors of specified nodes based on edge label.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_ids (list[str]): A list of IDs of nodes from which connections are to be
|
||
removed.
|
||
- edge_label (str): The label of the edges to remove.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- None: None
|
||
"""
|
||
# Not understanding
|
||
query = f"""
|
||
UNWIND $node_ids AS id
|
||
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
|
||
DELETE r;
|
||
"""
|
||
|
||
params = {"node_ids": node_ids}
|
||
|
||
return await self.query(query, params)
|
||
|
||
async def delete_graph(self):
|
||
"""
|
||
Delete all nodes and edges from the graph database.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
The result of the query execution, typically indicating success or failure.
|
||
"""
|
||
# query = """MATCH (node)
|
||
# DETACH DELETE node;"""
|
||
|
||
# return await self.query(query)
|
||
|
||
node_labels = await self.get_node_labels()
|
||
|
||
for label in node_labels:
|
||
query = f"""
|
||
MATCH (node:`{label}`)
|
||
DETACH DELETE node;
|
||
"""
|
||
|
||
await self.query(query)
|
||
|
||
def serialize_properties(self, properties=dict()):
|
||
"""
|
||
Convert properties of a node or edge into a serializable format suitable for storage.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- properties: A dictionary of properties to serialize, defaults to an empty
|
||
dictionary. (default dict())
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A dictionary with serialized property values.
|
||
"""
|
||
serialized_properties = {}
|
||
|
||
for property_key, property_value in properties.items():
|
||
if isinstance(property_value, UUID):
|
||
serialized_properties[property_key] = str(property_value)
|
||
continue
|
||
|
||
if isinstance(property_value, dict):
|
||
serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder)
|
||
continue
|
||
|
||
serialized_properties[property_key] = property_value
|
||
|
||
return serialized_properties
|
||
|
||
async def get_model_independent_graph_data(self):
|
||
"""
|
||
Retrieve the basic graph data without considering the model specifics, returning nodes
|
||
and edges.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A tuple of nodes and edges data.
|
||
"""
|
||
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
|
||
nodes = await self.query(query_nodes)
|
||
|
||
query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
|
||
edges = await self.query(query_edges)
|
||
|
||
return (nodes, edges)
|
||
|
||
async def get_graph_data(self):
|
||
"""
|
||
Retrieve comprehensive data about nodes and relationships within the graph.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A tuple containing two lists: nodes and edges with their properties.
|
||
"""
|
||
import time
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# Retrieve nodes
|
||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||
result = await self.query(query)
|
||
|
||
nodes = []
|
||
for record in result:
|
||
nodes.append(
|
||
(
|
||
record["properties"]["id"],
|
||
record["properties"],
|
||
)
|
||
)
|
||
|
||
# Retrieve edges
|
||
query = """
|
||
MATCH (n)-[r]->(m)
|
||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||
"""
|
||
result = await self.query(query)
|
||
|
||
edges = []
|
||
for record in result:
|
||
edges.append(
|
||
(
|
||
record["properties"]["source_node_id"],
|
||
record["properties"]["target_node_id"],
|
||
record["type"],
|
||
record["properties"],
|
||
)
|
||
)
|
||
|
||
retrieval_time = time.time() - start_time
|
||
logger.info(
|
||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||
)
|
||
|
||
return (nodes, edges)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||
raise
|
||
|
||
async def get_nodeset_subgraph(
|
||
self, node_type: Type[Any], node_name: List[str]
|
||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||
"""
|
||
Retrieve a subgraph based on specified node names and type, including their
|
||
relationships.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_type (Type[Any]): The type of nodes to include in the subgraph.
|
||
- node_name (List[str]): A list of names for nodes to filter the subgraph.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
- Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]}: A tuple
|
||
containing nodes and edges in the requested subgraph.
|
||
"""
|
||
import time
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
label = node_type.__name__
|
||
|
||
query = f"""
|
||
UNWIND $names AS wantedName
|
||
MATCH (n:`{label}`)
|
||
WHERE n.name = wantedName
|
||
WITH collect(DISTINCT n) AS primary
|
||
UNWIND primary AS p
|
||
OPTIONAL MATCH (p)--(nbr)
|
||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||
WITH primary + nbrs AS nodelist
|
||
UNWIND nodelist AS node
|
||
WITH collect(DISTINCT node) AS nodes
|
||
MATCH (a)-[r]-(b)
|
||
WHERE a IN nodes AND b IN nodes
|
||
WITH nodes, collect(DISTINCT r) AS rels
|
||
RETURN
|
||
[n IN nodes |
|
||
{{ id: n.id,
|
||
properties: properties(n) }}] AS rawNodes,
|
||
[r IN rels |
|
||
{{ type: type(r),
|
||
properties: properties(r) }}] AS rawRels
|
||
"""
|
||
|
||
result = await self.query(query, {"names": node_name})
|
||
|
||
if not result:
|
||
return [], []
|
||
|
||
raw_nodes = result[0]["rawNodes"]
|
||
raw_rels = result[0]["rawRels"]
|
||
|
||
# Process nodes
|
||
nodes = []
|
||
for n in raw_nodes:
|
||
nodes.append((n["properties"]["id"], n["properties"]))
|
||
|
||
# Process edges
|
||
edges = []
|
||
for r in raw_rels:
|
||
edges.append(
|
||
(
|
||
r["properties"]["source_node_id"],
|
||
r["properties"]["target_node_id"],
|
||
r["type"],
|
||
r["properties"],
|
||
)
|
||
)
|
||
|
||
retrieval_time = time.time() - start_time
|
||
logger.info(
|
||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges for {node_type.__name__} in {retrieval_time:.2f} seconds"
|
||
)
|
||
|
||
return nodes, edges
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error during nodeset subgraph retrieval: {str(e)}")
|
||
raise
|
||
|
||
async def get_filtered_graph_data(self, attribute_filters):
|
||
"""
|
||
Fetch nodes and edges filtered by specific attribute criteria.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- attribute_filters: A list of dictionaries representing attributes and associated
|
||
values for filtering.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A tuple containing filtered nodes and edges based on the specified criteria.
|
||
"""
|
||
where_clauses = []
|
||
for attribute, values in attribute_filters[0].items():
|
||
values_str = ", ".join(
|
||
f"'{value}'" if isinstance(value, str) else str(value) for value in values
|
||
)
|
||
where_clauses.append(f"n.{attribute} IN [{values_str}]")
|
||
|
||
where_clause = " AND ".join(where_clauses)
|
||
|
||
query_nodes = f"""
|
||
MATCH (n)
|
||
WHERE {where_clause}
|
||
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
|
||
"""
|
||
result_nodes = await self.query(query_nodes)
|
||
|
||
nodes = [
|
||
(
|
||
record["id"],
|
||
record["properties"],
|
||
)
|
||
for record in result_nodes
|
||
]
|
||
|
||
query_edges = f"""
|
||
MATCH (n)-[r]->(m)
|
||
WHERE {where_clause} AND {where_clause.replace("n.", "m.")}
|
||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||
"""
|
||
result_edges = await self.query(query_edges)
|
||
|
||
edges = [
|
||
(
|
||
record["properties"]["source_node_id"],
|
||
record["properties"]["target_node_id"],
|
||
record["type"],
|
||
record["properties"],
|
||
)
|
||
for record in result_edges
|
||
]
|
||
|
||
return (nodes, edges)
|
||
|
||
async def graph_exists(self, graph_name="myGraph"):
|
||
"""
|
||
Check if a graph with a given name exists in the database.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- graph_name: The name of the graph to check for existence, defaults to 'myGraph'.
|
||
(default 'myGraph')
|
||
|
||
Returns:
|
||
--------
|
||
|
||
True if the graph exists, otherwise False.
|
||
"""
|
||
query = "CALL gds.graph.list() YIELD graphName RETURN collect(graphName) AS graphNames;"
|
||
result = await self.query(query)
|
||
graph_names = result[0]["graphNames"] if result else []
|
||
return graph_name in graph_names
|
||
|
||
async def get_node_labels(self):
|
||
"""
|
||
Fetch all node labels from the database and return them.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A list of node labels.
|
||
"""
|
||
node_labels_query = "CALL db.labels()"
|
||
node_labels_result = await self.query(node_labels_query)
|
||
node_labels = [record["label"] for record in node_labels_result]
|
||
|
||
return node_labels
|
||
|
||
async def get_relationship_labels_string(self):
|
||
"""
|
||
Fetch all relationship types from the database and return them as a formatted string.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A formatted string of relationship types.
|
||
"""
|
||
relationship_types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) AS relationships;"
|
||
relationship_types_result = await self.query(relationship_types_query)
|
||
relationship_types = (
|
||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||
)
|
||
|
||
if not relationship_types:
|
||
raise ValueError("No relationship types found in the database.")
|
||
|
||
relationship_types_undirected_str = (
|
||
"{"
|
||
+ ", ".join(f"{rel}" + ": {orientation: 'UNDIRECTED'}" for rel in relationship_types)
|
||
+ "}"
|
||
)
|
||
return relationship_types_undirected_str
|
||
|
||
async def project_entire_graph(self, graph_name="myGraph"):
|
||
"""
|
||
Project all node labels and relationship types into an in-memory graph using GDS.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- graph_name: The name of the graph to project, defaults to 'myGraph'. (default
|
||
'myGraph')
|
||
"""
|
||
if await self.graph_exists(graph_name):
|
||
return
|
||
|
||
node_labels = await self.get_node_labels()
|
||
relationship_types_undirected_str = await self.get_relationship_labels_string()
|
||
|
||
query = f"""
|
||
CALL gds.graph.project(
|
||
'{graph_name}',
|
||
['{"', '".join(node_labels)}'],
|
||
{relationship_types_undirected_str}
|
||
) YIELD graphName;
|
||
"""
|
||
|
||
await self.query(query)
|
||
|
||
async def drop_graph(self, graph_name="myGraph"):
|
||
"""
|
||
Drop an existing graph from the database based on its name.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- graph_name: The name of the graph to drop, defaults to 'myGraph'. (default
|
||
'myGraph')
|
||
"""
|
||
if await self.graph_exists(graph_name):
|
||
drop_query = f"CALL gds.graph.drop('{graph_name}');"
|
||
await self.query(drop_query)
|
||
|
||
async def get_graph_metrics(self, include_optional=False):
|
||
"""
|
||
Retrieve metrics related to the graph such as number of nodes, edges, and connected
|
||
components.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- include_optional: Specify whether to include optional metrics; defaults to False.
|
||
(default False)
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A dictionary containing graph metrics, both mandatory and optional based on the
|
||
input flag.
|
||
"""
|
||
|
||
nodes, edges = await self.get_model_independent_graph_data()
|
||
graph_name = "myGraph"
|
||
await self.drop_graph(graph_name)
|
||
await self.project_entire_graph(graph_name)
|
||
|
||
num_nodes = len(nodes[0]["nodes"])
|
||
num_edges = len(edges[0]["elements"])
|
||
|
||
mandatory_metrics = {
|
||
"num_nodes": num_nodes,
|
||
"num_edges": num_edges,
|
||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
||
"edge_density": await get_edge_density(self),
|
||
"num_connected_components": await get_num_connected_components(self, graph_name),
|
||
"sizes_of_connected_components": await get_size_of_connected_components(
|
||
self, graph_name
|
||
),
|
||
}
|
||
|
||
if include_optional:
|
||
shortest_path_lengths = await get_shortest_path_lengths(self, graph_name)
|
||
optional_metrics = {
|
||
"num_selfloops": await count_self_loops(self),
|
||
"diameter": max(shortest_path_lengths) if shortest_path_lengths else -1,
|
||
"avg_shortest_path_length": sum(shortest_path_lengths) / len(shortest_path_lengths)
|
||
if shortest_path_lengths
|
||
else -1,
|
||
"avg_clustering": await get_avg_clustering(self, graph_name),
|
||
}
|
||
else:
|
||
optional_metrics = {
|
||
"num_selfloops": -1,
|
||
"diameter": -1,
|
||
"avg_shortest_path_length": -1,
|
||
"avg_clustering": -1,
|
||
}
|
||
|
||
return mandatory_metrics | optional_metrics
|
||
|
||
async def get_document_subgraph(self, data_id: str):
|
||
"""
|
||
Retrieve a subgraph related to a document identified by its content hash, including
|
||
related entities and chunks.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- content_hash (str): The hash identifying the document whose subgraph should be
|
||
retrieved.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
The subgraph data as a dictionary, or None if not found.
|
||
"""
|
||
query = """
|
||
MATCH (doc)
|
||
WHERE (doc:TextDocument OR doc:PdfDocument OR doc:UnstructuredDocument OR doc:AudioDocument or doc:ImageDocument)
|
||
AND doc.id = $data_id
|
||
|
||
OPTIONAL MATCH (doc)<-[:is_part_of]-(chunk:DocumentChunk)
|
||
OPTIONAL MATCH (chunk)-[:contains]->(entity:Entity)
|
||
WHERE NOT EXISTS {
|
||
MATCH (entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
|
||
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument OR otherDoc:UnstructuredDocument OR otherDoc:AudioDocument or otherDoc:ImageDocument)
|
||
AND otherDoc.id <> doc.id
|
||
}
|
||
OPTIONAL MATCH (chunk)<-[:made_from]-(made_node:TextSummary)
|
||
OPTIONAL MATCH (entity)-[:is_a]->(type:EntityType)
|
||
WHERE NOT EXISTS {
|
||
MATCH (type)<-[:is_a]-(otherEntity:Entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
|
||
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument OR otherDoc:UnstructuredDocument OR otherDoc:AudioDocument or otherDoc:ImageDocument)
|
||
AND otherDoc.id <> doc.id
|
||
}
|
||
|
||
RETURN
|
||
collect(DISTINCT doc) as document,
|
||
collect(DISTINCT chunk) as chunks,
|
||
collect(DISTINCT entity) as orphan_entities,
|
||
collect(DISTINCT made_node) as made_from_nodes,
|
||
collect(DISTINCT type) as orphan_types
|
||
"""
|
||
result = await self.query(query, {"data_id": data_id})
|
||
return result[0] if result else None
|
||
|
||
async def get_degree_one_nodes(self, node_type: str):
|
||
"""
|
||
Fetch nodes of a specified type that have exactly one connection.
|
||
|
||
Parameters:
|
||
-----------
|
||
|
||
- node_type (str): The type of nodes to retrieve, must be 'Entity' or 'EntityType'.
|
||
|
||
Returns:
|
||
--------
|
||
|
||
A list of nodes with exactly one connection of the specified type.
|
||
"""
|
||
if not node_type or node_type not in ["Entity", "EntityType"]:
|
||
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
||
|
||
query = f"""
|
||
MATCH (n:{node_type})
|
||
WHERE COUNT {{ MATCH (n)--() }} = 1
|
||
RETURN n
|
||
"""
|
||
result = await self.query(query)
|
||
return [record["n"] for record in result] if result else []
|
||
|
||
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
||
"""
|
||
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
|
||
Parameters:
|
||
-----------
|
||
- limit (int): The maximum number of interaction IDs to return.
|
||
Returns:
|
||
--------
|
||
- List[str]: A list of interaction IDs, sorted by created_at descending.
|
||
"""
|
||
|
||
query = """
|
||
MATCH (n)
|
||
WHERE n.type = 'CogneeUserInteraction'
|
||
RETURN n.id as id
|
||
ORDER BY n.created_at DESC
|
||
LIMIT $limit
|
||
"""
|
||
rows = await self.query(query, {"limit": limit})
|
||
|
||
id_list = [row["id"] for row in rows if "id" in row]
|
||
return id_list
|
||
|
||
async def apply_feedback_weight(
|
||
self,
|
||
node_ids: List[str],
|
||
weight: float,
|
||
) -> None:
|
||
"""
|
||
Increment `feedback_weight` on relationships `:used_graph_element_to_answer`
|
||
outgoing from nodes whose `id` is in `node_ids`.
|
||
|
||
Args:
|
||
node_ids: List of node IDs to match.
|
||
weight: Amount to add to `r.feedback_weight` (can be negative).
|
||
|
||
Side effects:
|
||
Updates relationship property `feedback_weight`, defaulting missing values to 0.
|
||
"""
|
||
query = """
|
||
MATCH (n)-[r]->()
|
||
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
||
SET r.feedback_weight = coalesce(r.feedback_weight, 0) + $weight
|
||
"""
|
||
await self.query(
|
||
query,
|
||
params={"weight": float(weight), "node_ids": list(node_ids)},
|
||
)
|