From 2c0fce32d33a0684545325ad7c56791d62e0861b Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 13:38:33 +0100 Subject: [PATCH] WIP get_graph_from_model --- .../graph/utils/get_graph_from_model.py | 86 ++++++++----------- 1 file changed, 34 insertions(+), 52 deletions(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 810be7ce8..7b05d2046 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -2,6 +2,37 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model +def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges): + + property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) + + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[str(edge_key)] = True + + for property_node in get_own_properties(property_nodes, property_edges): + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: + edges.append((data_point.id, property_node.id, field_name, { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + })) + added_edges[str(edge_key)] = True + + return(nodes, edges, added_nodes, added_edges) + + def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = None, added_edges = None): if not added_nodes: @@ -22,65 +53,16 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes if isinstance(field_value, DataPoint): excluded_properties.add(field_name) - property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) + nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges) - for node in property_nodes: - if str(node.id) not in added_nodes: - nodes.append(node) - added_nodes[str(node.id)] = True - - for edge in property_edges: - edge_key = str(edge[0]) + str(edge[1]) + edge[2] - - if str(edge_key) not in added_edges: - edges.append(edge) - added_edges[str(edge_key)] = True - - for property_node in get_own_properties(property_nodes, property_edges): - edge_key = str(data_point.id) + str(property_node.id) + field_name - - if str(edge_key) not in added_edges: - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - })) - added_edges[str(edge_key)] = True - continue if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): excluded_properties.add(field_name) for item in field_value: - property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) + nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, item, nodes, edges, added_nodes, added_edges) + edges = [(*edge[:3],{**edge[3], "metadata": {"type": "list"}}) for edge in edges] - for node in property_nodes: - if str(node.id) not in added_nodes: - nodes.append(node) - added_nodes[str(node.id)] = True - - for edge in property_edges: - edge_key = str(edge[0]) + str(edge[1]) + edge[2] - - if str(edge_key) not in added_edges: - edges.append(edge) - added_edges[edge_key] = True - - for property_node in get_own_properties(property_nodes, property_edges): - edge_key = str(data_point.id) + str(property_node.id) + field_name - - if str(edge_key) not in added_edges: - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - "metadata": { - "type": "list" - }, - })) - added_edges[edge_key] = True continue data_point_properties[field_name] = field_value