WIP nested pydantic structures
This commit is contained in:
parent
0ea011ccd7
commit
7be613e2fc
1 changed files with 103 additions and 2 deletions
|
|
@ -1,3 +1,5 @@
|
|||
import random
|
||||
import string
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -36,8 +38,8 @@ class Car(DataPoint):
|
|||
class Person(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
age: int
|
||||
owns_car: list[Car]
|
||||
age: Optional[int]
|
||||
owns_car: Optional[list[Car]]
|
||||
driving_license: Optional[dict]
|
||||
_metadata: dict = dict(index_fields=["name"])
|
||||
|
||||
|
|
@ -66,3 +68,102 @@ def boris():
|
|||
},
|
||||
)
|
||||
return boris
|
||||
|
||||
|
||||
class Organization(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
members: Optional[list["SocietyPerson"]]
|
||||
|
||||
|
||||
class SocietyPerson(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
memberships: Optional[list[Organization]]
|
||||
|
||||
|
||||
Organization.model_rebuild()
|
||||
SocietyPerson.model_rebuild()
|
||||
|
||||
|
||||
ORGANIZATION_NAMES = [
|
||||
"ChessClub",
|
||||
"RowingClub",
|
||||
"TheatreTroupe",
|
||||
"PoliticalParty",
|
||||
"Charity",
|
||||
"FanClub",
|
||||
"FilmClub",
|
||||
"NeighborhoodGroup",
|
||||
"LocalCouncil",
|
||||
"Band",
|
||||
]
|
||||
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
|
||||
|
||||
|
||||
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
|
||||
if depth < max_depth:
|
||||
memberships = [
|
||||
create_organization_recursive(
|
||||
org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1
|
||||
)
|
||||
for org_name in organization_names
|
||||
]
|
||||
else:
|
||||
memberships = None
|
||||
|
||||
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||
return SocietyPerson(
|
||||
id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships
|
||||
)
|
||||
|
||||
|
||||
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
|
||||
if depth < max_depth:
|
||||
members = [
|
||||
create_society_person_recursive(
|
||||
member_name,
|
||||
member_name.lower(),
|
||||
ORGANIZATION_NAMES,
|
||||
max_depth,
|
||||
depth + 1,
|
||||
)
|
||||
for member_name in member_names
|
||||
]
|
||||
else:
|
||||
members = None
|
||||
|
||||
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||
return Organization(
|
||||
id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members
|
||||
)
|
||||
|
||||
|
||||
def count_society(obj):
|
||||
if isinstance(obj, SocietyPerson):
|
||||
if obj.memberships is not None:
|
||||
organization_counts, society_person_counts = zip(
|
||||
*[count_society(organization) for organization in obj.memberships]
|
||||
)
|
||||
organization_count = sum(organization_counts)
|
||||
society_person_count = sum(society_person_counts) + 1
|
||||
return (organization_count, society_person_count)
|
||||
else:
|
||||
return (0, 1)
|
||||
if isinstance(obj, Organization):
|
||||
if obj.members is not None:
|
||||
organization_counts, society_person_counts = zip(
|
||||
*[count_society(organization) for organization in obj.members]
|
||||
)
|
||||
organization_count = sum(organization_counts) + 1
|
||||
society_person_count = sum(society_person_counts)
|
||||
return (organization_count, society_person_count)
|
||||
else:
|
||||
return (1, 0)
|
||||
else:
|
||||
return (0, 0)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def society():
|
||||
society = create_organization_recursive("society", "Society", PERSON_NAMES, 4)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue