Add get_graph_from_model_generative_test
This commit is contained in:
parent
3c8a52f4b0
commit
afae70f3b5
3 changed files with 129 additions and 102 deletions
|
|
@ -1,5 +1,3 @@
|
||||||
import random
|
|
||||||
import string
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -68,102 +66,3 @@ def boris():
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return boris
|
return boris
|
||||||
|
|
||||||
|
|
||||||
class Organization(DataPoint):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
members: Optional[list["SocietyPerson"]]
|
|
||||||
|
|
||||||
|
|
||||||
class SocietyPerson(DataPoint):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
memberships: Optional[list[Organization]]
|
|
||||||
|
|
||||||
|
|
||||||
Organization.model_rebuild()
|
|
||||||
SocietyPerson.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
ORGANIZATION_NAMES = [
|
|
||||||
"ChessClub",
|
|
||||||
"RowingClub",
|
|
||||||
"TheatreTroupe",
|
|
||||||
"PoliticalParty",
|
|
||||||
"Charity",
|
|
||||||
"FanClub",
|
|
||||||
"FilmClub",
|
|
||||||
"NeighborhoodGroup",
|
|
||||||
"LocalCouncil",
|
|
||||||
"Band",
|
|
||||||
]
|
|
||||||
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
|
|
||||||
|
|
||||||
|
|
||||||
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
|
|
||||||
if depth < max_depth:
|
|
||||||
memberships = [
|
|
||||||
create_organization_recursive(
|
|
||||||
org_name, org_name.lower(), PERSON_NAMES, max_depth, depth + 1
|
|
||||||
)
|
|
||||||
for org_name in organization_names
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
memberships = None
|
|
||||||
|
|
||||||
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
|
||||||
return SocietyPerson(
|
|
||||||
id=f"{id}{depth}-{id_suffix}", name=f"{name}{depth}", memberships=memberships
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
|
|
||||||
if depth < max_depth:
|
|
||||||
members = [
|
|
||||||
create_society_person_recursive(
|
|
||||||
member_name,
|
|
||||||
member_name.lower(),
|
|
||||||
ORGANIZATION_NAMES,
|
|
||||||
max_depth,
|
|
||||||
depth + 1,
|
|
||||||
)
|
|
||||||
for member_name in member_names
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
members = None
|
|
||||||
|
|
||||||
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
|
||||||
return Organization(
|
|
||||||
id=f"{id}{depth}-{id_suffix}", name=f"{id}{name}", members=members
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def count_society(obj):
|
|
||||||
if isinstance(obj, SocietyPerson):
|
|
||||||
if obj.memberships is not None:
|
|
||||||
organization_counts, society_person_counts = zip(
|
|
||||||
*[count_society(organization) for organization in obj.memberships]
|
|
||||||
)
|
|
||||||
organization_count = sum(organization_counts)
|
|
||||||
society_person_count = sum(society_person_counts) + 1
|
|
||||||
return (organization_count, society_person_count)
|
|
||||||
else:
|
|
||||||
return (0, 1)
|
|
||||||
if isinstance(obj, Organization):
|
|
||||||
if obj.members is not None:
|
|
||||||
organization_counts, society_person_counts = zip(
|
|
||||||
*[count_society(organization) for organization in obj.members]
|
|
||||||
)
|
|
||||||
organization_count = sum(organization_counts) + 1
|
|
||||||
society_person_count = sum(society_person_counts)
|
|
||||||
return (organization_count, society_person_count)
|
|
||||||
else:
|
|
||||||
return (1, 0)
|
|
||||||
else:
|
|
||||||
return (0, 0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def society():
|
|
||||||
society = create_organization_recursive("society", "Society", PERSON_NAMES, 4)
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cognee.modules.graph.utils import get_graph_from_model
|
||||||
|
from cognee.tests.unit.interfaces.graph.util import (
|
||||||
|
PERSON_NAMES,
|
||||||
|
count_society,
|
||||||
|
create_organization_recursive,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
|
||||||
|
def test_extracted_car_type(recursive_depth):
|
||||||
|
society = create_organization_recursive(
|
||||||
|
"society", "Society", PERSON_NAMES, recursive_depth
|
||||||
|
)
|
||||||
|
|
||||||
|
n_organizations, n_persons = count_society(society)
|
||||||
|
society_counts_total = n_organizations + n_persons
|
||||||
|
|
||||||
|
nodes, edges = get_graph_from_model(society)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(nodes) == society_counts_total
|
||||||
|
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
|
||||||
|
|
||||||
|
assert len(edges) == (
|
||||||
|
len(nodes) - 1
|
||||||
|
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
|
import random
|
||||||
|
import string
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
def run_test_against_ground_truth(
|
def run_test_against_ground_truth(
|
||||||
|
|
@ -30,3 +34,99 @@ def run_test_against_ground_truth(
|
||||||
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
|
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
|
||||||
|
|
||||||
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
|
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
|
||||||
|
|
||||||
|
|
||||||
|
class Organization(DataPoint):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
members: Optional[list["SocietyPerson"]]
|
||||||
|
|
||||||
|
|
||||||
|
class SocietyPerson(DataPoint):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
memberships: Optional[list[Organization]]
|
||||||
|
|
||||||
|
|
||||||
|
Organization.model_rebuild()
|
||||||
|
SocietyPerson.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
|
ORGANIZATION_NAMES = [
|
||||||
|
"ChessClub",
|
||||||
|
"RowingClub",
|
||||||
|
"TheatreTroupe",
|
||||||
|
"PoliticalParty",
|
||||||
|
"Charity",
|
||||||
|
"FanClub",
|
||||||
|
"FilmClub",
|
||||||
|
"NeighborhoodGroup",
|
||||||
|
"LocalCouncil",
|
||||||
|
"Band",
|
||||||
|
]
|
||||||
|
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
|
||||||
|
|
||||||
|
|
||||||
|
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
|
||||||
|
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||||
|
|
||||||
|
if depth < max_depth:
|
||||||
|
memberships = [
|
||||||
|
create_organization_recursive(
|
||||||
|
f"{org_name}-{depth}-{id_suffix}",
|
||||||
|
org_name.lower(),
|
||||||
|
PERSON_NAMES,
|
||||||
|
max_depth,
|
||||||
|
depth + 1,
|
||||||
|
)
|
||||||
|
for org_name in organization_names
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
memberships = None
|
||||||
|
|
||||||
|
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
|
||||||
|
|
||||||
|
|
||||||
|
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
|
||||||
|
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||||
|
|
||||||
|
if depth < max_depth:
|
||||||
|
members = [
|
||||||
|
create_society_person_recursive(
|
||||||
|
f"{member_name}-{depth}-{id_suffix}",
|
||||||
|
member_name.lower(),
|
||||||
|
ORGANIZATION_NAMES,
|
||||||
|
max_depth,
|
||||||
|
depth + 1,
|
||||||
|
)
|
||||||
|
for member_name in member_names
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
members = None
|
||||||
|
|
||||||
|
return Organization(id=id, name=f"{name}{depth}", members=members)
|
||||||
|
|
||||||
|
|
||||||
|
def count_society(obj):
|
||||||
|
if isinstance(obj, SocietyPerson):
|
||||||
|
if obj.memberships is not None:
|
||||||
|
organization_counts, society_person_counts = zip(
|
||||||
|
*[count_society(organization) for organization in obj.memberships]
|
||||||
|
)
|
||||||
|
organization_count = sum(organization_counts)
|
||||||
|
society_person_count = sum(society_person_counts) + 1
|
||||||
|
return (organization_count, society_person_count)
|
||||||
|
else:
|
||||||
|
return (0, 1)
|
||||||
|
if isinstance(obj, Organization):
|
||||||
|
if obj.members is not None:
|
||||||
|
organization_counts, society_person_counts = zip(
|
||||||
|
*[count_society(organization) for organization in obj.members]
|
||||||
|
)
|
||||||
|
organization_count = sum(organization_counts) + 1
|
||||||
|
society_person_count = sum(society_person_counts)
|
||||||
|
return (organization_count, society_person_count)
|
||||||
|
else:
|
||||||
|
return (1, 0)
|
||||||
|
else:
|
||||||
|
raise Exception("Not allowed")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue