Add maximum_length argument to chunk_sentences
This commit is contained in:
parent
ef7a19043d
commit
f8e5b529c3
3 changed files with 48 additions and 4 deletions
|
|
@ -13,7 +13,8 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
|
||||||
last_paragraph_id = None
|
last_paragraph_id = None
|
||||||
last_cut_type = None
|
last_cut_type = None
|
||||||
|
|
||||||
for paragraph_id, _, sentence, word_count, end_type in chunk_by_sentence(data):
|
for paragraph_id, _, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length):
|
||||||
|
assert word_count <= paragraph_length, f"{paragraph_length = } is smaller than {word_count = }"
|
||||||
# Check if this sentence would exceed length limit
|
# Check if this sentence would exceed length limit
|
||||||
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
|
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
|
||||||
# Yield current chunk
|
# Yield current chunk
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,10 @@
|
||||||
|
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from typing import Optional
|
||||||
from .chunk_by_word import chunk_by_word
|
from .chunk_by_word import chunk_by_word
|
||||||
|
|
||||||
def chunk_by_sentence(data: str):
|
def chunk_by_sentence(data: str, maximum_length: Optional[int]):
|
||||||
sentence = ""
|
sentence = ""
|
||||||
paragraph_id = uuid4()
|
paragraph_id = uuid4()
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
|
|
@ -14,7 +15,7 @@ def chunk_by_sentence(data: str):
|
||||||
sentence += word
|
sentence += word
|
||||||
word_count += 1
|
word_count += 1
|
||||||
|
|
||||||
if word_type == "paragraph_end" or word_type == "sentence_end":
|
if word_type == "paragraph_end" or word_type == "sentence_end" or ((word_count is not None) and (word_count == maximum_length)):
|
||||||
yield (paragraph_id, chunk_index, sentence, word_count, word_type)
|
yield (paragraph_id, chunk_index, sentence, word_count, word_type)
|
||||||
sentence = ""
|
sentence = ""
|
||||||
word_count = 0
|
word_count = 0
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from cognee.tasks.chunks import chunk_by_paragraph
|
from cognee.tasks.chunks import chunk_by_paragraph, chunk_by_word
|
||||||
from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
|
from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,3 +40,45 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
|
||||||
assert (
|
assert (
|
||||||
reconstructed_text == input_text
|
reconstructed_text == input_text
|
||||||
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_text,paragraph_length,batch_paragraphs",
|
||||||
|
[
|
||||||
|
(INPUT_TEXTS["english_text"], 64, True),
|
||||||
|
(INPUT_TEXTS["english_text"], 64, False),
|
||||||
|
(INPUT_TEXTS["english_text"], 256, True),
|
||||||
|
(INPUT_TEXTS["english_text"], 256, False),
|
||||||
|
(INPUT_TEXTS["english_text"], 1024, True),
|
||||||
|
(INPUT_TEXTS["english_text"], 1024, False),
|
||||||
|
(INPUT_TEXTS["english_lists"], 64, True),
|
||||||
|
(INPUT_TEXTS["english_lists"], 64, False),
|
||||||
|
(INPUT_TEXTS["english_lists"], 256, True),
|
||||||
|
(INPUT_TEXTS["english_lists"], 256, False),
|
||||||
|
(INPUT_TEXTS["english_lists"], 1024, True),
|
||||||
|
(INPUT_TEXTS["english_lists"], 1024, False),
|
||||||
|
(INPUT_TEXTS["python_code"], 64, True),
|
||||||
|
(INPUT_TEXTS["python_code"], 64, False),
|
||||||
|
(INPUT_TEXTS["python_code"], 256, True),
|
||||||
|
(INPUT_TEXTS["python_code"], 256, False),
|
||||||
|
(INPUT_TEXTS["python_code"], 1024, True),
|
||||||
|
(INPUT_TEXTS["python_code"], 1024, False),
|
||||||
|
(INPUT_TEXTS["chinese_text"], 64, True),
|
||||||
|
(INPUT_TEXTS["chinese_text"], 64, False),
|
||||||
|
(INPUT_TEXTS["chinese_text"], 256, True),
|
||||||
|
(INPUT_TEXTS["chinese_text"], 256, False),
|
||||||
|
(INPUT_TEXTS["chinese_text"], 1024, True),
|
||||||
|
(INPUT_TEXTS["chinese_text"], 1024, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
|
||||||
|
chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs))
|
||||||
|
|
||||||
|
chunk_lengths = np.array(
|
||||||
|
[len(list(chunk_by_word(chunk["text"]))) for chunk in chunks]
|
||||||
|
)
|
||||||
|
|
||||||
|
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
|
||||||
|
assert np.all(
|
||||||
|
chunk_lengths <= paragraph_length
|
||||||
|
), f"{paragraph_length = }: {larger_chunks} are too large"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue