diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 7b05d2046..d1e14c878 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -2,9 +2,70 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges): - property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) +def get_graph_from_model( + data_point: DataPoint, include_root=True, added_nodes=None, added_edges=None +): + + if not added_nodes: + added_nodes = {} + if not added_edges: + added_edges = {} + + nodes = [] + edges = [] + + data_point_properties = {} + excluded_properties = set() + + for field_name, field_value in data_point: + if field_name == "_metadata": + continue + elif isinstance(field_value, DataPoint): + excluded_properties.add(field_name) + nodes, edges, added_nodes, added_edges = add_nodes_and_edges( + data_point, field_name, field_value, nodes, edges, added_nodes, added_edges + ) + + elif ( + isinstance(field_value, list) + and len(field_value) > 0 + and isinstance(field_value[0], DataPoint) + ): + excluded_properties.add(field_name) + + for item in field_value: + nodes, edges, added_nodes, added_edges = add_nodes_and_edges( + data_point, field_name, item, nodes, edges, added_nodes, added_edges + ) + edges = [ + (*edge[:3], {**edge[3], "metadata": {"type": "list"}}) + for edge in edges + ] + else: + data_point_properties[field_name] = field_value + + SimpleDataPointModel = copy_model( + type(data_point), + include_fields={ + "_metadata": (dict, data_point._metadata), + }, + exclude_fields=excluded_properties, + ) + + if include_root: + nodes.append(SimpleDataPointModel(**data_point_properties)) + + return nodes, edges + + +def add_nodes_and_edges( + data_point, field_name, field_value, nodes, edges, added_nodes, added_edges +): + + property_nodes, property_edges = get_graph_from_model( + field_value, True, added_nodes, added_edges + ) for node in property_nodes: if str(node.id) not in added_nodes: @@ -22,63 +83,24 @@ def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added edge_key = str(data_point.id) + str(property_node.id) + field_name if str(edge_key) not in added_edges: - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - })) + edges.append( + ( + data_point.id, + property_node.id, + field_name, + { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S" + ), + }, + ) + ) added_edges[str(edge_key)] = True - - return(nodes, edges, added_nodes, added_edges) - -def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = None, added_edges = None): - - if not added_nodes: - added_nodes = {} - if not added_edges: - added_edges = {} - - nodes = [] - edges = [] - - data_point_properties = {} - excluded_properties = set() - - for field_name, field_value in data_point: - if field_name == "_metadata": - continue - - if isinstance(field_value, DataPoint): - excluded_properties.add(field_name) - - nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges) - - - if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): - excluded_properties.add(field_name) - - for item in field_value: - nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, item, nodes, edges, added_nodes, added_edges) - edges = [(*edge[:3],{**edge[3], "metadata": {"type": "list"}}) for edge in edges] - - continue - - data_point_properties[field_name] = field_value - - SimpleDataPointModel = copy_model( - type(data_point), - include_fields = { - "_metadata": (dict, data_point._metadata), - }, - exclude_fields = excluded_properties, - ) - - if include_root: - nodes.append(SimpleDataPointModel(**data_point_properties)) - - return nodes, edges + return (nodes, edges, added_nodes, added_edges) def get_own_properties(property_nodes, property_edges):