Add get_graph_from_model_generative_test

This commit is contained in:
Leon Luithlen 2024-11-15 15:10:42 +01:00
parent 3c8a52f4b0
commit afae70f3b5
3 changed files with 129 additions and 102 deletions

View file

@ -1,5 +1,3 @@
import random
import string
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@ -68,102 +66,3 @@ def boris():
}, },
) )
return boris return boris
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
Organization.model_rebuild()
SocietyPerson.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
if depth < max_depth:
memberships = [
create_organization_recursive(
org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1
)
for org_name in organization_names
]
else:
memberships = None
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
return SocietyPerson(
id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships
)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
if depth < max_depth:
members = [
create_society_person_recursive(
member_name,
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
return Organization(
id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members
)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
return (0, 0)
@pytest.fixture(scope="function")
def society():
society = create_organization_recursive("society", "Society", PERSON_NAMES, 4)

View file

@ -0,0 +1,28 @@
import pytest
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_extracted_car_type(recursive_depth):
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons
nodes, edges = get_graph_from_model(society)
assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"

View file

@ -1,5 +1,9 @@
import random
import string
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Dict from typing import Any, Dict, Optional
from cognee.infrastructure.engine import DataPoint
def run_test_against_ground_truth( def run_test_against_ground_truth(
@ -30,3 +34,99 @@ def run_test_against_ground_truth(
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at") time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }" assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
Organization.model_rebuild()
SocietyPerson.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
memberships = [
create_organization_recursive(
f"{org_name}-{depth}-{id_suffix}",
org_name.lower(),
PERSON_NAMES,
max_depth,
depth + 1,
)
for org_name in organization_names
]
else:
memberships = None
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
members = [
create_society_person_recursive(
f"{member_name}-{depth}-{id_suffix}",
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
return Organization(id=id, name=f"{name}{depth}", members=members)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
raise Exception("Not allowed")