feat: externalize chunkers [cog-1354] (#547)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced document chunk extraction for improved processing consistency
across multiple formats.

- **Refactor**
- Streamlined the configuration for text chunking by replacing indirect
mappings with a direct instantiation approach across document types.
- Updated method signatures across various document classes to accept
chunker class references instead of string identifiers.

- **Chores**
- Removed legacy configuration utilities related to document chunking to
simplify processing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
alekszievr 2025-02-19 13:26:11 +01:00 committed by GitHub
parent b10aef7b25
commit 2a167fa1ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 35 additions and 57 deletions

View file

@ -25,6 +25,7 @@ from cognee.tasks.documents import (
from cognee.tasks.graph import extract_graph_from_data from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text from cognee.tasks.summarization import summarize_text
from cognee.modules.chunking.TextChunker import TextChunker
logger = logging.getLogger("cognify.v2") logger = logging.getLogger("cognify.v2")
@ -123,7 +124,9 @@ async def get_default_tasks(
Task(classify_documents), Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]), Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task( Task(
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens() extract_chunks_from_documents,
max_chunk_tokens=get_max_chunk_tokens(),
chunker=TextChunker,
), # Extract text chunks based on the document type. ), # Extract text chunks based on the document type.
Task( Task(
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}

View file

@ -1,8 +1,5 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .ChunkerMapping import ChunkerConfig
from .Document import Document from .Document import Document
@ -13,13 +10,12 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
return result.text return result.text
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
# Transcribe the audio file # Transcribe the audio file
text = self.create_transcript() text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_cls(
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
) )

View file

@ -1,14 +0,0 @@
from cognee.modules.chunking.TextChunker import TextChunker
class ChunkerConfig:
chunker_mapping = {"text_chunker": TextChunker}
@classmethod
def get_chunker(cls, chunker_name: str):
chunker_class = cls.chunker_mapping.get(chunker_name)
if chunker_class is None:
raise NotImplementedError(
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
)
return chunker_class

View file

@ -9,5 +9,7 @@ class Document(DataPoint):
mime_type: str mime_type: str
metadata: dict = {"index_fields": ["name"]} metadata: dict = {"index_fields": ["name"]}
def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str: def read(
self, chunk_size: int, chunker_cls: type, max_chunk_tokens: Optional[int] = None
) -> str:
pass pass

View file

@ -1,8 +1,5 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .ChunkerMapping import ChunkerConfig
from .Document import Document from .Document import Document
@ -13,12 +10,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_cls(
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
) )

View file

@ -1,15 +1,12 @@
from typing import Optional
from pypdf import PdfReader from pypdf import PdfReader
from .ChunkerMapping import ChunkerConfig
from .Document import Document from .Document import Document
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -17,8 +14,7 @@ class PdfDocument(Document):
page_text = page.extract_text() page_text = page.extract_text()
yield page_text yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_cls(
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
) )

View file

@ -1,11 +1,10 @@
from .ChunkerMapping import ChunkerConfig
from .Document import Document from .Document import Document
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
def get_text(): def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file: with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True: while True:
@ -16,9 +15,7 @@ class TextDocument(Document):
yield text yield text
chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_cls(
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
) )

View file

@ -1,5 +1,4 @@
from io import StringIO from io import StringIO
from typing import Optional
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.exceptions import UnstructuredLibraryImportError from cognee.modules.data.exceptions import UnstructuredLibraryImportError
@ -10,7 +9,7 @@ from .Document import Document
class UnstructuredDocument(Document): class UnstructuredDocument(Document):
type: str = "unstructured" type: str = "unstructured"
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str: def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int) -> str:
def get_text(): def get_text():
try: try:
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
@ -29,7 +28,7 @@ class UnstructuredDocument(Document):
yield text yield text
chunker = TextChunker( chunker = chunker_cls(
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
) )

View file

@ -5,6 +5,7 @@ from sqlalchemy import select
from cognee.modules.data.models import Data from cognee.modules.data.models import Data
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from uuid import UUID from uuid import UUID
from cognee.modules.chunking.TextChunker import TextChunker
async def update_document_token_count(document_id: UUID, token_count: int) -> None: async def update_document_token_count(document_id: UUID, token_count: int) -> None:
@ -26,7 +27,7 @@ async def extract_chunks_from_documents(
documents: list[Document], documents: list[Document],
max_chunk_tokens: int, max_chunk_tokens: int,
chunk_size: int = 1024, chunk_size: int = 1024,
chunker="text_chunker", chunker=TextChunker,
) -> AsyncGenerator: ) -> AsyncGenerator:
""" """
Extracts chunks of data from a list of documents based on the specified chunking parameters. Extracts chunks of data from a list of documents based on the specified chunking parameters.
@ -38,7 +39,7 @@ async def extract_chunks_from_documents(
for document in documents: for document in documents:
document_token_count = 0 document_token_count = 0
for document_chunk in document.read( for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens chunk_size=chunk_size, chunker_cls=chunker, max_chunk_tokens=max_chunk_tokens
): ):
document_token_count += document_chunk.token_count document_token_count += document_chunk.token_count
yield document_chunk yield document_chunk

View file

@ -1,6 +1,6 @@
import uuid import uuid
from unittest.mock import patch from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
GROUND_TRUTH = [ GROUND_TRUTH = [
@ -34,7 +34,8 @@ def test_AudioDocument():
) )
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) GROUND_TRUTH,
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -1,6 +1,6 @@
import uuid import uuid
from unittest.mock import patch from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
GROUND_TRUTH = [ GROUND_TRUTH = [
@ -23,7 +23,8 @@ def test_ImageDocument():
) )
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) GROUND_TRUTH,
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -1,6 +1,6 @@
import os import os
import uuid import uuid
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
GROUND_TRUTH = [ GROUND_TRUTH = [
@ -25,7 +25,7 @@ def test_PdfDocument():
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048) GROUND_TRUTH, document.read(chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=2048)
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -2,7 +2,7 @@ import os
import uuid import uuid
import pytest import pytest
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.TextDocument import TextDocument from cognee.modules.data.processing.document_types.TextDocument import TextDocument
GROUND_TRUTH = { GROUND_TRUTH = {
@ -38,7 +38,7 @@ def test_TextDocument(input_file, chunk_size):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], GROUND_TRUTH[input_file],
document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024), document.read(chunk_size=chunk_size, chunker_cls=TextChunker, max_chunk_tokens=1024),
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -1,6 +1,6 @@
import os import os
import uuid import uuid
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument
@ -69,7 +69,7 @@ def test_UnstructuredDocument():
# Test PPTX # Test PPTX
for paragraph_data in pptx_document.read( for paragraph_data in pptx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
): ):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
@ -79,7 +79,7 @@ def test_UnstructuredDocument():
# Test DOCX # Test DOCX
for paragraph_data in docx_document.read( for paragraph_data in docx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
): ):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
@ -89,7 +89,7 @@ def test_UnstructuredDocument():
# TEST CSV # TEST CSV
for paragraph_data in csv_document.read( for paragraph_data in csv_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
): ):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
@ -101,7 +101,7 @@ def test_UnstructuredDocument():
# Test XLSX # Test XLSX
for paragraph_data in xlsx_document.read( for paragraph_data in xlsx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
): ):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"