From 41544369af10756a3a76715ebb28206afdfcaab0 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 14:47:17 +0100 Subject: [PATCH] test: Change test_by_paragraph tests to accomodate to change --- .../chunks/chunk_by_paragraph_2_test.py | 56 +++++++++++++++---- .../chunks/chunk_by_paragraph_test.py | 2 +- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py index d8680a604..5555a7dc9 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py @@ -8,14 +8,24 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS paragraph_lengths = [64, 256, 1024] batch_paragraphs_vals = [True, False] +max_chunk_tokens_vals = [512, 1024, 4096] @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs): - chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) +def test_chunk_by_paragraph_isomorphism( + input_text, max_chunk_tokens, paragraph_length, batch_paragraphs +): + chunks = chunk_by_paragraph(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs) reconstructed_text = "".join([chunk["text"] for chunk in chunks]) assert reconstructed_text == input_text, ( f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" @@ -23,13 +33,23 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): +def test_paragraph_chunk_length(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs): chunks = list( chunk_by_paragraph( - data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + data=input_text, + max_chunk_tokens=max_chunk_tokens, + paragraph_length=paragraph_length, + batch_paragraphs=batch_paragraphs, ) ) @@ -42,12 +62,24 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs): +def test_chunk_by_paragraph_chunk_numbering( + input_text, max_chunk_tokens, paragraph_length, batch_paragraphs +): chunks = chunk_by_paragraph( - data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + data=input_text, + max_chunk_tokens=max_chunk_tokens, + paragraph_length=paragraph_length, + batch_paragraphs=batch_paragraphs, ) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py index e420b2e9f..ed706830e 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py @@ -50,7 +50,7 @@ Third paragraph is cut and is missing the dot at the end""", def run_chunking_test(test_text, expected_chunks): chunks = [] for chunk_data in chunk_by_paragraph( - data=test_text, paragraph_length=12, batch_paragraphs=False + data=test_text, paragraph_length=12, batch_paragraphs=False, max_chunk_tokens=512 ): chunks.append(chunk_data)