WIP nested pydantic structures

This commit is contained in:
Leon Luithlen 2024-11-15 11:57:26 +01:00
parent 0ea011ccd7
commit 7be613e2fc

View file

@ -1,3 +1,5 @@
import random
import string
from enum import Enum
from typing import Optional
@ -36,8 +38,8 @@ class Car(DataPoint):
class Person(DataPoint):
id: str
name: str
age: int
owns_car: list[Car]
age: Optional[int]
owns_car: Optional[list[Car]]
driving_license: Optional[dict]
_metadata: dict = dict(index_fields=["name"])
@ -66,3 +68,102 @@ def boris():
},
)
return boris
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
Organization.model_rebuild()
SocietyPerson.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
if depth < max_depth:
memberships = [
create_organization_recursive(
org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1
)
for org_name in organization_names
]
else:
memberships = None
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
return SocietyPerson(
id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships
)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
if depth < max_depth:
members = [
create_society_person_recursive(
member_name,
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
return Organization(
id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members
)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
return (0, 0)
@pytest.fixture(scope="function")
def society():
society = create_organization_recursive("society", "Society", PERSON_NAMES, 4)