Transform into pytest tests
This commit is contained in:
parent
fbd011560a
commit
f326a4daff
5 changed files with 48 additions and 33 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
import pytest
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -27,7 +28,6 @@ PERSON_GROUND_TRUTH = {
|
||||||
"driving_license": {'issued_by': "PU Vrsac", 'issued_on': '2025-11-06', 'number': '1234567890', 'expires_on': '2025-11-06'}
|
"driving_license": {'issued_by': "PU Vrsac", 'issued_on': '2025-11-06', 'number': '1234567890', 'expires_on': '2025-11-06'}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
PARSED_PERSON_GROUND_TRUTH = {
|
PARSED_PERSON_GROUND_TRUTH = {
|
||||||
"id": "boris",
|
"id": "boris",
|
||||||
"name": "Boris",
|
"name": "Boris",
|
||||||
|
|
@ -69,10 +69,20 @@ class Person(DataPoint):
|
||||||
_metadata: dict = dict(index_fields = ["name"])
|
_metadata: dict = dict(index_fields = ["name"])
|
||||||
|
|
||||||
|
|
||||||
|
def run_test_agains_ground_truth(test_target_item_name, test_target_item, ground_truth_dict):
|
||||||
|
for key, ground_truth in ground_truth_dict.items():
|
||||||
|
if isinstance(ground_truth, dict):
|
||||||
|
for key2, ground_truth2 in ground_truth.items():
|
||||||
|
assert ground_truth2 == getattr(test_target_item, key)[key2], f'{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }'
|
||||||
|
else:
|
||||||
|
assert ground_truth == getattr(test_target_item, key), f'{test_target_item_name}/{key = }: {ground_truth = } != {getattr(test_target_item, key) = }'
|
||||||
|
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
|
||||||
|
|
||||||
|
assert time_delta.total_seconds() < 20, f"{ time_delta.total_seconds() = }"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@pytest.fixture(scope="session")
|
||||||
|
def graph_outputs():
|
||||||
boris = Person(
|
boris = Person(
|
||||||
id = "boris",
|
id = "boris",
|
||||||
name = "Boris",
|
name = "Boris",
|
||||||
|
|
@ -92,31 +102,37 @@ if __name__ == "__main__":
|
||||||
"expires_on": "2025-11-06",
|
"expires_on": "2025-11-06",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes, edges = get_graph_from_model(boris)
|
nodes, edges = get_graph_from_model(boris)
|
||||||
|
|
||||||
car, person = nodes[0], nodes[1]
|
try:
|
||||||
edge = edges[0]
|
car, person = nodes[0], nodes[1]
|
||||||
|
edge = edges[0]
|
||||||
|
except:
|
||||||
|
print(f"{nodes = }\n{edges = }")
|
||||||
|
|
||||||
def test_against_ground_truth(test_target_item_name, test_target_item, ground_truth_dict):
|
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
|
||||||
for key, ground_truth in ground_truth_dict.items():
|
|
||||||
if isinstance(ground_truth, dict):
|
|
||||||
for key2, ground_truth2 in ground_truth.items():
|
|
||||||
assert ground_truth2 == getattr(test_target_item, key)[key2], f'{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }'
|
|
||||||
else:
|
|
||||||
assert ground_truth == getattr(test_target_item, key), f'{test_target_item_name}/{key = }: {ground_truth = } != {getattr(test_target_item, key) = }'
|
|
||||||
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
|
|
||||||
|
|
||||||
assert time_delta.total_seconds() < 20, f"{ time_delta.total_seconds() = }"
|
return(car, person, edge, parsed_person)
|
||||||
|
|
||||||
test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
|
def test_extracted_person(graph_outputs):
|
||||||
test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
|
(_, person, _, _) = graph_outputs
|
||||||
|
|
||||||
|
run_test_agains_ground_truth("person", person, PERSON_GROUND_TRUTH)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extracted_car(graph_outputs):
|
||||||
|
(car, _, _, _) = graph_outputs
|
||||||
|
run_test_agains_ground_truth("car", car, CAR_GROUND_TRUTH)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extracted_edge(graph_outputs):
|
||||||
|
(_, _, edge, _) = graph_outputs
|
||||||
|
|
||||||
assert EDGE_GROUND_TRUTH[:3] == edge[:3], f'{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }'
|
assert EDGE_GROUND_TRUTH[:3] == edge[:3], f'{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }'
|
||||||
for key, ground_truth in EDGE_GROUND_TRUTH[3].items():
|
for key, ground_truth in EDGE_GROUND_TRUTH[3].items():
|
||||||
assert ground_truth == edge[3][key], f'{ground_truth = } != {edge[3][key] = }'
|
assert ground_truth == edge[3][key], f'{ground_truth = } != {edge[3][key] = }'
|
||||||
|
|
||||||
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
|
def test_parsed_person(graph_outputs):
|
||||||
|
(_, _, _, parsed_person) = graph_outputs
|
||||||
test_against_ground_truth("parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH)
|
run_test_agains_ground_truth("parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH)
|
||||||
test_against_ground_truth("car", parsed_person.owns_car[0], CAR_GROUND_TRUTH)
|
run_test_agains_ground_truth("car", parsed_person.owns_car[0], CAR_GROUND_TRUTH)
|
||||||
|
|
@ -22,7 +22,7 @@ INPUT_TEXT = {
|
||||||
Third paragraph is cut and is missing the dot at the end"""
|
Third paragraph is cut and is missing the dot at the end"""
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_chunking(test_text, ground_truth):
|
def run_chunking_test(test_text, ground_truth):
|
||||||
chunks = []
|
chunks = []
|
||||||
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs = False):
|
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs = False):
|
||||||
chunks.append(chunk_data)
|
chunks.append(chunk_data)
|
||||||
|
|
@ -34,7 +34,8 @@ def test_chunking(test_text, ground_truth):
|
||||||
assert chunk[key] == ground_truth_item[key], f'{key = }: {chunk[key] = } != {ground_truth_item[key] = }'
|
assert chunk[key] == ground_truth_item[key], f'{key = }: {chunk[key] = } != {ground_truth_item[key] = }'
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunking_whole_text():
|
||||||
|
run_chunking_test(INPUT_TEXT["whole_text"], GROUND_TRUTH["whole_text"])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_chunking_cut_text():
|
||||||
test_chunking(INPUT_TEXT["whole_text"], GROUND_TRUTH["whole_text"])
|
run_chunking_test(INPUT_TEXT["cut_text"], GROUND_TRUTH["cut_text"])
|
||||||
test_chunking(INPUT_TEXT["cut_text"], GROUND_TRUTH["cut_text"])
|
|
||||||
|
|
@ -3,7 +3,7 @@ from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
from cognee.modules.pipelines.tasks.Task import Task
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def run_and_check_tasks():
|
||||||
def number_generator(num):
|
def number_generator(num):
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
yield i + 1
|
yield i + 1
|
||||||
|
|
@ -32,5 +32,5 @@ async def main():
|
||||||
assert result == results[index]
|
assert result == results[index]
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_run_tasks():
|
||||||
asyncio.run(main())
|
asyncio.run(run_and_check_tasks())
|
||||||
|
|
@ -30,7 +30,7 @@ async def pipeline(data_queue):
|
||||||
assert result == results[index]
|
assert result == results[index]
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
async def main():
|
async def run_queue():
|
||||||
data_queue = Queue()
|
data_queue = Queue()
|
||||||
data_queue.is_closed = False
|
data_queue.is_closed = False
|
||||||
|
|
||||||
|
|
@ -42,5 +42,5 @@ async def main():
|
||||||
|
|
||||||
await asyncio.gather(pipeline(data_queue), queue_producer())
|
await asyncio.gather(pipeline(data_queue), queue_producer())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_run_tasks_from_queue():
|
||||||
asyncio.run(main())
|
asyncio.run(run_queue())
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
[pytest]
|
|
||||||
addopts = tests/
|
|
||||||
Loading…
Add table
Reference in a new issue