<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: vasilije <vas.markovic@gmail.com> Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Co-authored-by: Igor Ilic <igorilic03@gmail.com> Co-authored-by: Hande <159312713+hande-k@users.noreply.github.com> Co-authored-by: Matea Pesic <80577904+matea16@users.noreply.github.com> Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com> Co-authored-by: Daniel Molnar <soobrosa@gmail.com> Co-authored-by: Diego Baptista Theuerkauf <34717973+diegoabt@users.noreply.github.com>
636 lines
25 KiB
Python
636 lines
25 KiB
Python
"""Adapter for NetworkX graph database."""
|
|
|
|
from datetime import datetime, timezone
|
|
import os
|
|
import json
|
|
import asyncio
|
|
from cognee.shared.logging_utils import get_logger
|
|
from typing import Dict, Any, List, Union
|
|
from uuid import UUID
|
|
import aiofiles
|
|
import aiofiles.os as aiofiles_os
|
|
import networkx as nx
|
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|
GraphDBInterface,
|
|
record_graph_changes,
|
|
)
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.infrastructure.engine.utils import parse_id
|
|
from cognee.modules.storage.utils import JSONEncoder
|
|
import numpy as np
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class NetworkXAdapter(GraphDBInterface):
|
|
_instance = None
|
|
graph = None # Class variable to store the singleton instance
|
|
|
|
def __new__(cls, filename):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance.filename = filename
|
|
return cls._instance
|
|
|
|
def __init__(self, filename="cognee_graph.pkl"):
|
|
self.filename = filename
|
|
|
|
async def get_graph_data(self):
|
|
await self.load_graph_from_file()
|
|
return (list(self.graph.nodes(data=True)), list(self.graph.edges(data=True, keys=True)))
|
|
|
|
async def query(self, query: str, params: dict):
|
|
pass
|
|
|
|
async def has_node(self, node_id: UUID) -> bool:
|
|
return self.graph.has_node(node_id)
|
|
|
|
async def add_node(self, node: DataPoint) -> None:
|
|
self.graph.add_node(node.id, **node.model_dump())
|
|
|
|
await self.save_graph_to_file(self.filename)
|
|
|
|
@record_graph_changes
|
|
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
|
nodes = [(node.id, node.model_dump()) for node in nodes]
|
|
self.graph.add_nodes_from(nodes)
|
|
await self.save_graph_to_file(self.filename)
|
|
|
|
async def get_graph(self):
|
|
return self.graph
|
|
|
|
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
|
return self.graph.has_edge(from_node, to_node, key=edge_label)
|
|
|
|
async def has_edges(self, edges):
|
|
result = []
|
|
|
|
for from_node, to_node, edge_label in edges:
|
|
if self.graph.has_edge(from_node, to_node, edge_label):
|
|
result.append((from_node, to_node, edge_label))
|
|
|
|
return result
|
|
|
|
@record_graph_changes
|
|
async def add_edge(
|
|
self,
|
|
from_node: str,
|
|
to_node: str,
|
|
relationship_name: str,
|
|
edge_properties: Dict[str, Any] = {},
|
|
) -> None:
|
|
edge_properties["updated_at"] = datetime.now(timezone.utc)
|
|
self.graph.add_edge(
|
|
from_node,
|
|
to_node,
|
|
key=relationship_name,
|
|
**(edge_properties if edge_properties else {}),
|
|
)
|
|
|
|
await self.save_graph_to_file(self.filename)
|
|
|
|
@record_graph_changes
|
|
async def add_edges(self, edges: list[tuple[str, str, str, dict]]) -> None:
|
|
if not edges:
|
|
logger.debug("No edges to add")
|
|
return
|
|
|
|
try:
|
|
# Validate edge format and convert UUIDs to strings
|
|
processed_edges = []
|
|
for edge in edges:
|
|
if len(edge) < 3 or len(edge) > 4:
|
|
raise ValueError(
|
|
f"Invalid edge format: {edge}. Expected (from_node, to_node, relationship_name[, properties])"
|
|
)
|
|
|
|
# Convert UUIDs to strings if needed
|
|
from_node = str(edge[0]) if isinstance(edge[0], UUID) else edge[0]
|
|
to_node = str(edge[1]) if isinstance(edge[1], UUID) else edge[1]
|
|
relationship_name = edge[2]
|
|
|
|
if not all(isinstance(x, str) for x in [from_node, to_node, relationship_name]):
|
|
raise ValueError(
|
|
f"First three elements of edge must be strings or UUIDs: {edge}"
|
|
)
|
|
|
|
# Process edge with updated_at timestamp
|
|
processed_edge = (
|
|
from_node,
|
|
to_node,
|
|
relationship_name,
|
|
{
|
|
**(edge[3] if len(edge) == 4 else {}),
|
|
"updated_at": datetime.now(timezone.utc),
|
|
},
|
|
)
|
|
processed_edges.append(processed_edge)
|
|
|
|
# Add edges to graph
|
|
self.graph.add_edges_from(processed_edges)
|
|
logger.debug(f"Added {len(processed_edges)} edges to graph")
|
|
|
|
# Save changes
|
|
await self.save_graph_to_file(self.filename)
|
|
except Exception as e:
|
|
logger.error(f"Failed to add edges: {e}")
|
|
raise
|
|
|
|
async def get_edges(self, node_id: UUID):
|
|
return list(self.graph.in_edges(node_id, data=True)) + list(
|
|
self.graph.out_edges(node_id, data=True)
|
|
)
|
|
|
|
async def delete_node(self, node_id: UUID) -> None:
|
|
"""Asynchronously delete a node and all its relationships from the graph if it exists."""
|
|
|
|
if self.graph.has_node(node_id):
|
|
# First remove all edges connected to the node
|
|
for edge in list(self.graph.edges(node_id, data=True)):
|
|
source, target, data = edge
|
|
self.graph.remove_edge(source, target, key=data.get("relationship_name"))
|
|
|
|
# Then remove the node itself
|
|
self.graph.remove_node(node_id)
|
|
|
|
# Save the updated graph state
|
|
await self.save_graph_to_file(self.filename)
|
|
else:
|
|
logger.error(f"Node {node_id} not found in graph")
|
|
|
|
async def delete_nodes(self, node_ids: List[UUID]) -> None:
|
|
self.graph.remove_nodes_from(node_ids)
|
|
await self.save_graph_to_file(self.filename)
|
|
|
|
async def get_disconnected_nodes(self) -> List[str]:
|
|
connected_components = list(nx.weakly_connected_components(self.graph))
|
|
|
|
disconnected_nodes = []
|
|
biggest_subgraph = max(connected_components, key=len)
|
|
|
|
for component in connected_components:
|
|
if component != biggest_subgraph:
|
|
disconnected_nodes.extend(list(component))
|
|
|
|
return disconnected_nodes
|
|
|
|
async def extract_node(self, node_id: UUID) -> dict:
|
|
if self.graph.has_node(node_id):
|
|
return self.graph.nodes[node_id]
|
|
|
|
return None
|
|
|
|
async def extract_nodes(self, node_ids: List[UUID]) -> List[dict]:
|
|
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
|
|
|
async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
|
|
if self.graph.has_node(node_id):
|
|
if edge_label is None:
|
|
return [
|
|
self.graph.nodes[predecessor]
|
|
for predecessor in list(self.graph.predecessors(node_id))
|
|
]
|
|
|
|
nodes = []
|
|
|
|
for predecessor_id in list(self.graph.predecessors(node_id)):
|
|
if self.graph.has_edge(predecessor_id, node_id, edge_label):
|
|
nodes.append(self.graph.nodes[predecessor_id])
|
|
|
|
return nodes
|
|
|
|
async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:
|
|
if self.graph.has_node(node_id):
|
|
if edge_label is None:
|
|
return [
|
|
self.graph.nodes[successor]
|
|
for successor in list(self.graph.successors(node_id))
|
|
]
|
|
|
|
nodes = []
|
|
|
|
for successor_id in list(self.graph.successors(node_id)):
|
|
if self.graph.has_edge(node_id, successor_id, edge_label):
|
|
nodes.append(self.graph.nodes[successor_id])
|
|
|
|
return nodes
|
|
|
|
async def get_neighbors(self, node_id: UUID) -> list:
|
|
if not self.graph.has_node(node_id):
|
|
return []
|
|
|
|
predecessors, successors = await asyncio.gather(
|
|
self.get_predecessors(node_id),
|
|
self.get_successors(node_id),
|
|
)
|
|
|
|
neighbors = predecessors + successors
|
|
|
|
return neighbors
|
|
|
|
async def get_connections(self, node_id: UUID) -> list:
|
|
if not self.graph.has_node(node_id):
|
|
return []
|
|
|
|
node = self.graph.nodes[node_id]
|
|
|
|
if "id" not in node:
|
|
return []
|
|
|
|
predecessors, successors = await asyncio.gather(
|
|
self.get_predecessors(node_id),
|
|
self.get_successors(node_id),
|
|
)
|
|
|
|
connections = []
|
|
|
|
# Handle None values for predecessors and successors
|
|
if predecessors is not None:
|
|
for neighbor in predecessors:
|
|
if "id" in neighbor:
|
|
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
|
|
if edge_data is not None:
|
|
for edge_properties in edge_data.values():
|
|
connections.append((neighbor, edge_properties, node))
|
|
|
|
if successors is not None:
|
|
for neighbor in successors:
|
|
if "id" in neighbor:
|
|
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
|
|
if edge_data is not None:
|
|
for edge_properties in edge_data.values():
|
|
connections.append((node, edge_properties, neighbor))
|
|
|
|
return connections
|
|
|
|
async def remove_connection_to_predecessors_of(
|
|
self, node_ids: list[UUID], edge_label: str
|
|
) -> None:
|
|
for node_id in node_ids:
|
|
if self.graph.has_node(node_id):
|
|
for predecessor_id in list(self.graph.predecessors(node_id)):
|
|
if self.graph.has_edge(predecessor_id, node_id, edge_label):
|
|
self.graph.remove_edge(predecessor_id, node_id, edge_label)
|
|
|
|
await self.save_graph_to_file(self.filename)
|
|
|
|
async def remove_connection_to_successors_of(
|
|
self, node_ids: list[UUID], edge_label: str
|
|
) -> None:
|
|
for node_id in node_ids:
|
|
if self.graph.has_node(node_id):
|
|
for successor_id in list(self.graph.successors(node_id)):
|
|
if self.graph.has_edge(node_id, successor_id, edge_label):
|
|
self.graph.remove_edge(node_id, successor_id, edge_label)
|
|
|
|
await self.save_graph_to_file(self.filename)
|
|
|
|
async def create_empty_graph(self, file_path: str) -> None:
|
|
self.graph = nx.MultiDiGraph()
|
|
|
|
# Only create directory if file_path contains a directory
|
|
file_dir = os.path.dirname(file_path)
|
|
if file_dir and not os.path.exists(file_dir):
|
|
os.makedirs(file_dir, exist_ok=True)
|
|
|
|
await self.save_graph_to_file(file_path)
|
|
|
|
async def save_graph_to_file(self, file_path: str = None) -> None:
|
|
"""Asynchronously save the graph to a file in JSON format."""
|
|
if not file_path:
|
|
file_path = self.filename
|
|
|
|
graph_data = nx.readwrite.json_graph.node_link_data(self.graph, edges="links")
|
|
|
|
async with aiofiles.open(file_path, "w") as file:
|
|
json_data = json.dumps(graph_data, cls=JSONEncoder)
|
|
await file.write(json_data)
|
|
|
|
async def load_graph_from_file(self, file_path: str = None):
|
|
"""Asynchronously load the graph from a file in JSON format."""
|
|
if not file_path:
|
|
file_path = self.filename
|
|
try:
|
|
if os.path.exists(file_path):
|
|
async with aiofiles.open(file_path, "r") as file:
|
|
graph_data = json.loads(await file.read())
|
|
for node in graph_data["nodes"]:
|
|
try:
|
|
if not isinstance(node["id"], UUID):
|
|
try:
|
|
node["id"] = UUID(node["id"])
|
|
except Exception:
|
|
# If conversion fails, keep the original id
|
|
pass
|
|
except Exception as e:
|
|
logger.error(e)
|
|
raise e
|
|
|
|
if isinstance(node.get("updated_at"), int):
|
|
node["updated_at"] = datetime.fromtimestamp(
|
|
node["updated_at"] / 1000, tz=timezone.utc
|
|
)
|
|
elif isinstance(node.get("updated_at"), str):
|
|
node["updated_at"] = datetime.strptime(
|
|
node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
|
|
)
|
|
|
|
for edge in graph_data["links"]:
|
|
try:
|
|
if not isinstance(edge["source"], UUID):
|
|
source_id = parse_id(edge["source"])
|
|
else:
|
|
source_id = edge["source"]
|
|
|
|
if not isinstance(edge["target"], UUID):
|
|
target_id = parse_id(edge["target"])
|
|
else:
|
|
target_id = edge["target"]
|
|
|
|
edge["source"] = source_id
|
|
edge["target"] = target_id
|
|
edge["source_node_id"] = source_id
|
|
edge["target_node_id"] = target_id
|
|
except Exception as e:
|
|
logger.error(e)
|
|
raise e
|
|
|
|
if isinstance(
|
|
edge.get("updated_at"), int
|
|
): # Handle timestamp in milliseconds
|
|
edge["updated_at"] = datetime.fromtimestamp(
|
|
edge["updated_at"] / 1000, tz=timezone.utc
|
|
)
|
|
elif isinstance(edge.get("updated_at"), str):
|
|
edge["updated_at"] = datetime.strptime(
|
|
edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
|
|
)
|
|
|
|
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data, edges="links")
|
|
|
|
for node_id, node_data in self.graph.nodes(data=True):
|
|
node_data["id"] = node_id
|
|
else:
|
|
# Log that the file does not exist and an empty graph is initialized
|
|
logger.warning("File %s not found. Initializing an empty graph.", file_path)
|
|
await self.create_empty_graph(file_path)
|
|
|
|
except Exception:
|
|
logger.error("Failed to load graph from file: %s", file_path)
|
|
|
|
await self.create_empty_graph(file_path)
|
|
|
|
async def delete_graph(self, file_path: str = None):
|
|
"""Asynchronously delete the graph file from the filesystem."""
|
|
if file_path is None:
|
|
file_path = (
|
|
self.filename
|
|
) # Assuming self.filename is defined elsewhere and holds the default graph file path
|
|
try:
|
|
if os.path.exists(file_path):
|
|
await aiofiles_os.remove(file_path)
|
|
|
|
self.graph = None
|
|
logger.info("Graph deleted successfully.")
|
|
except Exception as error:
|
|
logger.error("Failed to delete graph: %s", error)
|
|
raise error
|
|
|
|
async def get_filtered_graph_data(
|
|
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
|
):
|
|
"""
|
|
Fetches nodes and relationships filtered by specified attribute values.
|
|
|
|
Args:
|
|
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
|
Example: [{"community": ["1", "2"]}]
|
|
|
|
Returns:
|
|
tuple: A tuple containing two lists:
|
|
- Nodes: List of tuples (node_id, node_properties).
|
|
- Edges: List of tuples (source_id, target_id, relationship_type, edge_properties).
|
|
"""
|
|
# Create filters for nodes based on the attribute filters
|
|
where_clauses = []
|
|
for attribute, values in attribute_filters[0].items():
|
|
where_clauses.append((attribute, values))
|
|
|
|
# Filter nodes
|
|
filtered_nodes = [
|
|
(node, data)
|
|
for node, data in self.graph.nodes(data=True)
|
|
if all(data.get(attr) in values for attr, values in where_clauses)
|
|
]
|
|
|
|
# Filter edges where both source and target nodes satisfy the filters
|
|
filtered_edges = [
|
|
(source, target, data.get("relationship_type", "UNKNOWN"), data)
|
|
for source, target, data in self.graph.edges(data=True)
|
|
if (
|
|
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses)
|
|
and all(
|
|
self.graph.nodes[target].get(attr) in values for attr, values in where_clauses
|
|
)
|
|
)
|
|
]
|
|
|
|
return filtered_nodes, filtered_edges
|
|
|
|
async def get_graph_metrics(self, include_optional=False):
|
|
graph = self.graph
|
|
|
|
def _get_mean_degree(graph):
|
|
degrees = [d for _, d in graph.degree()]
|
|
return np.mean(degrees) if degrees else 0
|
|
|
|
def _get_edge_density(graph):
|
|
num_nodes = graph.number_of_nodes()
|
|
num_edges = graph.number_of_edges()
|
|
num_possible_edges = num_nodes * (num_nodes - 1)
|
|
edge_density = num_edges / num_possible_edges if num_possible_edges > 0 else 0
|
|
return edge_density
|
|
|
|
def _get_diameter(graph):
|
|
try:
|
|
return nx.diameter(nx.DiGraph(graph.to_undirected()))
|
|
except Exception as e:
|
|
logger.warning("Failed to calculate diameter: %s", e)
|
|
return None
|
|
|
|
def _get_avg_shortest_path_length(graph):
|
|
try:
|
|
return nx.average_shortest_path_length(nx.DiGraph(graph.to_undirected()))
|
|
except Exception as e:
|
|
logger.warning("Failed to calculate average shortest path length: %s", e)
|
|
return None
|
|
|
|
def _get_avg_clustering(graph):
|
|
try:
|
|
return nx.average_clustering(nx.DiGraph(graph.to_undirected()))
|
|
except Exception as e:
|
|
logger.warning("Failed to calculate clustering coefficient: %s", e)
|
|
return None
|
|
|
|
mandatory_metrics = {
|
|
"num_nodes": graph.number_of_nodes(),
|
|
"num_edges": graph.number_of_edges(),
|
|
"mean_degree": _get_mean_degree(graph),
|
|
"edge_density": _get_edge_density(graph),
|
|
"num_connected_components": nx.number_weakly_connected_components(graph),
|
|
"sizes_of_connected_components": [
|
|
len(c) for c in nx.weakly_connected_components(graph)
|
|
],
|
|
}
|
|
|
|
if include_optional:
|
|
optional_metrics = {
|
|
"num_selfloops": sum(1 for u, v in graph.edges() if u == v),
|
|
"diameter": _get_diameter(graph),
|
|
"avg_shortest_path_length": _get_avg_shortest_path_length(graph),
|
|
"avg_clustering": _get_avg_clustering(graph),
|
|
}
|
|
else:
|
|
optional_metrics = {
|
|
"num_selfloops": -1,
|
|
"diameter": -1,
|
|
"avg_shortest_path_length": -1,
|
|
"avg_clustering": -1,
|
|
}
|
|
|
|
return mandatory_metrics | optional_metrics
|
|
|
|
async def get_document_subgraph(self, content_hash: str):
|
|
"""Get all nodes that should be deleted when removing a document."""
|
|
# Ensure graph is loaded
|
|
if self.graph is None:
|
|
await self.load_graph_from_file()
|
|
|
|
# Find the document node by looking for content_hash in the name field
|
|
document = None
|
|
document_node_id = None
|
|
for node_id, attrs in self.graph.nodes(data=True):
|
|
if (
|
|
attrs.get("type") in ["TextDocument", "PdfDocument"]
|
|
and attrs.get("name") == f"text_{content_hash}"
|
|
):
|
|
document = {"id": str(node_id), **attrs} # Convert UUID to string for consistency
|
|
document_node_id = node_id # Keep the original UUID
|
|
break
|
|
|
|
if not document:
|
|
return None
|
|
|
|
# Find chunks connected via is_part_of (chunks point TO document)
|
|
chunks = []
|
|
for source, target, edge_data in self.graph.in_edges(document_node_id, data=True):
|
|
if edge_data.get("relationship_name") == "is_part_of":
|
|
chunks.append({"id": source, **self.graph.nodes[source]}) # Keep as UUID object
|
|
|
|
# Find entities connected to chunks (chunks point TO entities via contains)
|
|
entities = []
|
|
for chunk in chunks:
|
|
chunk_id = chunk["id"] # Already a UUID object
|
|
for source, target, edge_data in self.graph.out_edges(chunk_id, data=True):
|
|
if edge_data.get("relationship_name") == "contains":
|
|
entities.append(
|
|
{"id": target, **self.graph.nodes[target]}
|
|
) # Keep as UUID object
|
|
|
|
# Find orphaned entities (entities only connected to chunks we're deleting)
|
|
orphan_entities = []
|
|
for entity in entities:
|
|
entity_id = entity["id"] # Already a UUID object
|
|
# Get all chunks that contain this entity
|
|
containing_chunks = []
|
|
for source, target, edge_data in self.graph.in_edges(entity_id, data=True):
|
|
if edge_data.get("relationship_name") == "contains":
|
|
containing_chunks.append(source) # Keep as UUID object
|
|
|
|
# Check if all containing chunks are in our chunks list
|
|
chunk_ids = [chunk["id"] for chunk in chunks]
|
|
if containing_chunks and all(c in chunk_ids for c in containing_chunks):
|
|
orphan_entities.append(entity)
|
|
|
|
# Find orphaned entity types
|
|
orphan_types = []
|
|
seen_types = set() # Track seen types to avoid duplicates
|
|
for entity in orphan_entities:
|
|
entity_id = entity["id"] # Already a UUID object
|
|
for _, target, edge_data in self.graph.out_edges(entity_id, data=True):
|
|
if edge_data.get("relationship_name") in ["is_a", "instance_of"]:
|
|
# Check if this type is only connected to entities we're deleting
|
|
type_node = self.graph.nodes[target]
|
|
if type_node.get("type") == "EntityType" and target not in seen_types:
|
|
is_orphaned = True
|
|
# Get all incoming edges to this type node
|
|
for source, _, edge_data in self.graph.in_edges(target, data=True):
|
|
if edge_data.get("relationship_name") in ["is_a", "instance_of"]:
|
|
# Check if the source entity is not in our orphan_entities list
|
|
if source not in [e["id"] for e in orphan_entities]:
|
|
is_orphaned = False
|
|
break
|
|
if is_orphaned:
|
|
orphan_types.append({"id": target, **type_node}) # Keep as UUID object
|
|
seen_types.add(target) # Mark as seen
|
|
|
|
# Find nodes connected via made_from (chunks point TO summaries)
|
|
made_from_nodes = []
|
|
for chunk in chunks:
|
|
chunk_id = chunk["id"] # Already a UUID object
|
|
for source, target, edge_data in self.graph.in_edges(chunk_id, data=True):
|
|
if edge_data.get("relationship_name") == "made_from":
|
|
made_from_nodes.append(
|
|
{"id": source, **self.graph.nodes[source]}
|
|
) # Keep as UUID object
|
|
|
|
# Return UUIDs directly without string conversion
|
|
return {
|
|
"document": [{"id": document["id"], **{k: v for k, v in document.items() if k != "id"}}]
|
|
if document
|
|
else [],
|
|
"chunks": [
|
|
{"id": chunk["id"], **{k: v for k, v in chunk.items() if k != "id"}}
|
|
for chunk in chunks
|
|
],
|
|
"orphan_entities": [
|
|
{"id": entity["id"], **{k: v for k, v in entity.items() if k != "id"}}
|
|
for entity in orphan_entities
|
|
],
|
|
"made_from_nodes": [
|
|
{"id": node["id"], **{k: v for k, v in node.items() if k != "id"}}
|
|
for node in made_from_nodes
|
|
],
|
|
"orphan_types": [
|
|
{"id": type_node["id"], **{k: v for k, v in type_node.items() if k != "id"}}
|
|
for type_node in orphan_types
|
|
],
|
|
}
|
|
|
|
async def get_degree_one_nodes(self, node_type: str):
|
|
"""Get all nodes that have only one connection."""
|
|
if not node_type or node_type not in ["Entity", "EntityType"]:
|
|
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
|
|
|
nodes = []
|
|
for node_id, node_data in self.graph.nodes(data=True):
|
|
if node_data.get("type") == node_type:
|
|
# Count both incoming and outgoing edges
|
|
degree = self.graph.degree(node_id)
|
|
if degree == 1:
|
|
nodes.append(node_data)
|
|
return nodes
|
|
|
|
async def get_node(self, node_id: UUID) -> dict:
|
|
if self.graph.has_node(node_id):
|
|
return self.graph.nodes[node_id]
|
|
return None
|
|
|
|
async def get_nodes(self, node_ids: List[UUID] = None) -> List[dict]:
|
|
if node_ids is None:
|
|
return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)]
|
|
return [
|
|
{"id": node_id, **self.graph.nodes[node_id]}
|
|
for node_id in node_ids
|
|
if self.graph.has_node(node_id)
|
|
]
|