Transform into pytest tests

This commit is contained in:
Leon Luithlen 2024-11-12 09:46:58 +01:00
parent fbd011560a
commit f326a4daff
5 changed files with 48 additions and 33 deletions

View file

@ -1,3 +1,4 @@
import pytest
from enum import Enum
from datetime import datetime, timezone
from typing import Optional
@ -27,7 +28,6 @@ PERSON_GROUND_TRUTH = {
"driving_license": {'issued_by': "PU Vrsac", 'issued_on': '2025-11-06', 'number': '1234567890', 'expires_on': '2025-11-06'}
}
PARSED_PERSON_GROUND_TRUTH = {
"id": "boris",
"name": "Boris",
@ -69,10 +69,20 @@ class Person(DataPoint):
_metadata: dict = dict(index_fields = ["name"])
def run_test_agains_ground_truth(test_target_item_name, test_target_item, ground_truth_dict):
for key, ground_truth in ground_truth_dict.items():
if isinstance(ground_truth, dict):
for key2, ground_truth2 in ground_truth.items():
assert ground_truth2 == getattr(test_target_item, key)[key2], f'{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }'
else:
assert ground_truth == getattr(test_target_item, key), f'{test_target_item_name}/{key = }: {ground_truth = } != {getattr(test_target_item, key) = }'
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
assert time_delta.total_seconds() < 20, f"{ time_delta.total_seconds() = }"
if __name__ == "__main__":
@pytest.fixture(scope="session")
def graph_outputs():
boris = Person(
id = "boris",
name = "Boris",
@ -92,31 +102,37 @@ if __name__ == "__main__":
"expires_on": "2025-11-06",
},
)
nodes, edges = get_graph_from_model(boris)
car, person = nodes[0], nodes[1]
edge = edges[0]
try:
car, person = nodes[0], nodes[1]
edge = edges[0]
except:
print(f"{nodes = }\n{edges = }")
def test_against_ground_truth(test_target_item_name, test_target_item, ground_truth_dict):
for key, ground_truth in ground_truth_dict.items():
if isinstance(ground_truth, dict):
for key2, ground_truth2 in ground_truth.items():
assert ground_truth2 == getattr(test_target_item, key)[key2], f'{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }'
else:
assert ground_truth == getattr(test_target_item, key), f'{test_target_item_name}/{key = }: {ground_truth = } != {getattr(test_target_item, key) = }'
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
assert time_delta.total_seconds() < 20, f"{ time_delta.total_seconds() = }"
return(car, person, edge, parsed_person)
test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_person(graph_outputs):
(_, person, _, _) = graph_outputs
run_test_agains_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car(graph_outputs):
(car, _, _, _) = graph_outputs
run_test_agains_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_edge(graph_outputs):
(_, _, edge, _) = graph_outputs
assert EDGE_GROUND_TRUTH[:3] == edge[:3], f'{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }'
for key, ground_truth in EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f'{ground_truth = } != {edge[3][key] = }'
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
test_against_ground_truth("parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH)
test_against_ground_truth("car", parsed_person.owns_car[0], CAR_GROUND_TRUTH)
def test_parsed_person(graph_outputs):
(_, _, _, parsed_person) = graph_outputs
run_test_agains_ground_truth("parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH)
run_test_agains_ground_truth("car", parsed_person.owns_car[0], CAR_GROUND_TRUTH)

View file

@ -22,7 +22,7 @@ INPUT_TEXT = {
Third paragraph is cut and is missing the dot at the end"""
}
def test_chunking(test_text, ground_truth):
def run_chunking_test(test_text, ground_truth):
chunks = []
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs = False):
chunks.append(chunk_data)
@ -34,7 +34,8 @@ def test_chunking(test_text, ground_truth):
assert chunk[key] == ground_truth_item[key], f'{key = }: {chunk[key] = } != {ground_truth_item[key] = }'
def test_chunking_whole_text():
run_chunking_test(INPUT_TEXT["whole_text"], GROUND_TRUTH["whole_text"])
if __name__ == "__main__":
test_chunking(INPUT_TEXT["whole_text"], GROUND_TRUTH["whole_text"])
test_chunking(INPUT_TEXT["cut_text"], GROUND_TRUTH["cut_text"])
def test_chunking_cut_text():
run_chunking_test(INPUT_TEXT["cut_text"], GROUND_TRUTH["cut_text"])

View file

@ -3,7 +3,7 @@ from cognee.modules.pipelines.operations.run_tasks import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
async def main():
async def run_and_check_tasks():
def number_generator(num):
for i in range(num):
yield i + 1
@ -32,5 +32,5 @@ async def main():
assert result == results[index]
index += 1
if __name__ == "__main__":
asyncio.run(main())
def test_run_tasks():
asyncio.run(run_and_check_tasks())

View file

@ -30,7 +30,7 @@ async def pipeline(data_queue):
assert result == results[index]
index += 1
async def main():
async def run_queue():
data_queue = Queue()
data_queue.is_closed = False
@ -42,5 +42,5 @@ async def main():
await asyncio.gather(pipeline(data_queue), queue_producer())
if __name__ == "__main__":
asyncio.run(main())
def test_run_tasks_from_queue():
asyncio.run(run_queue())

View file

@ -1,2 +0,0 @@
[pytest]
addopts = tests/