feat: externalize chunkers [cog-1354] (#547)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced document chunk extraction for improved processing consistency across multiple formats. - **Refactor** - Streamlined the configuration for text chunking by replacing indirect mappings with a direct instantiation approach across document types. - Updated method signatures across various document classes to accept chunker class references instead of string identifiers. - **Chores** - Removed legacy configuration utilities related to document chunking to simplify processing. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
parent
b10aef7b25
commit
2a167fa1ab
14 changed files with 35 additions and 57 deletions
|
|
@ -25,6 +25,7 @@ from cognee.tasks.documents import (
|
|||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
|
||||
logger = logging.getLogger("cognify.v2")
|
||||
|
||||
|
|
@ -123,7 +124,9 @@ async def get_default_tasks(
|
|||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_tokens=get_max_chunk_tokens(),
|
||||
chunker=TextChunker,
|
||||
), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
||||
|
|
@ -13,13 +10,12 @@ class AudioDocument(Document):
|
|||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
return result.text
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
# Transcribe the audio file
|
||||
|
||||
text = self.create_transcript()
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
chunker = chunker_cls(
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
|
||||
|
||||
class ChunkerConfig:
|
||||
chunker_mapping = {"text_chunker": TextChunker}
|
||||
|
||||
@classmethod
|
||||
def get_chunker(cls, chunker_name: str):
|
||||
chunker_class = cls.chunker_mapping.get(chunker_name)
|
||||
if chunker_class is None:
|
||||
raise NotImplementedError(
|
||||
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
|
||||
)
|
||||
return chunker_class
|
||||
|
|
@ -9,5 +9,7 @@ class Document(DataPoint):
|
|||
mime_type: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str:
|
||||
def read(
|
||||
self, chunk_size: int, chunker_cls: type, max_chunk_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
||||
|
|
@ -13,12 +10,11 @@ class ImageDocument(Document):
|
|||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return result.choices[0].message.content
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
chunker = chunker_cls(
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
from pypdf import PdfReader
|
||||
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
@ -17,8 +14,7 @@ class PdfDocument(Document):
|
|||
page_text = page.extract_text()
|
||||
yield page_text
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
chunker = chunker_cls(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
def get_text():
|
||||
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||
while True:
|
||||
|
|
@ -16,9 +15,7 @@ class TextDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
|
||||
chunker = chunker_func(
|
||||
chunker = chunker_cls(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
||||
|
|
@ -10,7 +9,7 @@ from .Document import Document
|
|||
class UnstructuredDocument(Document):
|
||||
type: str = "unstructured"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str:
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int) -> str:
|
||||
def get_text():
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
|
|
@ -29,7 +28,7 @@ class UnstructuredDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
chunker = TextChunker(
|
||||
chunker = chunker_cls(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlalchemy import select
|
|||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from uuid import UUID
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
|
||||
|
||||
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
|
||||
|
|
@ -26,7 +27,7 @@ async def extract_chunks_from_documents(
|
|||
documents: list[Document],
|
||||
max_chunk_tokens: int,
|
||||
chunk_size: int = 1024,
|
||||
chunker="text_chunker",
|
||||
chunker=TextChunker,
|
||||
) -> AsyncGenerator:
|
||||
"""
|
||||
Extracts chunks of data from a list of documents based on the specified chunking parameters.
|
||||
|
|
@ -38,7 +39,7 @@ async def extract_chunks_from_documents(
|
|||
for document in documents:
|
||||
document_token_count = 0
|
||||
for document_chunk in document.read(
|
||||
chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens
|
||||
chunk_size=chunk_size, chunker_cls=chunker, max_chunk_tokens=max_chunk_tokens
|
||||
):
|
||||
document_token_count += document_chunk.token_count
|
||||
yield document_chunk
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
||||
|
||||
GROUND_TRUTH = [
|
||||
|
|
@ -34,7 +34,8 @@ def test_AudioDocument():
|
|||
)
|
||||
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
|
||||
GROUND_TRUTH,
|
||||
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
|
||||
|
||||
GROUND_TRUTH = [
|
||||
|
|
@ -23,7 +23,8 @@ def test_ImageDocument():
|
|||
)
|
||||
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
|
||||
GROUND_TRUTH,
|
||||
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import uuid
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
|
||||
|
||||
GROUND_TRUTH = [
|
||||
|
|
@ -25,7 +25,7 @@ def test_PdfDocument():
|
|||
)
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048)
|
||||
GROUND_TRUTH, document.read(chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=2048)
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
|
||||
|
||||
GROUND_TRUTH = {
|
||||
|
|
@ -38,7 +38,7 @@ def test_TextDocument(input_file, chunk_size):
|
|||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH[input_file],
|
||||
document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024),
|
||||
document.read(chunk_size=chunk_size, chunker_cls=TextChunker, max_chunk_tokens=1024),
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import uuid
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument
|
||||
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ def test_UnstructuredDocument():
|
|||
|
||||
# Test PPTX
|
||||
for paragraph_data in pptx_document.read(
|
||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
||||
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||
):
|
||||
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||
|
|
@ -79,7 +79,7 @@ def test_UnstructuredDocument():
|
|||
|
||||
# Test DOCX
|
||||
for paragraph_data in docx_document.read(
|
||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
||||
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||
):
|
||||
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||
|
|
@ -89,7 +89,7 @@ def test_UnstructuredDocument():
|
|||
|
||||
# TEST CSV
|
||||
for paragraph_data in csv_document.read(
|
||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
||||
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||
):
|
||||
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
||||
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
||||
|
|
@ -101,7 +101,7 @@ def test_UnstructuredDocument():
|
|||
|
||||
# Test XLSX
|
||||
for paragraph_data in xlsx_document.read(
|
||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
||||
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||
):
|
||||
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue