From afae70f3b59cee1d13eaab152162556fbdd22278 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 15:10:42 +0100 Subject: [PATCH] Add get_graph_from_model_generative_test --- .../tests/unit/interfaces/graph/conftest.py | 101 ----------------- .../get_graph_from_model_generative_test.py | 28 +++++ cognee/tests/unit/interfaces/graph/util.py | 102 +++++++++++++++++- 3 files changed, 129 insertions(+), 102 deletions(-) create mode 100644 cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py diff --git a/cognee/tests/unit/interfaces/graph/conftest.py b/cognee/tests/unit/interfaces/graph/conftest.py index 83d64523d..8084fb2bf 100644 --- a/cognee/tests/unit/interfaces/graph/conftest.py +++ b/cognee/tests/unit/interfaces/graph/conftest.py @@ -1,5 +1,3 @@ -import random -import string from enum import Enum from typing import Optional @@ -68,102 +66,3 @@ def boris(): }, ) return boris - - -class Organization(DataPoint): - id: str - name: str - members: Optional[list["SocietyPerson"]] - - -class SocietyPerson(DataPoint): - id: str - name: str - memberships: Optional[list[Organization]] - - -Organization.model_rebuild() -SocietyPerson.model_rebuild() - - -ORGANIZATION_NAMES = [ - "ChessClub", - "RowingClub", - "TheatreTroupe", - "PoliticalParty", - "Charity", - "FanClub", - "FilmClub", - "NeighborhoodGroup", - "LocalCouncil", - "Band", -] -PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"] - - -def create_society_person_recursive(id, name, organization_names, max_depth, depth=0): - if depth < max_depth: - memberships = [ - create_organization_recursive( - org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1 - ) - for org_name in organization_names - ] - else: - memberships = None - - id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) - return SocietyPerson( - id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships - ) - - -def create_organization_recursive(id, name, member_names, max_depth, depth=0): - if depth < max_depth: - members = [ - create_society_person_recursive( - member_name, - member_name.lower(), - ORGANIZATION_NAMES, - max_depth, - depth + 1, - ) - for member_name in member_names - ] - else: - members = None - - id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) - return Organization( - id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members - ) - - -def count_society(obj): - if isinstance(obj, SocietyPerson): - if obj.memberships is not None: - organization_counts, society_person_counts = zip( - *[count_society(organization) for organization in obj.memberships] - ) - organization_count = sum(organization_counts) - society_person_count = sum(society_person_counts) + 1 - return (organization_count, society_person_count) - else: - return (0, 1) - if isinstance(obj, Organization): - if obj.members is not None: - organization_counts, society_person_counts = zip( - *[count_society(organization) for organization in obj.members] - ) - organization_count = sum(organization_counts) + 1 - society_person_count = sum(society_person_counts) - return (organization_count, society_person_count) - else: - return (1, 0) - else: - return (0, 0) - - -@pytest.fixture(scope="function") -def society(): - society = create_organization_recursive("society", "Society", PERSON_NAMES, 4) diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py new file mode 100644 index 000000000..2a5816ac3 --- /dev/null +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_generative_test.py @@ -0,0 +1,28 @@ +import pytest + +from cognee.modules.graph.utils import get_graph_from_model +from cognee.tests.unit.interfaces.graph.util import ( + PERSON_NAMES, + count_society, + create_organization_recursive, +) + + +@pytest.mark.parametrize("recursive_depth", [1, 2, 3]) +def test_extracted_car_type(recursive_depth): + society = create_organization_recursive( + "society", "Society", PERSON_NAMES, recursive_depth + ) + + n_organizations, n_persons = count_society(society) + society_counts_total = n_organizations + n_persons + + nodes, edges = get_graph_from_model(society) + + assert ( + len(nodes) == society_counts_total + ), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found" + + assert len(edges) == ( + len(nodes) - 1 + ), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node" diff --git a/cognee/tests/unit/interfaces/graph/util.py b/cognee/tests/unit/interfaces/graph/util.py index e5da0201c..3bdd55fe7 100644 --- a/cognee/tests/unit/interfaces/graph/util.py +++ b/cognee/tests/unit/interfaces/graph/util.py @@ -1,5 +1,9 @@ +import random +import string from datetime import datetime, timezone -from typing import Any, Dict +from typing import Any, Dict, Optional + +from cognee.infrastructure.engine import DataPoint def run_test_against_ground_truth( @@ -30,3 +34,99 @@ def run_test_against_ground_truth( time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at") assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }" + + +class Organization(DataPoint): + id: str + name: str + members: Optional[list["SocietyPerson"]] + + +class SocietyPerson(DataPoint): + id: str + name: str + memberships: Optional[list[Organization]] + + +Organization.model_rebuild() +SocietyPerson.model_rebuild() + + +ORGANIZATION_NAMES = [ + "ChessClub", + "RowingClub", + "TheatreTroupe", + "PoliticalParty", + "Charity", + "FanClub", + "FilmClub", + "NeighborhoodGroup", + "LocalCouncil", + "Band", +] +PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"] + + +def create_society_person_recursive(id, name, organization_names, max_depth, depth=0): + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + + if depth < max_depth: + memberships = [ + create_organization_recursive( + f"{org_name}-{depth}-{id_suffix}", + org_name.lower(), + PERSON_NAMES, + max_depth, + depth + 1, + ) + for org_name in organization_names + ] + else: + memberships = None + + return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships) + + +def create_organization_recursive(id, name, member_names, max_depth, depth=0): + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + + if depth < max_depth: + members = [ + create_society_person_recursive( + f"{member_name}-{depth}-{id_suffix}", + member_name.lower(), + ORGANIZATION_NAMES, + max_depth, + depth + 1, + ) + for member_name in member_names + ] + else: + members = None + + return Organization(id=id, name=f"{name}{depth}", members=members) + + +def count_society(obj): + if isinstance(obj, SocietyPerson): + if obj.memberships is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.memberships] + ) + organization_count = sum(organization_counts) + society_person_count = sum(society_person_counts) + 1 + return (organization_count, society_person_count) + else: + return (0, 1) + if isinstance(obj, Organization): + if obj.members is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.members] + ) + organization_count = sum(organization_counts) + 1 + society_person_count = sum(society_person_counts) + return (organization_count, society_person_count) + else: + return (1, 0) + else: + raise Exception("Not allowed")