Merge pull request #245 from topoteretes/fixing-faulty-cognify-unit-tests

fix: deletes get_graph_from_model test of the faulty old implementation
This commit is contained in:
hajdul88 2024-12-03 14:50:52 +01:00 committed by GitHub
commit bd0227b551
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 0 additions and 412 deletions

View file

@ -1,68 +0,0 @@
from enum import Enum
from typing import Optional
import pytest
from cognee.infrastructure.engine import DataPoint
class CarTypeName(Enum):
Pickup = "Pickup"
Sedan = "Sedan"
SUV = "SUV"
Coupe = "Coupe"
Convertible = "Convertible"
Hatchback = "Hatchback"
Wagon = "Wagon"
Minivan = "Minivan"
Van = "Van"
class CarType(DataPoint):
id: str
name: CarTypeName
_metadata: dict = dict(index_fields=["name"])
class Car(DataPoint):
id: str
brand: str
model: str
year: int
color: str
is_type: CarType
class Person(DataPoint):
id: str
name: str
age: int
owns_car: list[Car]
driving_license: Optional[dict]
_metadata: dict = dict(index_fields=["name"])
@pytest.fixture(scope="function")
def boris():
boris = Person(
id="boris",
name="Boris",
age=30,
owns_car=[
Car(
id="car1",
brand="Toyota",
model="Camry",
year=2020,
color="Blue",
is_type=CarType(id="sedan", name=CarTypeName.Sedan),
)
],
driving_license={
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
)
return boris

View file

@ -1,37 +0,0 @@
import warnings
import pytest
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons
nodes, edges = get_graph_from_model(society)
assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,89 +0,0 @@
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
CAR_SEDAN_EDGE = (
"car1",
"sedan",
"is_type",
{
"source_node_id": "car1",
"target_node_id": "sedan",
"relationship_name": "is_type",
},
)
BORIS_CAR_EDGE_GROUND_TRUTH = (
"boris",
"car1",
"owns_car",
{
"source_node_id": "boris",
"target_node_id": "car1",
"relationship_name": "owns_car",
"metadata": {"type": "list"},
},
)
CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}
CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
"model": "Camry",
"year": 2020,
"color": "Blue",
}
PERSON_GROUND_TRUTH = {
"id": "boris",
"name": "Boris",
"age": 30,
"driving_license": {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
}
def test_extracted_car_type(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
def test_extracted_car(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_person(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car_sedan_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
def test_extracted_boris_car_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[1]
assert (
BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"

View file

@ -1,33 +0,0 @@
import warnings
import pytest
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
show_first_difference,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
nodes, edges = get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
assert str(society) == (str(parsed_society)), show_first_difference(
str(society), str(parsed_society), "society", "parsed_society"
)
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,35 +0,0 @@
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
PARSED_PERSON_GROUND_TRUTH = {
"id": "boris",
"name": "Boris",
"age": 30,
"driving_license": {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
}
CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
"model": "Camry",
"year": 2020,
"color": "Blue",
}
def test_parsed_person(boris):
nodes, edges = get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth(
"parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH
)
run_test_against_ground_truth("car", parsed_person.owns_car[0], CAR_GROUND_TRUTH)

View file

@ -1,150 +0,0 @@
import random
import string
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from cognee.infrastructure.engine import DataPoint
def run_test_against_ground_truth(
test_target_item_name: str, test_target_item: Any, ground_truth_dict: Dict[str, Any]
):
"""Validates test target item attributes against ground truth values.
Args:
test_target_item_name: Name of the item being tested (for error messages)
test_target_item: Object whose attributes are being validated
ground_truth_dict: Dictionary containing expected values
Raises:
AssertionError: If any attribute doesn't match ground truth or if update timestamp is too old
"""
for key, ground_truth in ground_truth_dict.items():
if isinstance(ground_truth, dict):
for key2, ground_truth2 in ground_truth.items():
assert (
ground_truth2 == getattr(test_target_item, key)[key2]
), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }"
elif isinstance(ground_truth, list):
raise NotImplementedError("Currently not implemented for 'list'")
else:
assert ground_truth == getattr(
test_target_item, key
), f"{test_target_item_name}/{key = }: {ground_truth = } != {getattr(test_target_item, key) = }"
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
SocietyPerson.model_rebuild()
Organization.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
memberships = [
create_organization_recursive(
f"{org_name}-{depth}-{id_suffix}",
org_name.lower(),
PERSON_NAMES,
max_depth,
depth + 1,
)
for org_name in organization_names
]
else:
memberships = None
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
members = [
create_society_person_recursive(
f"{member_name}-{depth}-{id_suffix}",
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
return Organization(id=id, name=f"{name}{depth}", members=members)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
raise Exception("Not allowed")
def show_first_difference(str1, str2, str1_name, str2_name, context=30):
for i, (c1, c2) in enumerate(zip(str1, str2)):
if c1 != c2:
start = max(0, i - context)
end1 = min(len(str1), i + context + 1)
end2 = min(len(str2), i + context + 1)
if i > 0:
return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..."
else:
return f"{str1_name} and {str2_name} have no overlap in characters"
if len(str1) > len(str2):
return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..."
if len(str2) > len(str1):
return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..."
else:
return f"{str1_name} and {str2_name} are identical."