feat: externalize chunkers [cog-1354] (#547)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced document chunk extraction for improved processing consistency
across multiple formats.

- **Refactor**
- Streamlined the configuration for text chunking by replacing indirect
mappings with a direct instantiation approach across document types.
- Updated method signatures across various document classes to accept
chunker class references instead of string identifiers.

- **Chores**
- Removed legacy configuration utilities related to document chunking to
simplify processing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
alekszievr 2025-02-19 13:26:11 +01:00 committed by GitHub
parent b10aef7b25
commit 2a167fa1ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 35 additions and 57 deletions

View file

@ -25,6 +25,7 @@ from cognee.tasks.documents import (
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
from cognee.modules.chunking.TextChunker import TextChunker
logger = logging.getLogger("cognify.v2")
@ -123,7 +124,9 @@ async def get_default_tasks(
Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
extract_chunks_from_documents,
max_chunk_tokens=get_max_chunk_tokens(),
chunker=TextChunker,
), # Extract text chunks based on the document type.
Task(
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}

View file

@ -1,8 +1,5 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .ChunkerMapping import ChunkerConfig
from .Document import Document
@ -13,13 +10,12 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location)
return result.text
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
# Transcribe the audio file
text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
chunker = chunker_cls(
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
)

View file

@ -1,14 +0,0 @@
from cognee.modules.chunking.TextChunker import TextChunker
class ChunkerConfig:
chunker_mapping = {"text_chunker": TextChunker}
@classmethod
def get_chunker(cls, chunker_name: str):
chunker_class = cls.chunker_mapping.get(chunker_name)
if chunker_class is None:
raise NotImplementedError(
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
)
return chunker_class

View file

@ -9,5 +9,7 @@ class Document(DataPoint):
mime_type: str
metadata: dict = {"index_fields": ["name"]}
def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str:
def read(
self, chunk_size: int, chunker_cls: type, max_chunk_tokens: Optional[int] = None
) -> str:
pass

View file

@ -1,8 +1,5 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .ChunkerMapping import ChunkerConfig
from .Document import Document
@ -13,12 +10,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
# Transcribe the image file
text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
chunker = chunker_cls(
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
)

View file

@ -1,15 +1,12 @@
from typing import Optional
from pypdf import PdfReader
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class PdfDocument(Document):
type: str = "pdf"
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
file = PdfReader(self.raw_data_location)
def get_text():
@ -17,8 +14,7 @@ class PdfDocument(Document):
page_text = page.extract_text()
yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
chunker = chunker_cls(
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
)

View file

@ -1,11 +1,10 @@
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class TextDocument(Document):
type: str = "text"
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True:
@ -16,9 +15,7 @@ class TextDocument(Document):
yield text
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
chunker = chunker_cls(
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
)

View file

@ -1,5 +1,4 @@
from io import StringIO
from typing import Optional
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
@ -10,7 +9,7 @@ from .Document import Document
class UnstructuredDocument(Document):
type: str = "unstructured"
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str:
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int) -> str:
def get_text():
try:
from unstructured.partition.auto import partition
@ -29,7 +28,7 @@ class UnstructuredDocument(Document):
yield text
chunker = TextChunker(
chunker = chunker_cls(
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
)

View file

@ -5,6 +5,7 @@ from sqlalchemy import select
from cognee.modules.data.models import Data
from cognee.infrastructure.databases.relational import get_relational_engine
from uuid import UUID
from cognee.modules.chunking.TextChunker import TextChunker
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
@ -26,7 +27,7 @@ async def extract_chunks_from_documents(
documents: list[Document],
max_chunk_tokens: int,
chunk_size: int = 1024,
chunker="text_chunker",
chunker=TextChunker,
) -> AsyncGenerator:
"""
Extracts chunks of data from a list of documents based on the specified chunking parameters.
@ -38,7 +39,7 @@ async def extract_chunks_from_documents(
for document in documents:
document_token_count = 0
for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens
chunk_size=chunk_size, chunker_cls=chunker, max_chunk_tokens=max_chunk_tokens
):
document_token_count += document_chunk.token_count
yield document_chunk

View file

@ -1,6 +1,6 @@
import uuid
from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
GROUND_TRUTH = [
@ -34,7 +34,8 @@ def test_AudioDocument():
)
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
GROUND_TRUTH,
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -1,6 +1,6 @@
import uuid
from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
GROUND_TRUTH = [
@ -23,7 +23,8 @@ def test_ImageDocument():
)
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
GROUND_TRUTH,
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -1,6 +1,6 @@
import os
import uuid
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
GROUND_TRUTH = [
@ -25,7 +25,7 @@ def test_PdfDocument():
)
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048)
GROUND_TRUTH, document.read(chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=2048)
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -2,7 +2,7 @@ import os
import uuid
import pytest
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
GROUND_TRUTH = {
@ -38,7 +38,7 @@ def test_TextDocument(input_file, chunk_size):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file],
document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024),
document.read(chunk_size=chunk_size, chunker_cls=TextChunker, max_chunk_tokens=1024),
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -1,6 +1,6 @@
import os
import uuid
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument
@ -69,7 +69,7 @@ def test_UnstructuredDocument():
# Test PPTX
for paragraph_data in pptx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
@ -79,7 +79,7 @@ def test_UnstructuredDocument():
# Test DOCX
for paragraph_data in docx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
@ -89,7 +89,7 @@ def test_UnstructuredDocument():
# TEST CSV
for paragraph_data in csv_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
@ -101,7 +101,7 @@ def test_UnstructuredDocument():
# Test XLSX
for paragraph_data in xlsx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"