From a3342918d971d3c9a3d46dd923d0c0a5d166a4b4 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 16:53:32 +0100 Subject: [PATCH] Apply cosmetic changes and autoformat --- .../graph/utils/get_graph_from_model.py | 18 ++++++---- .../utils/get_model_instance_from_graph.py | 34 +++++++++++++++---- cognee/tests/unit/interfaces/graph/util.py | 1 - 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index bd6480ba0..770e63d05 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,11 +1,10 @@ from datetime import datetime, timezone + from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def get_graph_from_model( - data_point: DataPoint, added_nodes=None, added_edges=None -): +def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None): if not added_nodes: added_nodes = {} @@ -24,7 +23,13 @@ def get_graph_from_model( elif isinstance(field_value, DataPoint): excluded_properties.add(field_name) nodes, edges, added_nodes, added_edges = add_nodes_and_edges( - data_point, field_name, field_value, nodes, edges, added_nodes, added_edges + data_point, + field_name, + field_value, + nodes, + edges, + added_nodes, + added_edges, ) elif ( @@ -35,12 +40,13 @@ def get_graph_from_model( excluded_properties.add(field_name) for item in field_value: + n_edges_before = len(edges) nodes, edges, added_nodes, added_edges = add_nodes_and_edges( data_point, field_name, item, nodes, edges, added_nodes, added_edges ) - edges = [ + edges = edges[:n_edges_before] + [ (*edge[:3], {**edge[3], "metadata": {"type": "list"}}) - for edge in edges + for edge in edges[n_edges_before:] ] else: data_point_properties[field_name] = field_value diff --git a/cognee/modules/graph/utils/get_model_instance_from_graph.py b/cognee/modules/graph/utils/get_model_instance_from_graph.py index bdd0dface..87146111c 100644 --- a/cognee/modules/graph/utils/get_model_instance_from_graph.py +++ b/cognee/modules/graph/utils/get_model_instance_from_graph.py @@ -1,8 +1,12 @@ +from typing import Callable + from pydantic_core import PydanticUndefined + from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def merge_dicts(dict1, dict2, agg_fn): + +def merge_dicts(dict1: dict, dict2: dict, agg_fn: Callable) -> dict: merged_dict = {} for key, value in dict1.items(): if key in dict2: @@ -15,22 +19,38 @@ def merge_dicts(dict1, dict2, agg_fn): merged_dict[key] = value return merged_dict -def get_model_instance_from_graph(nodes: list[DataPoint], edges: list[tuple[str, str, str, dict[str, str]]], entity_id: str): + +def get_model_instance_from_graph( + nodes: list[DataPoint], + edges: list[tuple[str, str, str, dict[str, str]]], + entity_id: str, +): node_map = {node.id: node for node in nodes} for source_node_id, target_node_id, edge_label, edge_properties in edges: source_node = node_map[source_node_id] target_node = node_map[target_node_id] edge_metadata = edge_properties.get("metadata", {}) - edge_type = edge_metadata.get("type") + edge_type = edge_metadata.get("type", "default") if edge_type == "list": - NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) }) - new_model_dict = merge_dicts(source_node.model_dump(), { edge_label: [target_node] }, lambda a, b: a + b) + NewModel = copy_model( + type(source_node), + {edge_label: (list[type(target_node)], PydanticUndefined)}, + ) + new_model_dict = merge_dicts( + source_node.model_dump(), + {edge_label: [target_node]}, + lambda a, b: a + b, + ) node_map[source_node_id] = NewModel(**new_model_dict) else: - NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) }) + NewModel = copy_model( + type(source_node), {edge_label: (type(target_node), PydanticUndefined)} + ) - node_map[target_node_id] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) + node_map[target_node_id] = NewModel( + **source_node.model_dump(), **{edge_label: target_node} + ) return node_map[entity_id] diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index c06023cc2..c8909d40d 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -132,7 +132,6 @@ def count_society(obj): def show_first_difference(str1, str2, str1_name, str2_name, context=30): - """Shows where two strings first diverge, with surrounding context.""" for i, (c1, c2) in enumerate(zip(str1, str2)): if c1 != c2: start = max(0, i - context)