feat: externalize chunkers [cog-1354] (#547)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced document chunk extraction for improved processing consistency across multiple formats. - **Refactor** - Streamlined the configuration for text chunking by replacing indirect mappings with a direct instantiation approach across document types. - Updated method signatures across various document classes to accept chunker class references instead of string identifiers. - **Chores** - Removed legacy configuration utilities related to document chunking to simplify processing. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
parent
b10aef7b25
commit
2a167fa1ab
14 changed files with 35 additions and 57 deletions
|
|
@ -25,6 +25,7 @@ from cognee.tasks.documents import (
|
||||||
from cognee.tasks.graph import extract_graph_from_data
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.tasks.summarization import summarize_text
|
from cognee.tasks.summarization import summarize_text
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
|
|
||||||
logger = logging.getLogger("cognify.v2")
|
logger = logging.getLogger("cognify.v2")
|
||||||
|
|
||||||
|
|
@ -123,7 +124,9 @@ async def get_default_tasks(
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||||
Task(
|
Task(
|
||||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
extract_chunks_from_documents,
|
||||||
|
max_chunk_tokens=get_max_chunk_tokens(),
|
||||||
|
chunker=TextChunker,
|
||||||
), # Extract text chunks based on the document type.
|
), # Extract text chunks based on the document type.
|
||||||
Task(
|
Task(
|
||||||
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
|
||||||
from .ChunkerMapping import ChunkerConfig
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -13,13 +10,12 @@ class AudioDocument(Document):
|
||||||
result = get_llm_client().create_transcript(self.raw_data_location)
|
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||||
return result.text
|
return result.text
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||||
# Transcribe the audio file
|
# Transcribe the audio file
|
||||||
|
|
||||||
text = self.create_transcript()
|
text = self.create_transcript()
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker = chunker_cls(
|
||||||
chunker = chunker_func(
|
|
||||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkerConfig:
|
|
||||||
chunker_mapping = {"text_chunker": TextChunker}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_chunker(cls, chunker_name: str):
|
|
||||||
chunker_class = cls.chunker_mapping.get(chunker_name)
|
|
||||||
if chunker_class is None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
|
|
||||||
)
|
|
||||||
return chunker_class
|
|
||||||
|
|
@ -9,5 +9,7 @@ class Document(DataPoint):
|
||||||
mime_type: str
|
mime_type: str
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str:
|
def read(
|
||||||
|
self, chunk_size: int, chunker_cls: type, max_chunk_tokens: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
|
||||||
from .ChunkerMapping import ChunkerConfig
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -13,12 +10,11 @@ class ImageDocument(Document):
|
||||||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||||
return result.choices[0].message.content
|
return result.choices[0].message.content
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||||
# Transcribe the image file
|
# Transcribe the image file
|
||||||
text = self.transcribe_image()
|
text = self.transcribe_image()
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker = chunker_cls(
|
||||||
chunker = chunker_func(
|
|
||||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,12 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
|
||||||
from .ChunkerMapping import ChunkerConfig
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class PdfDocument(Document):
|
class PdfDocument(Document):
|
||||||
type: str = "pdf"
|
type: str = "pdf"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||||
file = PdfReader(self.raw_data_location)
|
file = PdfReader(self.raw_data_location)
|
||||||
|
|
||||||
def get_text():
|
def get_text():
|
||||||
|
|
@ -17,8 +14,7 @@ class PdfDocument(Document):
|
||||||
page_text = page.extract_text()
|
page_text = page.extract_text()
|
||||||
yield page_text
|
yield page_text
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker = chunker_cls(
|
||||||
chunker = chunker_func(
|
|
||||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
from .ChunkerMapping import ChunkerConfig
|
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class TextDocument(Document):
|
class TextDocument(Document):
|
||||||
type: str = "text"
|
type: str = "text"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||||
def get_text():
|
def get_text():
|
||||||
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -16,9 +15,7 @@ class TextDocument(Document):
|
||||||
|
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker = chunker_cls(
|
||||||
|
|
||||||
chunker = chunker_func(
|
|
||||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
||||||
|
|
@ -10,7 +9,7 @@ from .Document import Document
|
||||||
class UnstructuredDocument(Document):
|
class UnstructuredDocument(Document):
|
||||||
type: str = "unstructured"
|
type: str = "unstructured"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str:
|
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int) -> str:
|
||||||
def get_text():
|
def get_text():
|
||||||
try:
|
try:
|
||||||
from unstructured.partition.auto import partition
|
from unstructured.partition.auto import partition
|
||||||
|
|
@ -29,7 +28,7 @@ class UnstructuredDocument(Document):
|
||||||
|
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
chunker = TextChunker(
|
chunker = chunker_cls(
|
||||||
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from sqlalchemy import select
|
||||||
from cognee.modules.data.models import Data
|
from cognee.modules.data.models import Data
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
|
|
||||||
|
|
||||||
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
|
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
|
||||||
|
|
@ -26,7 +27,7 @@ async def extract_chunks_from_documents(
|
||||||
documents: list[Document],
|
documents: list[Document],
|
||||||
max_chunk_tokens: int,
|
max_chunk_tokens: int,
|
||||||
chunk_size: int = 1024,
|
chunk_size: int = 1024,
|
||||||
chunker="text_chunker",
|
chunker=TextChunker,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
"""
|
"""
|
||||||
Extracts chunks of data from a list of documents based on the specified chunking parameters.
|
Extracts chunks of data from a list of documents based on the specified chunking parameters.
|
||||||
|
|
@ -38,7 +39,7 @@ async def extract_chunks_from_documents(
|
||||||
for document in documents:
|
for document in documents:
|
||||||
document_token_count = 0
|
document_token_count = 0
|
||||||
for document_chunk in document.read(
|
for document_chunk in document.read(
|
||||||
chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens
|
chunk_size=chunk_size, chunker_cls=chunker, max_chunk_tokens=max_chunk_tokens
|
||||||
):
|
):
|
||||||
document_token_count += document_chunk.token_count
|
document_token_count += document_chunk.token_count
|
||||||
yield document_chunk
|
yield document_chunk
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
||||||
|
|
||||||
GROUND_TRUTH = [
|
GROUND_TRUTH = [
|
||||||
|
|
@ -34,7 +34,8 @@ def test_AudioDocument():
|
||||||
)
|
)
|
||||||
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
|
GROUND_TRUTH,
|
||||||
|
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
|
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
|
||||||
|
|
||||||
GROUND_TRUTH = [
|
GROUND_TRUTH = [
|
||||||
|
|
@ -23,7 +23,8 @@ def test_ImageDocument():
|
||||||
)
|
)
|
||||||
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
|
GROUND_TRUTH,
|
||||||
|
document.read(chunk_size=64, chunker_cls=TextChunker, max_chunk_tokens=512),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
|
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
|
||||||
|
|
||||||
GROUND_TRUTH = [
|
GROUND_TRUTH = [
|
||||||
|
|
@ -25,7 +25,7 @@ def test_PdfDocument():
|
||||||
)
|
)
|
||||||
|
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048)
|
GROUND_TRUTH, document.read(chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=2048)
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
|
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
|
||||||
|
|
||||||
GROUND_TRUTH = {
|
GROUND_TRUTH = {
|
||||||
|
|
@ -38,7 +38,7 @@ def test_TextDocument(input_file, chunk_size):
|
||||||
|
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH[input_file],
|
GROUND_TRUTH[input_file],
|
||||||
document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024),
|
document.read(chunk_size=chunk_size, chunker_cls=TextChunker, max_chunk_tokens=1024),
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument
|
from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,7 +69,7 @@ def test_UnstructuredDocument():
|
||||||
|
|
||||||
# Test PPTX
|
# Test PPTX
|
||||||
for paragraph_data in pptx_document.read(
|
for paragraph_data in pptx_document.read(
|
||||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||||
):
|
):
|
||||||
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
||||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||||
|
|
@ -79,7 +79,7 @@ def test_UnstructuredDocument():
|
||||||
|
|
||||||
# Test DOCX
|
# Test DOCX
|
||||||
for paragraph_data in docx_document.read(
|
for paragraph_data in docx_document.read(
|
||||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||||
):
|
):
|
||||||
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
||||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||||
|
|
@ -89,7 +89,7 @@ def test_UnstructuredDocument():
|
||||||
|
|
||||||
# TEST CSV
|
# TEST CSV
|
||||||
for paragraph_data in csv_document.read(
|
for paragraph_data in csv_document.read(
|
||||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||||
):
|
):
|
||||||
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
||||||
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
||||||
|
|
@ -101,7 +101,7 @@ def test_UnstructuredDocument():
|
||||||
|
|
||||||
# Test XLSX
|
# Test XLSX
|
||||||
for paragraph_data in xlsx_document.read(
|
for paragraph_data in xlsx_document.read(
|
||||||
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
|
chunk_size=1024, chunker_cls=TextChunker, max_chunk_tokens=1024
|
||||||
):
|
):
|
||||||
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
||||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue