From 7be613e2fc01b0f77f227763280a954a67608f9a Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Fri, 15 Nov 2024 11:57:26 +0100 Subject: [PATCH] WIP nested pydantic structures --- .../tests/unit/interfaces/graph/conftest.py | 105 +++++++++++++++++- 1 file changed, 103 insertions(+), 2 deletions(-) diff --git a/cognee/tests/unit/interfaces/graph/conftest.py b/cognee/tests/unit/interfaces/graph/conftest.py index 45f977bd6..83d64523d 100644 --- a/cognee/tests/unit/interfaces/graph/conftest.py +++ b/cognee/tests/unit/interfaces/graph/conftest.py @@ -1,3 +1,5 @@ +import random +import string from enum import Enum from typing import Optional @@ -36,8 +38,8 @@ class Car(DataPoint): class Person(DataPoint): id: str name: str - age: int - owns_car: list[Car] + age: Optional[int] + owns_car: Optional[list[Car]] driving_license: Optional[dict] _metadata: dict = dict(index_fields=["name"]) @@ -66,3 +68,102 @@ def boris(): }, ) return boris + + +class Organization(DataPoint): + id: str + name: str + members: Optional[list["SocietyPerson"]] + + +class SocietyPerson(DataPoint): + id: str + name: str + memberships: Optional[list[Organization]] + + +Organization.model_rebuild() +SocietyPerson.model_rebuild() + + +ORGANIZATION_NAMES = [ + "ChessClub", + "RowingClub", + "TheatreTroupe", + "PoliticalParty", + "Charity", + "FanClub", + "FilmClub", + "NeighborhoodGroup", + "LocalCouncil", + "Band", +] +PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"] + + +def create_society_person_recursive(id, name, organization_names, max_depth, depth=0): + if depth < max_depth: + memberships = [ + create_organization_recursive( + org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1 + ) + for org_name in organization_names + ] + else: + memberships = None + + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + return SocietyPerson( + id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships + ) + + +def create_organization_recursive(id, name, member_names, max_depth, depth=0): + if depth < max_depth: + members = [ + create_society_person_recursive( + member_name, + member_name.lower(), + ORGANIZATION_NAMES, + max_depth, + depth + 1, + ) + for member_name in member_names + ] + else: + members = None + + id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + return Organization( + id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members + ) + + +def count_society(obj): + if isinstance(obj, SocietyPerson): + if obj.memberships is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.memberships] + ) + organization_count = sum(organization_counts) + society_person_count = sum(society_person_counts) + 1 + return (organization_count, society_person_count) + else: + return (0, 1) + if isinstance(obj, Organization): + if obj.members is not None: + organization_counts, society_person_counts = zip( + *[count_society(organization) for organization in obj.members] + ) + organization_count = sum(organization_counts) + 1 + society_person_count = sum(society_person_counts) + return (organization_count, society_person_count) + else: + return (1, 0) + else: + return (0, 0) + + +@pytest.fixture(scope="function") +def society(): + society = create_organization_recursive("society", "Society", PERSON_NAMES, 4)