From f8e5b529c3825481e4f056994448ead8a780d4fa Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Wed, 13 Nov 2024 13:58:00 +0100 Subject: [PATCH] Add maximum_length argument to chunk_sentences --- cognee/tasks/chunks/chunk_by_paragraph.py | 3 +- cognee/tasks/chunks/chunk_by_sentence.py | 5 ++- .../chunks/chunk_by_paragraph_test2.py | 44 ++++++++++++++++++- 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 24f55b118..11ab8dd41 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -13,7 +13,8 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs last_paragraph_id = None last_cut_type = None - for paragraph_id, _, sentence, word_count, end_type in chunk_by_sentence(data): + for paragraph_id, _, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length): + assert word_count <= paragraph_length, f"{paragraph_length = } is smaller than {word_count = }" # Check if this sentence would exceed length limit if current_word_count > 0 and current_word_count + word_count > paragraph_length: # Yield current chunk diff --git a/cognee/tasks/chunks/chunk_by_sentence.py b/cognee/tasks/chunks/chunk_by_sentence.py index 6a752caee..7191a78c4 100644 --- a/cognee/tasks/chunks/chunk_by_sentence.py +++ b/cognee/tasks/chunks/chunk_by_sentence.py @@ -2,9 +2,10 @@ from uuid import uuid4 +from typing import Optional from .chunk_by_word import chunk_by_word -def chunk_by_sentence(data: str): +def chunk_by_sentence(data: str, maximum_length: Optional[int]): sentence = "" paragraph_id = uuid4() chunk_index = 0 @@ -14,7 +15,7 @@ def chunk_by_sentence(data: str): sentence += word word_count += 1 - if word_type == "paragraph_end" or word_type == "sentence_end": + if word_type == "paragraph_end" or word_type == "sentence_end" or ((word_count is not None) and (word_count == maximum_length)): yield (paragraph_id, chunk_index, sentence, word_count, word_type) sentence = "" word_count = 0 diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test2.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test2.py index d846fdfa2..ef75094c4 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test2.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test2.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from cognee.tasks.chunks import chunk_by_paragraph +from cognee.tasks.chunks import chunk_by_paragraph, chunk_by_word from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS @@ -40,3 +40,45 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para assert ( reconstructed_text == input_text ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" + + +@pytest.mark.parametrize( + "input_text,paragraph_length,batch_paragraphs", + [ + (INPUT_TEXTS["english_text"], 64, True), + (INPUT_TEXTS["english_text"], 64, False), + (INPUT_TEXTS["english_text"], 256, True), + (INPUT_TEXTS["english_text"], 256, False), + (INPUT_TEXTS["english_text"], 1024, True), + (INPUT_TEXTS["english_text"], 1024, False), + (INPUT_TEXTS["english_lists"], 64, True), + (INPUT_TEXTS["english_lists"], 64, False), + (INPUT_TEXTS["english_lists"], 256, True), + (INPUT_TEXTS["english_lists"], 256, False), + (INPUT_TEXTS["english_lists"], 1024, True), + (INPUT_TEXTS["english_lists"], 1024, False), + (INPUT_TEXTS["python_code"], 64, True), + (INPUT_TEXTS["python_code"], 64, False), + (INPUT_TEXTS["python_code"], 256, True), + (INPUT_TEXTS["python_code"], 256, False), + (INPUT_TEXTS["python_code"], 1024, True), + (INPUT_TEXTS["python_code"], 1024, False), + (INPUT_TEXTS["chinese_text"], 64, True), + (INPUT_TEXTS["chinese_text"], 64, False), + (INPUT_TEXTS["chinese_text"], 256, True), + (INPUT_TEXTS["chinese_text"], 256, False), + (INPUT_TEXTS["chinese_text"], 1024, True), + (INPUT_TEXTS["chinese_text"], 1024, False), + ], +) +def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): + chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)) + + chunk_lengths = np.array( + [len(list(chunk_by_word(chunk["text"]))) for chunk in chunks] + ) + + larger_chunks = chunk_lengths[chunk_lengths > paragraph_length] + assert np.all( + chunk_lengths <= paragraph_length + ), f"{paragraph_length = }: {larger_chunks} are too large"