Adapt graph interfaces tests to debugged get_graph_from_model

This commit is contained in:
Leon Luithlen 2024-11-15 10:27:27 +01:00
parent 628f192b8d
commit 0ea011ccd7
4 changed files with 60 additions and 29 deletions

View file

@ -1,14 +1,9 @@
from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import pytest import pytest
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
class CarTypeName(Enum): class CarTypeName(Enum):
@ -47,8 +42,8 @@ class Person(DataPoint):
_metadata: dict = dict(index_fields=["name"]) _metadata: dict = dict(index_fields=["name"])
@pytest.fixture(scope="session") @pytest.fixture(scope="function")
def graph_outputs(): def boris():
boris = Person( boris = Person(
id="boris", id="boris",
name="Boris", name="Boris",
@ -70,11 +65,4 @@ def graph_outputs():
"expires_on": "2025-11-06", "expires_on": "2025-11-06",
}, },
) )
nodes, edges = get_graph_from_model(boris) return boris
car, person = nodes[0], nodes[1]
edge = edges[0]
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
return (car, person, edge, parsed_person)

View file

@ -1,6 +1,19 @@
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
EDGE_GROUND_TRUTH = ( CAR_SEDAN_EDGE = (
"car1",
"sedan",
"is_type",
{
"source_node_id": "car1",
"target_node_id": "sedan",
"relationship_name": "is_type",
},
)
BORIS_CAR_EDGE_GROUND_TRUTH = (
"boris", "boris",
"car1", "car1",
"owns_car", "owns_car",
@ -12,6 +25,8 @@ EDGE_GROUND_TRUTH = (
}, },
) )
CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}
CAR_GROUND_TRUTH = { CAR_GROUND_TRUTH = {
"id": "car1", "id": "car1",
"brand": "Toyota", "brand": "Toyota",
@ -33,22 +48,42 @@ PERSON_GROUND_TRUTH = {
} }
def test_extracted_person(graph_outputs): def test_extracted_car_type(boris):
(_, person, _, _) = graph_outputs nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH) car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
def test_extracted_car(graph_outputs): def test_extracted_car(boris):
(car, _, _, _) = graph_outputs nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH) run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_edge(graph_outputs): def test_extracted_person(boris):
(_, _, edge, _) = graph_outputs nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car_sedan_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
def test_extracted_boris_car_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[1]
assert ( assert (
EDGE_GROUND_TRUTH[:3] == edge[:3] BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }" ), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in EDGE_GROUND_TRUTH[3].items(): for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }" assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"

View file

@ -1,3 +1,7 @@
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
PARSED_PERSON_GROUND_TRUTH = { PARSED_PERSON_GROUND_TRUTH = {
@ -21,8 +25,10 @@ CAR_GROUND_TRUTH = {
} }
def test_parsed_person(graph_outputs): def test_parsed_person(boris):
(_, _, _, parsed_person) = graph_outputs nodes, edges = get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth( run_test_against_ground_truth(
"parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH "parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH
) )

View file

@ -21,6 +21,8 @@ def run_test_against_ground_truth(
assert ( assert (
ground_truth2 == getattr(test_target_item, key)[key2] ground_truth2 == getattr(test_target_item, key)[key2]
), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }" ), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }"
elif isinstance(ground_truth, list):
raise NotImplementedError("Currently not implemented for 'list'")
else: else:
assert ground_truth == getattr( assert ground_truth == getattr(
test_target_item, key test_target_item, key