test: Change test_by_paragraph tests to accomodate to change

This commit is contained in:
Igor Ilic 2025-01-28 14:47:17 +01:00
parent 3db7f85c9c
commit 41544369af
2 changed files with 45 additions and 13 deletions

View file

@ -8,14 +8,24 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
paragraph_lengths = [64, 256, 1024]
batch_paragraphs_vals = [True, False]
max_chunk_tokens_vals = [512, 1024, 4096]
@pytest.mark.parametrize(
"input_text,paragraph_length,batch_paragraphs",
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
"input_text,max_chunk_tokens,paragraph_length,batch_paragraphs",
list(
product(
list(INPUT_TEXTS.values()),
max_chunk_tokens_vals,
paragraph_lengths,
batch_paragraphs_vals,
)
),
)
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
def test_chunk_by_paragraph_isomorphism(
input_text, max_chunk_tokens, paragraph_length, batch_paragraphs
):
chunks = chunk_by_paragraph(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@ -23,13 +33,23 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
@pytest.mark.parametrize(
"input_text,paragraph_length,batch_paragraphs",
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
"input_text,max_chunk_tokens,paragraph_length,batch_paragraphs",
list(
product(
list(INPUT_TEXTS.values()),
max_chunk_tokens_vals,
paragraph_lengths,
batch_paragraphs_vals,
)
),
)
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
def test_paragraph_chunk_length(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs):
chunks = list(
chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
data=input_text,
max_chunk_tokens=max_chunk_tokens,
paragraph_length=paragraph_length,
batch_paragraphs=batch_paragraphs,
)
)
@ -42,12 +62,24 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
@pytest.mark.parametrize(
"input_text,paragraph_length,batch_paragraphs",
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
"input_text,max_chunk_tokens,paragraph_length,batch_paragraphs",
list(
product(
list(INPUT_TEXTS.values()),
max_chunk_tokens_vals,
paragraph_lengths,
batch_paragraphs_vals,
)
),
)
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs):
def test_chunk_by_paragraph_chunk_numbering(
input_text, max_chunk_tokens, paragraph_length, batch_paragraphs
):
chunks = chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
data=input_text,
max_chunk_tokens=max_chunk_tokens,
paragraph_length=paragraph_length,
batch_paragraphs=batch_paragraphs,
)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (

View file

@ -50,7 +50,7 @@ Third paragraph is cut and is missing the dot at the end""",
def run_chunking_test(test_text, expected_chunks):
chunks = []
for chunk_data in chunk_by_paragraph(
data=test_text, paragraph_length=12, batch_paragraphs=False
data=test_text, paragraph_length=12, batch_paragraphs=False, max_chunk_tokens=512
):
chunks.append(chunk_data)