cognee/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
lxobr a65ded6283 Merge branch 'dev' into sf_demo
# Conflicts:
#	cognee/api/v1/add/add.py
#	cognee/api/v1/search/search.py
#	cognee/infrastructure/databases/graph/graph_db_interface.py
#	cognee/infrastructure/engine/models/DataPoint.py
#	cognee/modules/retrieval/graph_completion_retriever.py
#	cognee/modules/search/methods/search.py
#	cognee/modules/visualization/cognee_network_visualization.py
#	cognee/tasks/documents/classify_documents.py
#	cognee/tasks/ingestion/ingest_data.py
#	examples/python/simple_node_set_example.py
2025-04-23 00:05:21 +02:00

769 lines
25 KiB
Python

#
"""Neo4j Adapter for Graph Database"""
import json
from cognee.shared.logging_utils import get_logger, ERROR
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.modules.storage.utils import JSONEncoder
from .neo4j_metrics_utils import (
get_avg_clustering,
get_edge_density,
get_num_connected_components,
get_shortest_path_lengths,
get_size_of_connected_components,
count_self_loops,
)
logger = get_logger("Neo4jAdapter", level=ERROR)
class Neo4jAdapter(GraphDBInterface):
def __init__(
self,
graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
driver: Optional[Any] = None,
):
self.driver = driver or AsyncGraphDatabase.driver(
graph_database_url,
auth=(graph_database_username, graph_database_password),
max_connection_lifetime=120,
)
@asynccontextmanager
async def get_session(self) -> AsyncSession:
async with self.driver.session() as session:
yield session
async def query(
self,
query: str,
params: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
try:
async with self.get_session() as session:
result = await session.run(query, parameters=params)
data = await result.data()
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def has_node(self, node_id: str) -> bool:
results = self.query(
"""
MATCH (n)
WHERE n.id = $node_id
RETURN COUNT(n) > 0 AS node_exists
""",
{"node_id": node_id},
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent(
"""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
WITH node, $node_label AS label
CALL apoc.create.addLabels(node, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId"""
)
params = {
"node_id": str(node.id),
"node_label": type(node).__name__,
"properties": serialized_properties,
}
return await self.query(query, params)
@record_graph_changes
async def add_nodes(self, nodes: list[DataPoint]) -> None:
query = """
UNWIND $nodes AS node
MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties, n.updated_at = timestamp()
WITH n, node.label AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
"""
nodes = [
{
"node_id": str(node.id),
"label": type(node).__name__,
"properties": self.serialize_properties(node.model_dump()),
}
for node in nodes
]
results = await self.query(query, dict(nodes=nodes))
return results
async def extract_node(self, node_id: str):
results = await self.extract_nodes([node_id])
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]):
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
RETURN node"""
params = {"node_ids": node_ids}
results = await self.query(query, params)
return [result["node"] for result in results]
async def delete_node(self, node_id: str):
query = "MATCH (node {id: $node_id}) DETACH DELETE node"
params = {"node_id": node_id}
return await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None:
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
DETACH DELETE node"""
params = {"node_ids": node_ids}
return await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
query = """
MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
RETURN COUNT(relationship) > 0 AS edge_exists
"""
params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}
edge_exists = await self.query(query, params)
return edge_exists
async def has_edges(self, edges):
query = """
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
"""
try:
params = {
"edges": [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
}
for edge in edges
],
}
results = await self.query(query, params)
return [result["edge_exists"] for result in results]
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def add_edge(
self,
from_node: UUID,
to_node: UUID,
relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = {},
):
serialized_properties = self.serialize_properties(edge_properties)
query = dedent(
f"""\
MATCH (from_node {{id: $from_node}}),
(to_node {{id: $to_node}})
MERGE (from_node)-[r:{relationship_name}]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp()
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
"""
)
params = {
"from_node": str(from_node),
"to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties,
}
return await self.query(query, params)
@record_graph_changes
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
query = """
UNWIND $edges AS edge
MATCH (from_node {id: edge.from_node})
MATCH (to_node {id: edge.to_node})
CALL apoc.merge.relationship(
from_node,
edge.relationship_name,
{
source_node_id: edge.from_node,
target_node_id: edge.to_node
},
edge.properties,
to_node
) YIELD rel
RETURN rel"""
edges = [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
"properties": {
**(edge[3] if edge[3] else {}),
"source_node_id": str(edge[0]),
"target_node_id": str(edge[1]),
},
}
for edge in edges
]
try:
results = await self.query(query, dict(edges=edges))
return results
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def get_edges(self, node_id: str):
query = """
MATCH (n {id: $node_id})-[r]-(m)
RETURN n, r, m
"""
results = await self.query(query, dict(node_id=node_id))
return [
(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
for result in results
]
async def get_disconnected_nodes(self) -> list[str]:
# return await self.query(
# "MATCH (node) WHERE NOT (node)<-[:*]-() RETURN node.id as id",
# )
query = """
// Step 1: Collect all nodes
MATCH (n)
WITH COLLECT(n) AS nodes
// Step 2: Find all connected components
WITH nodes
CALL {
WITH nodes
UNWIND nodes AS startNode
MATCH path = (startNode)-[*]-(connectedNode)
WITH startNode, COLLECT(DISTINCT connectedNode) AS component
RETURN component
}
// Step 3: Aggregate components
WITH COLLECT(component) AS components
// Step 4: Identify the largest connected component
UNWIND components AS component
WITH component
ORDER BY SIZE(component) DESC
LIMIT 1
WITH component AS largestComponent
// Step 5: Find nodes not in the largest connected component
MATCH (n)
WHERE NOT n IN largestComponent
RETURN COLLECT(ID(n)) AS ids
"""
results = await self.query(query)
return results[0]["ids"] if len(results) > 0 else []
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id AND type(r) = $edge_label
RETURN predecessor
"""
results = await self.query(
query,
dict(
node_id=node_id,
edge_label=edge_label,
),
)
return [result["predecessor"] for result in results]
else:
query = """
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id
RETURN predecessor
"""
results = await self.query(
query,
dict(
node_id=node_id,
),
)
return [result["predecessor"] for result in results]
async def get_successors(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id AND type(r) = $edge_label
RETURN successor
"""
results = await self.query(
query,
dict(
node_id=node_id,
edge_label=edge_label,
),
)
return [result["successor"] for result in results]
else:
query = """
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id
RETURN successor
"""
results = await self.query(
query,
dict(
node_id=node_id,
),
)
return [result["successor"] for result in results]
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
"""Get all neighboring nodes."""
return await self.get_neighbours(node_id)
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
"""Get a single node by ID."""
query = """
MATCH (node {id: $node_id})
RETURN node
"""
results = await self.query(query, {"node_id": node_id})
return results[0]["node"] if results else None
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
"""Get multiple nodes by their IDs."""
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
RETURN node
"""
results = await self.query(query, {"node_ids": node_ids})
return [result["node"] for result in results]
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id=str(node_id))),
self.query(successors_query, dict(node_id=str(node_id))),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
return connections
async def remove_connection_to_predecessors_of(
self, node_ids: list[str], edge_label: str
) -> None:
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
DELETE r;
"""
params = {"node_ids": node_ids}
return await self.query(query, params)
async def remove_connection_to_successors_of(
self, node_ids: list[str], edge_label: str
) -> None:
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
DELETE r;
"""
params = {"node_ids": node_ids}
return await self.query(query, params)
async def delete_graph(self):
query = """MATCH (node)
DETACH DELETE node;"""
return await self.query(query)
def serialize_properties(self, properties=dict()):
serialized_properties = {}
for property_key, property_value in properties.items():
if isinstance(property_value, UUID):
serialized_properties[property_key] = str(property_value)
continue
if isinstance(property_value, dict):
serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder)
continue
serialized_properties[property_key] = property_value
return serialized_properties
async def get_model_independent_graph_data(self):
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
nodes = await self.query(query_nodes)
query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
edges = await self.query(query_edges)
return (nodes, edges)
async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
result = await self.query(query)
nodes = [
(
record["properties"]["id"],
record["properties"],
)
for record in result
]
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = await self.query(query)
edges = [
(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)
for record in result
]
return (nodes, edges)
async def get_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
label = node_type.__name__
query = f"""
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
result = await self.query(query, {"names": node_name})
if not result:
return [], []
raw_nodes = result[0]["rawNodes"]
raw_rels = result[0]["rawRels"]
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
edges = [
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
)
for r in raw_rels
]
return nodes, edges
async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.
Args:
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
Example: [{"community": ["1", "2"]}]
Returns:
tuple: A tuple containing two lists: nodes and edges.
"""
where_clauses = []
for attribute, values in attribute_filters[0].items():
values_str = ", ".join(
f"'{value}'" if isinstance(value, str) else str(value) for value in values
)
where_clauses.append(f"n.{attribute} IN [{values_str}]")
where_clause = " AND ".join(where_clauses)
query_nodes = f"""
MATCH (n)
WHERE {where_clause}
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
"""
result_nodes = await self.query(query_nodes)
nodes = [
(
record["id"],
record["properties"],
)
for record in result_nodes
]
query_edges = f"""
MATCH (n)-[r]->(m)
WHERE {where_clause} AND {where_clause.replace("n.", "m.")}
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result_edges = await self.query(query_edges)
edges = [
(
record["source"],
record["target"],
record["type"],
record["properties"],
)
for record in result_edges
]
return (nodes, edges)
async def graph_exists(self, graph_name="myGraph"):
query = "CALL gds.graph.list() YIELD graphName RETURN collect(graphName) AS graphNames;"
result = await self.query(query)
graph_names = result[0]["graphNames"] if result else []
return graph_name in graph_names
async def get_node_labels_string(self):
node_labels_query = "CALL db.labels() YIELD label RETURN collect(label) AS labels;"
node_labels_result = await self.query(node_labels_query)
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
if not node_labels:
raise ValueError("No node labels found in the database")
node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
return node_labels_str
async def get_relationship_labels_string(self):
relationship_types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) AS relationships;"
relationship_types_result = await self.query(relationship_types_query)
relationship_types = (
relationship_types_result[0]["relationships"] if relationship_types_result else []
)
if not relationship_types:
raise ValueError("No relationship types found in the database.")
relationship_types_undirected_str = (
"{"
+ ", ".join(f"{rel}" + ": {orientation: 'UNDIRECTED'}" for rel in relationship_types)
+ "}"
)
return relationship_types_undirected_str
async def project_entire_graph(self, graph_name="myGraph"):
"""
Projects all node labels and all relationship types into an undirected in-memory GDS graph.
"""
if await self.graph_exists(graph_name):
return
node_labels_str = await self.get_node_labels_string()
relationship_types_undirected_str = await self.get_relationship_labels_string()
query = f"""
CALL gds.graph.project(
'{graph_name}',
{node_labels_str},
{relationship_types_undirected_str}
) YIELD graphName;
"""
await self.query(query)
async def drop_graph(self, graph_name="myGraph"):
if await self.graph_exists(graph_name):
drop_query = f"CALL gds.graph.drop('{graph_name}');"
await self.query(drop_query)
async def get_graph_metrics(self, include_optional=False):
"""For the definition of these metrics, please refer to
https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
nodes, edges = await self.get_model_independent_graph_data()
graph_name = "myGraph"
await self.drop_graph(graph_name)
await self.project_entire_graph(graph_name)
num_nodes = len(nodes[0]["nodes"])
num_edges = len(edges[0]["elements"])
mandatory_metrics = {
"num_nodes": num_nodes,
"num_edges": num_edges,
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
"edge_density": await get_edge_density(self),
"num_connected_components": await get_num_connected_components(self, graph_name),
"sizes_of_connected_components": await get_size_of_connected_components(
self, graph_name
),
}
if include_optional:
shortest_path_lengths = await get_shortest_path_lengths(self, graph_name)
optional_metrics = {
"num_selfloops": await count_self_loops(self),
"diameter": max(shortest_path_lengths) if shortest_path_lengths else -1,
"avg_shortest_path_length": sum(shortest_path_lengths) / len(shortest_path_lengths)
if shortest_path_lengths
else -1,
"avg_clustering": await get_avg_clustering(self, graph_name),
}
else:
optional_metrics = {
"num_selfloops": -1,
"diameter": -1,
"avg_shortest_path_length": -1,
"avg_clustering": -1,
}
return mandatory_metrics | optional_metrics
async def get_document_subgraph(self, content_hash: str):
query = """
MATCH (doc)
WHERE (doc:TextDocument OR doc:PdfDocument)
AND doc.name = 'text_' + $content_hash
OPTIONAL MATCH (doc)<-[:is_part_of]-(chunk:DocumentChunk)
OPTIONAL MATCH (chunk)-[:contains]->(entity:Entity)
WHERE NOT EXISTS {
MATCH (entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument)
AND otherDoc.id <> doc.id
}
OPTIONAL MATCH (chunk)<-[:made_from]-(made_node:TextSummary)
OPTIONAL MATCH (entity)-[:is_a]->(type:EntityType)
WHERE NOT EXISTS {
MATCH (type)<-[:is_a]-(otherEntity:Entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument)
AND otherDoc.id <> doc.id
}
RETURN
collect(DISTINCT doc) as document,
collect(DISTINCT chunk) as chunks,
collect(DISTINCT entity) as orphan_entities,
collect(DISTINCT made_node) as made_from_nodes,
collect(DISTINCT type) as orphan_types
"""
result = await self.query(query, {"content_hash": content_hash})
return result[0] if result else None
async def get_degree_one_nodes(self, node_type: str):
if not node_type or node_type not in ["Entity", "EntityType"]:
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
query = f"""
MATCH (n:{node_type})
WHERE COUNT {{ MATCH (n)--() }} = 1
RETURN n
"""
result = await self.query(query)
return [record["n"] for record in result] if result else []