Merge dicts directly

This commit is contained in:
Leon Luithlen 2024-11-19 10:56:21 +01:00
parent fde56f0c3b
commit b18f748c9e

View file

@ -6,20 +6,6 @@ from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def merge_dicts(dict1: dict, dict2: dict, agg_fn: Callable) -> dict:
merged_dict = {}
for key, value in dict1.items():
if key in dict2:
merged_dict[key] = agg_fn(value, dict2[key])
else:
merged_dict[key] = value
for key, value in dict2.items():
if key not in merged_dict:
merged_dict[key] = value
return merged_dict
def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
@ -38,12 +24,11 @@ def get_model_instance_from_graph(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
new_model_dict = merge_dicts(
source_node.model_dump(),
{edge_label: [target_node]},
lambda a, b: a + b,
)
node_map[source_node_id] = NewModel(**new_model_dict)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]
node_map[source_node_id] = NewModel(**source_node_dict)
else:
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}