feat: use external chunker [cog-1354] (#551)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a modular content chunking interface that offers flexible text segmentation with configurable chunk size and overlap. - Added new chunkers for enhanced text processing, including `LangchainChunker` and improved `TextChunker`. - **Refactor** - Unified the chunk extraction mechanism across various document types for improved consistency and type safety. - Updated method signatures to enhance clarity and type safety regarding chunker usage. - Enhanced error handling and logging during text segmentation to guide adjustments when content exceeds limits. - **Bug Fixes** - Adjusted expected output in tests to reflect changes in chunking logic and configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
eba1515127
commit
a61df966c6
11 changed files with 91 additions and 26 deletions
13
cognee/modules/chunking/Chunker.py
Normal file
13
cognee/modules/chunking/Chunker.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
class Chunker:
|
||||
def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024):
|
||||
self.chunk_index = 0
|
||||
self.chunk_size = 0
|
||||
self.token_count = 0
|
||||
|
||||
self.document = document
|
||||
self.max_chunk_size = chunk_size
|
||||
self.get_text = get_text
|
||||
self.max_chunk_tokens = max_chunk_tokens
|
||||
|
||||
def read(self):
|
||||
raise NotImplementedError
|
||||
59
cognee/modules/chunking/LangchainChunker.py
Normal file
59
cognee/modules/chunking/LangchainChunker.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import logging
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangchainChunker(Chunker):
|
||||
"""
|
||||
A Chunker that splits text into chunks using Langchain's RecursiveCharacterTextSplitter.
|
||||
|
||||
The chunker will split the text into chunks of approximately the given size, but will not split
|
||||
a chunk if the split would result in a chunk with fewer than the given overlap tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
document,
|
||||
get_text: callable,
|
||||
max_chunk_tokens: int,
|
||||
chunk_size: int = 1024,
|
||||
chunk_overlap=10,
|
||||
):
|
||||
super().__init__(document, get_text, max_chunk_tokens, chunk_size)
|
||||
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=lambda text: len(text.split()),
|
||||
)
|
||||
|
||||
def read(self):
|
||||
for content_text in self.get_text():
|
||||
for chunk in self.splitter.split_text(content_text):
|
||||
embedding_engine = get_vector_engine().embedding_engine
|
||||
token_count = embedding_engine.tokenizer.count_tokens(chunk)
|
||||
if token_count <= self.max_chunk_tokens:
|
||||
yield DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, chunk),
|
||||
text=chunk,
|
||||
word_count=len(chunk.split()),
|
||||
token_count=token_count,
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type="missing",
|
||||
contains=[],
|
||||
metadata={
|
||||
"index_fields": ["text"],
|
||||
},
|
||||
)
|
||||
self.chunk_index += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Chunk of {token_count} tokens is larger than the maximum of {self.max_chunk_tokens} tokens. Please reduce chunk_size in RecursiveCharacterTextSplitter."
|
||||
)
|
||||
|
|
@ -2,26 +2,13 @@ import logging
|
|||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextChunker:
|
||||
document = None
|
||||
max_chunk_size: int
|
||||
|
||||
chunk_index = 0
|
||||
chunk_size = 0
|
||||
token_count = 0
|
||||
|
||||
def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024):
|
||||
self.document = document
|
||||
self.max_chunk_size = chunk_size
|
||||
self.get_text = get_text
|
||||
self.max_chunk_tokens = max_chunk_tokens
|
||||
|
||||
class TextChunker(Chunker):
|
||||
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
|
||||
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
|
||||
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_chunk_tokens
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
|
||||
from .Document import Document
|
||||
|
||||
|
|
@ -10,7 +11,7 @@ class AudioDocument(Document):
|
|||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
return result.text
|
||||
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: Chunker, max_chunk_tokens: int):
|
||||
# Transcribe the audio file
|
||||
|
||||
text = self.create_transcript()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
|
||||
|
||||
class Document(DataPoint):
|
||||
|
|
@ -10,6 +11,6 @@ class Document(DataPoint):
|
|||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
def read(
|
||||
self, chunk_size: int, chunker_cls: type, max_chunk_tokens: Optional[int] = None
|
||||
self, chunk_size: int, chunker_cls: Chunker, max_chunk_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
|
||||
from .Document import Document
|
||||
|
||||
|
|
@ -10,7 +11,7 @@ class ImageDocument(Document):
|
|||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return result.choices[0].message.content
|
||||
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: Chunker, max_chunk_tokens: int):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from pypdf import PdfReader
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
|
||||
from .Document import Document
|
||||
|
||||
|
|
@ -6,7 +7,7 @@ from .Document import Document
|
|||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: Chunker, max_chunk_tokens: int):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from .Document import Document
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int):
|
||||
def read(self, chunk_size: int, chunker_cls: Chunker, max_chunk_tokens: int):
|
||||
def get_text():
|
||||
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||
while True:
|
||||
text = file.read(1024)
|
||||
text = file.read(1000000)
|
||||
|
||||
if len(text.strip()) == 0:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from io import StringIO
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
||||
|
||||
from .Document import Document
|
||||
|
|
@ -9,7 +9,7 @@ from .Document import Document
|
|||
class UnstructuredDocument(Document):
|
||||
type: str = "unstructured"
|
||||
|
||||
def read(self, chunk_size: int, chunker_cls: type, max_chunk_tokens: int) -> str:
|
||||
def read(self, chunk_size: int, chunker_cls: Chunker, max_chunk_tokens: int) -> str:
|
||||
def get_text():
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from cognee.modules.data.models import Data
|
|||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from uuid import UUID
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
|
||||
|
||||
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
|
||||
|
|
@ -27,7 +28,7 @@ async def extract_chunks_from_documents(
|
|||
documents: list[Document],
|
||||
max_chunk_tokens: int,
|
||||
chunk_size: int = 1024,
|
||||
chunker=TextChunker,
|
||||
chunker: Chunker = TextChunker,
|
||||
) -> AsyncGenerator:
|
||||
"""
|
||||
Extracts chunks of data from a list of documents based on the specified chunking parameters.
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from cognee.modules.data.processing.document_types.TextDocument import TextDocum
|
|||
|
||||
GROUND_TRUTH = {
|
||||
"code.txt": [
|
||||
{"word_count": 205, "len_text": 1024, "cut_type": "sentence_cut"},
|
||||
{"word_count": 104, "len_text": 833, "cut_type": "paragraph_end"},
|
||||
{"word_count": 252, "len_text": 1376, "cut_type": "paragraph_end"},
|
||||
{"word_count": 56, "len_text": 481, "cut_type": "paragraph_end"},
|
||||
],
|
||||
"Natural_language_processing.txt": [
|
||||
{"word_count": 128, "len_text": 984, "cut_type": "paragraph_end"},
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue