From 370b59b39a70086b7a341e63340a8cc1cb23610e Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 15:58:03 +0100 Subject: [PATCH] Add get_graph_from_model_generative_test --- .../get_graph_from_model_generative_test.py | 2 +- ...del_instance_from_graph_generative_test.py | 24 +++++++++++++++++++ cognee/tests/unit/interfaces/graph/util.py | 19 +++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py index 2a5816ac3..dee4f5042 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py @@ -9,7 +9,7 @@ from cognee.tests.unit.interfaces.graph.util import ( @pytest.mark.parametrize("recursive_depth", [1, 2, 3]) -def test_extracted_car_type(recursive_depth): +def test_society_nodes_and_edges(recursive_depth): society = create_organization_recursive( "society", "Society", PERSON_NAMES, recursive_depth ) diff --git a/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py new file mode 100644 index 000000000..10578216b --- /dev/null +++ b/cognee/tests/unit/interfaces/graph/get_model_instance_from_graph_generative_test.py @@ -0,0 +1,24 @@ +import pytest + +from cognee.modules.graph.utils import ( + get_graph_from_model, + get_model_instance_from_graph, +) +from cognee.tests.unit.interfaces.graph.util import ( + PERSON_NAMES, + create_organization_recursive, + show_first_difference, +) + + +@pytest.mark.parametrize("recursive_depth", [1, 2, 3]) +def test_society_nodes_and_edges(recursive_depth): + society = create_organization_recursive( + "society", "Society", PERSON_NAMES, recursive_depth + ) + nodes, edges = get_graph_from_model(society) + parsed_society = get_model_instance_from_graph(nodes, edges, "society") + + assert str(society) == (str(parsed_society)), show_first_difference( + str(society), str(parsed_society), "society", "parsed_society" + ) diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index 3bdd55fe7..4a60c94fa 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -130,3 +130,22 @@ def count_society(obj): return (1, 0) else: raise Exception("Not allowed") + + +def show_first_difference(str1, str2, str1_name, str2_name, context=30): + """Shows where two strings first diverge, with surrounding context.""" + for i, (c1, c2) in enumerate(zip(str1, str2)): + if c1 != c2: + start = max(0, i - context) + end1 = min(len(str1), i + context + 1) + end2 = min(len(str2), i + context + 1) + if i > 0: + return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..." + else: + return f"{str1_name} and {str2_name} have no overlap in characters" + if len(str1) > len(str2): + return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..." + if len(str2) > len(str1): + return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..." + else: + return f"{str1_name} and {str2_name} are identical."