feature: text chunker with overlap (#1732)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> - Implements `TextChunkerWithOverlap` with configurable `chunk_overlap_ratio` - Abstracts chunk_data generation via `get_chunk_data` callable (defaults to `chunk_by_paragraph`) - Parametrized tests verify `TextChunker` and `TextChunkerWithOverlap` (0% overlap) produce identical output for all edge cases. - Overlap-specific tests validate `TextChunkerWithOverlap` behavior ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
c0e5ce04ce
commit
5bc83968f8
3 changed files with 696 additions and 0 deletions
124
cognee/modules/chunking/text_chunker_with_overlap.py
Normal file
124
cognee/modules/chunking/text_chunker_with_overlap.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextChunkerWithOverlap(Chunker):
|
||||
def __init__(
|
||||
self,
|
||||
document,
|
||||
get_text: callable,
|
||||
max_chunk_size: int,
|
||||
chunk_overlap_ratio: float = 0.0,
|
||||
get_chunk_data: callable = None,
|
||||
):
|
||||
super().__init__(document, get_text, max_chunk_size)
|
||||
self._accumulated_chunk_data = []
|
||||
self._accumulated_size = 0
|
||||
self.chunk_overlap_ratio = chunk_overlap_ratio
|
||||
self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio)
|
||||
|
||||
if get_chunk_data is not None:
|
||||
self.get_chunk_data = get_chunk_data
|
||||
elif chunk_overlap_ratio > 0:
|
||||
paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size)
|
||||
self.get_chunk_data = lambda text: chunk_by_paragraph(
|
||||
text, paragraph_max_size, batch_paragraphs=True
|
||||
)
|
||||
else:
|
||||
self.get_chunk_data = lambda text: chunk_by_paragraph(
|
||||
text, self.max_chunk_size, batch_paragraphs=True
|
||||
)
|
||||
|
||||
def _accumulation_overflows(self, chunk_data):
|
||||
"""Check if adding chunk_data would exceed max_chunk_size."""
|
||||
return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size
|
||||
|
||||
def _accumulate_chunk_data(self, chunk_data):
|
||||
"""Add chunk_data to the current accumulation."""
|
||||
self._accumulated_chunk_data.append(chunk_data)
|
||||
self._accumulated_size += chunk_data["chunk_size"]
|
||||
|
||||
def _clear_accumulation(self):
|
||||
"""Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio."""
|
||||
if self.chunk_overlap == 0:
|
||||
self._accumulated_chunk_data = []
|
||||
self._accumulated_size = 0
|
||||
return
|
||||
|
||||
# Keep chunk_data from the end that fit in overlap
|
||||
overlap_chunk_data = []
|
||||
overlap_size = 0
|
||||
|
||||
for chunk_data in reversed(self._accumulated_chunk_data):
|
||||
if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap:
|
||||
overlap_chunk_data.insert(0, chunk_data)
|
||||
overlap_size += chunk_data["chunk_size"]
|
||||
else:
|
||||
break
|
||||
|
||||
self._accumulated_chunk_data = overlap_chunk_data
|
||||
self._accumulated_size = overlap_size
|
||||
|
||||
def _create_chunk(self, text, size, cut_type, chunk_id=None):
|
||||
"""Create a DocumentChunk with standard metadata."""
|
||||
try:
|
||||
return DocumentChunk(
|
||||
id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
text=text,
|
||||
chunk_size=size,
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=cut_type,
|
||||
contains=[],
|
||||
metadata={"index_fields": ["text"]},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
def _create_chunk_from_accumulation(self):
|
||||
"""Create a DocumentChunk from current accumulated chunk_data."""
|
||||
chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data)
|
||||
return self._create_chunk(
|
||||
text=chunk_text,
|
||||
size=self._accumulated_size,
|
||||
cut_type=self._accumulated_chunk_data[-1]["cut_type"],
|
||||
)
|
||||
|
||||
def _emit_chunk(self, chunk_data):
|
||||
"""Emit a chunk when accumulation overflows."""
|
||||
if len(self._accumulated_chunk_data) > 0:
|
||||
chunk = self._create_chunk_from_accumulation()
|
||||
self._clear_accumulation()
|
||||
self._accumulate_chunk_data(chunk_data)
|
||||
else:
|
||||
# Handle single chunk_data exceeding max_chunk_size
|
||||
chunk = self._create_chunk(
|
||||
text=chunk_data["text"],
|
||||
size=chunk_data["chunk_size"],
|
||||
cut_type=chunk_data["cut_type"],
|
||||
chunk_id=chunk_data["chunk_id"],
|
||||
)
|
||||
|
||||
self.chunk_index += 1
|
||||
return chunk
|
||||
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
for chunk_data in self.get_chunk_data(content_text):
|
||||
if not self._accumulation_overflows(chunk_data):
|
||||
self._accumulate_chunk_data(chunk_data)
|
||||
continue
|
||||
|
||||
yield self._emit_chunk(chunk_data)
|
||||
|
||||
if len(self._accumulated_chunk_data) == 0:
|
||||
return
|
||||
|
||||
yield self._create_chunk_from_accumulation()
|
||||
248
cognee/tests/unit/modules/chunking/test_text_chunker.py
Normal file
248
cognee/tests/unit/modules/chunking/test_text_chunker.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""Unit tests for TextChunker and TextChunkerWithOverlap behavioral equivalence."""
|
||||
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
|
||||
|
||||
@pytest.fixture(params=["TextChunker", "TextChunkerWithOverlap"])
|
||||
def chunker_class(request):
|
||||
"""Parametrize tests to run against both implementations."""
|
||||
return TextChunker if request.param == "TextChunker" else TextChunkerWithOverlap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_text_generator():
|
||||
"""Factory for async text generators."""
|
||||
|
||||
def _factory(*texts):
|
||||
async def gen():
|
||||
for text in texts:
|
||||
yield text
|
||||
|
||||
return gen
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
async def collect_chunks(chunker):
|
||||
"""Consume async generator and return list of chunks."""
|
||||
chunks = []
|
||||
async for chunk in chunker.read():
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_input_produces_no_chunks(chunker_class, make_text_generator):
|
||||
"""Empty input should yield no chunks."""
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator("")
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=512)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 0, "Empty input should produce no chunks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_input_emits_single_chunk(chunker_class, make_text_generator):
|
||||
"""Whitespace-only input should produce exactly one chunk with unchanged text."""
|
||||
whitespace_text = " \n\t \r\n "
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(whitespace_text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=512)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 1, "Whitespace-only input should produce exactly one chunk"
|
||||
assert chunks[0].text == whitespace_text, "Chunk text should equal input (whitespace preserved)"
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_paragraph_below_limit_emits_one_chunk(chunker_class, make_text_generator):
|
||||
"""Single paragraph below limit should emit exactly one chunk."""
|
||||
text = "This is a short paragraph."
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=512)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 1, "Single short paragraph should produce exactly one chunk"
|
||||
assert chunks[0].text == text, "Chunk text should match input"
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[0].chunk_size > 0, "Chunk should have positive size"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_paragraph_gets_emitted_as_a_single_chunk(
|
||||
chunker_class, make_text_generator
|
||||
):
|
||||
"""Oversized paragraph from chunk_by_paragraph should be emitted as single chunk."""
|
||||
text = ("A" * 1500) + ". Next sentence."
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=50)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 2, "Should produce 2 chunks (oversized paragraph + next sentence)"
|
||||
assert chunks[0].chunk_size > 50, "First chunk should be oversized"
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overflow_on_next_paragraph_emits_separate_chunk(chunker_class, make_text_generator):
|
||||
"""First paragraph near limit plus small paragraph should produce two separate chunks."""
|
||||
first_para = " ".join(["word"] * 5)
|
||||
second_para = "Short text."
|
||||
text = first_para + " " + second_para
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=10)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 2, "Should produce 2 chunks due to overflow"
|
||||
assert chunks[0].text.strip() == first_para, "First chunk should contain only first paragraph"
|
||||
assert chunks[1].text.strip() == second_para, (
|
||||
"Second chunk should contain only second paragraph"
|
||||
)
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_paragraphs_batch_correctly(chunker_class, make_text_generator):
|
||||
"""Multiple small paragraphs should batch together with joiner spaces counted."""
|
||||
paragraphs = [" ".join(["word"] * 12) for _ in range(40)]
|
||||
text = " ".join(paragraphs)
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=49)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 20, (
|
||||
"Should batch paragraphs (2 per chunk: 12 words × 2 tokens = 24, 24 + 1 joiner + 24 = 49)"
|
||||
)
|
||||
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
|
||||
"Chunk indices should be sequential"
|
||||
)
|
||||
all_text = " ".join(chunk.text.strip() for chunk in chunks)
|
||||
expected_text = " ".join(paragraphs)
|
||||
assert all_text == expected_text, "All paragraph text should be preserved with correct spacing"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternating_large_and_small_paragraphs_dont_batch(
|
||||
chunker_class, make_text_generator
|
||||
):
|
||||
"""Alternating near-max and small paragraphs should each become separate chunks."""
|
||||
large1 = "word" * 15 + "."
|
||||
small1 = "Short."
|
||||
large2 = "word" * 15 + "."
|
||||
small2 = "Tiny."
|
||||
text = large1 + " " + small1 + " " + large2 + " " + small2
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
max_chunk_size = 10
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=max_chunk_size)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 4, "Should produce multiple chunks"
|
||||
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
|
||||
"Chunk indices should be sequential"
|
||||
)
|
||||
assert chunks[0].chunk_size > max_chunk_size, (
|
||||
"First chunk should be oversized (large paragraph)"
|
||||
)
|
||||
assert chunks[1].chunk_size <= max_chunk_size, "Second chunk should be small (small paragraph)"
|
||||
assert chunks[2].chunk_size > max_chunk_size, (
|
||||
"Third chunk should be oversized (large paragraph)"
|
||||
)
|
||||
assert chunks[3].chunk_size <= max_chunk_size, "Fourth chunk should be small (small paragraph)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_indices_and_ids_are_deterministic(chunker_class, make_text_generator):
|
||||
"""Running chunker twice on identical input should produce identical indices and IDs."""
|
||||
sentence1 = "one " * 4 + ". "
|
||||
sentence2 = "two " * 4 + ". "
|
||||
sentence3 = "one " * 4 + ". "
|
||||
sentence4 = "two " * 4 + ". "
|
||||
text = sentence1 + sentence2 + sentence3 + sentence4
|
||||
doc_id = uuid4()
|
||||
max_chunk_size = 20
|
||||
|
||||
document1 = Document(
|
||||
id=doc_id,
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text1 = make_text_generator(text)
|
||||
chunker1 = chunker_class(document1, get_text1, max_chunk_size=max_chunk_size)
|
||||
chunks1 = await collect_chunks(chunker1)
|
||||
|
||||
document2 = Document(
|
||||
id=doc_id,
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text2 = make_text_generator(text)
|
||||
chunker2 = chunker_class(document2, get_text2, max_chunk_size=max_chunk_size)
|
||||
chunks2 = await collect_chunks(chunker2)
|
||||
|
||||
assert len(chunks1) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
|
||||
assert len(chunks2) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
|
||||
assert [c.chunk_index for c in chunks1] == [0, 1], "First run indices should be [0, 1]"
|
||||
assert [c.chunk_index for c in chunks2] == [0, 1], "Second run indices should be [0, 1]"
|
||||
assert chunks1[0].id == chunks2[0].id, "First chunk ID should be deterministic"
|
||||
assert chunks1[1].id == chunks2[1].id, "Second chunk ID should be deterministic"
|
||||
assert chunks1[0].id != chunks1[1].id, "Chunk IDs should be unique within a run"
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
"""Unit tests for TextChunkerWithOverlap overlap behavior."""
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_text_generator():
|
||||
"""Factory for async text generators."""
|
||||
|
||||
def _factory(*texts):
|
||||
async def gen():
|
||||
for text in texts:
|
||||
yield text
|
||||
|
||||
return gen
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_controlled_chunk_data():
|
||||
"""Factory for controlled chunk_data generators."""
|
||||
|
||||
def _factory(*sentences, chunk_size_per_sentence=10):
|
||||
def _chunk_data(text):
|
||||
return [
|
||||
{
|
||||
"text": sentence,
|
||||
"chunk_size": chunk_size_per_sentence,
|
||||
"cut_type": "sentence",
|
||||
"chunk_id": uuid4(),
|
||||
}
|
||||
for sentence in sentences
|
||||
]
|
||||
|
||||
return _chunk_data
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_half_overlap_preserves_content_across_chunks(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 50% overlap, consecutive chunks should share half their content."""
|
||||
s1 = "one"
|
||||
s2 = "two"
|
||||
s3 = "three"
|
||||
s4 = "four"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.5,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 3, "Should produce exactly 3 chunks (s1+s2, s2+s3, s3+s4)"
|
||||
assert [c.chunk_index for c in chunks] == [0, 1, 2], "Chunk indices should be [0, 1, 2]"
|
||||
assert "one" in chunks[0].text and "two" in chunks[0].text, "Chunk 0 should contain s1 and s2"
|
||||
assert "two" in chunks[1].text and "three" in chunks[1].text, (
|
||||
"Chunk 1 should contain s2 (overlap) and s3"
|
||||
)
|
||||
assert "three" in chunks[2].text and "four" in chunks[2].text, (
|
||||
"Chunk 2 should contain s3 (overlap) and s4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_overlap_produces_no_duplicate_content(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 0% overlap, no content should appear in multiple chunks."""
|
||||
s1 = "one"
|
||||
s2 = "two"
|
||||
s3 = "three"
|
||||
s4 = "four"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.0,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 2, "Should produce exactly 2 chunks (s1+s2, s3+s4)"
|
||||
assert "one" in chunks[0].text and "two" in chunks[0].text, (
|
||||
"First chunk should contain s1 and s2"
|
||||
)
|
||||
assert "three" in chunks[1].text and "four" in chunks[1].text, (
|
||||
"Second chunk should contain s3 and s4"
|
||||
)
|
||||
assert "two" not in chunks[1].text and "three" not in chunks[0].text, (
|
||||
"No overlap: end of chunk 0 should not appear in chunk 1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_overlap_ratio_creates_minimal_overlap(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 25% overlap ratio, chunks should have minimal overlap."""
|
||||
s1 = "alpha"
|
||||
s2 = "beta"
|
||||
s3 = "gamma"
|
||||
s4 = "delta"
|
||||
s5 = "epsilon"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=40,
|
||||
chunk_overlap_ratio=0.25,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 2, "Should produce exactly 2 chunks"
|
||||
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
|
||||
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
|
||||
"Chunk 0 should contain s1 through s4"
|
||||
)
|
||||
assert s4 in chunks[1].text and s5 in chunks[1].text, (
|
||||
"Chunk 1 should contain overlap s4 and new content s5"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_overlap_ratio_creates_significant_overlap(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 75% overlap ratio, consecutive chunks should share most content."""
|
||||
s1 = "red"
|
||||
s2 = "blue"
|
||||
s3 = "green"
|
||||
s4 = "yellow"
|
||||
s5 = "purple"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=5)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.75,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 2, "Should produce exactly 2 chunks with 75% overlap"
|
||||
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
|
||||
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
|
||||
"Chunk 0 should contain s1, s2, s3, s4"
|
||||
)
|
||||
assert all(token in chunks[1].text for token in [s2, s3, s4, s5]), (
|
||||
"Chunk 1 should contain s2, s3, s4 (overlap) and s5"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_no_dangling_overlap(make_text_generator, make_controlled_chunk_data):
|
||||
"""Text that fits in one chunk should produce exactly one chunk, no overlap artifact."""
|
||||
s1 = "alpha"
|
||||
s2 = "beta"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.5,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 1, (
|
||||
"Should produce exactly 1 chunk when content fits within max_chunk_size"
|
||||
)
|
||||
assert chunks[0].chunk_index == 0, "Single chunk should have index 0"
|
||||
assert "alpha" in chunks[0].text and "beta" in chunks[0].text, (
|
||||
"Single chunk should contain all content"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paragraph_chunking_with_overlap(make_text_generator):
|
||||
"""Test that chunk_by_paragraph integration produces 25% overlap between chunks."""
|
||||
|
||||
def mock_get_embedding_engine():
|
||||
class MockEngine:
|
||||
tokenizer = None
|
||||
|
||||
return MockEngine()
|
||||
|
||||
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
|
||||
|
||||
max_chunk_size = 20
|
||||
overlap_ratio = 0.25 # 5 token overlap
|
||||
paragraph_max_size = int(0.5 * overlap_ratio * max_chunk_size) # = 2
|
||||
|
||||
text = (
|
||||
"A0 A1. A2 A3. A4 A5. A6 A7. A8 A9. " # 10 tokens (0-9)
|
||||
"B0 B1. B2 B3. B4 B5. B6 B7. B8 B9. " # 10 tokens (10-19)
|
||||
"C0 C1. C2 C3. C4 C5. C6 C7. C8 C9. " # 10 tokens (20-29)
|
||||
"D0 D1. D2 D3. D4 D5. D6 D7. D8 D9. " # 10 tokens (30-39)
|
||||
"E0 E1. E2 E3. E4 E5. E6 E7. E8 E9." # 10 tokens (40-49)
|
||||
)
|
||||
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
get_text = make_text_generator(text)
|
||||
|
||||
def get_chunk_data(text_input):
|
||||
return chunk_by_paragraph(
|
||||
text_input, max_chunk_size=paragraph_max_size, batch_paragraphs=True
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
|
||||
):
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=max_chunk_size,
|
||||
chunk_overlap_ratio=overlap_ratio,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 3, f"Should produce exactly 3 chunks, got {len(chunks)}"
|
||||
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
|
||||
assert chunks[2].chunk_index == 2, "Third chunk should have index 2"
|
||||
|
||||
assert "A0" in chunks[0].text, "Chunk 0 should start with A0"
|
||||
assert "A9" in chunks[0].text, "Chunk 0 should contain A9"
|
||||
assert "B0" in chunks[0].text, "Chunk 0 should contain B0"
|
||||
assert "B9" in chunks[0].text, "Chunk 0 should contain up to B9 (20 tokens)"
|
||||
|
||||
assert "B" in chunks[1].text, "Chunk 1 should have overlap from B section"
|
||||
assert "C" in chunks[1].text, "Chunk 1 should contain C section"
|
||||
assert "D" in chunks[1].text, "Chunk 1 should contain D section"
|
||||
|
||||
assert "D" in chunks[2].text, "Chunk 2 should have overlap from D section"
|
||||
assert "E0" in chunks[2].text, "Chunk 2 should contain E0"
|
||||
assert "E9" in chunks[2].text, "Chunk 2 should end with E9"
|
||||
|
||||
chunk_0_end_words = chunks[0].text.split()[-4:]
|
||||
chunk_1_words = chunks[1].text.split()
|
||||
overlap_0_1 = any(word in chunk_1_words for word in chunk_0_end_words)
|
||||
assert overlap_0_1, (
|
||||
f"No overlap detected between chunks 0 and 1. "
|
||||
f"Chunk 0 ends with: {chunk_0_end_words}, "
|
||||
f"Chunk 1 starts with: {chunk_1_words[:6]}"
|
||||
)
|
||||
|
||||
chunk_1_end_words = chunks[1].text.split()[-4:]
|
||||
chunk_2_words = chunks[2].text.split()
|
||||
overlap_1_2 = any(word in chunk_2_words for word in chunk_1_end_words)
|
||||
assert overlap_1_2, (
|
||||
f"No overlap detected between chunks 1 and 2. "
|
||||
f"Chunk 1 ends with: {chunk_1_end_words}, "
|
||||
f"Chunk 2 starts with: {chunk_2_words[:6]}"
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue