diff --git a/cognee/modules/graph/utils/get_model_instance_from_graph.py b/cognee/modules/graph/utils/get_model_instance_from_graph.py index 82cdfa150..bdd0dface 100644 --- a/cognee/modules/graph/utils/get_model_instance_from_graph.py +++ b/cognee/modules/graph/utils/get_model_instance_from_graph.py @@ -2,28 +2,35 @@ from pydantic_core import PydanticUndefined from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model +def merge_dicts(dict1, dict2, agg_fn): + merged_dict = {} + for key, value in dict1.items(): + if key in dict2: + merged_dict[key] = agg_fn(value, dict2[key]) + else: + merged_dict[key] = value -def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str): - node_map = {} + for key, value in dict2.items(): + if key not in merged_dict: + merged_dict[key] = value + return merged_dict - for node in nodes: - node_map[node.id] = node +def get_model_instance_from_graph(nodes: list[DataPoint], edges: list[tuple[str, str, str, dict[str, str]]], entity_id: str): + node_map = {node.id: node for node in nodes} - for edge in edges: - source_node = node_map[edge[0]] - target_node = node_map[edge[1]] - edge_label = edge[2] - edge_properties = edge[3] if len(edge) == 4 else {} + for source_node_id, target_node_id, edge_label, edge_properties in edges: + source_node = node_map[source_node_id] + target_node = node_map[target_node_id] edge_metadata = edge_properties.get("metadata", {}) edge_type = edge_metadata.get("type") if edge_type == "list": NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) }) - - node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] }) + new_model_dict = merge_dicts(source_node.model_dump(), { edge_label: [target_node] }, lambda a, b: a + b) + node_map[source_node_id] = NewModel(**new_model_dict) else: NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) }) - node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) + node_map[target_node_id] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) return node_map[entity_id]