Overcome ContextWindowExceededError by checking token count while chunking (#413)
This commit is contained in:
parent
5e79dc53c5
commit
4802567871
9 changed files with 84 additions and 24 deletions
|
|
@ -71,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
|
||||
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents),
|
||||
Task(extract_chunks_from_documents, embedding_model=embedding_engine.model, max_tokens=8192),
|
||||
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}),
|
||||
Task(
|
||||
summarize_text,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
|
||||
class TextChunker():
|
||||
document = None
|
||||
|
|
@ -9,23 +12,34 @@ class TextChunker():
|
|||
|
||||
chunk_index = 0
|
||||
chunk_size = 0
|
||||
token_count = 0
|
||||
|
||||
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
|
||||
def __init__(self, document, get_text: callable, embedding_model: Optional[str] = None, max_tokens: Optional[int] = None, chunk_size: int = 1024):
|
||||
self.document = document
|
||||
self.max_chunk_size = chunk_size
|
||||
self.get_text = get_text
|
||||
self.max_tokens = max_tokens if max_tokens else float("inf")
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
|
||||
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
|
||||
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
|
||||
return word_count_fits and token_count_fits
|
||||
|
||||
def read(self):
|
||||
paragraph_chunks = []
|
||||
for content_text in self.get_text():
|
||||
for chunk_data in chunk_by_paragraph(
|
||||
content_text,
|
||||
self.embedding_model,
|
||||
self.max_tokens,
|
||||
self.max_chunk_size,
|
||||
batch_paragraphs = True,
|
||||
):
|
||||
if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size:
|
||||
if self.check_word_count_and_token_count(self.chunk_size, self.token_count, chunk_data):
|
||||
paragraph_chunks.append(chunk_data)
|
||||
self.chunk_size += chunk_data["word_count"]
|
||||
self.token_count += chunk_data["token_count"]
|
||||
else:
|
||||
if len(paragraph_chunks) == 0:
|
||||
yield DocumentChunk(
|
||||
|
|
@ -63,6 +77,7 @@ class TextChunker():
|
|||
print(e)
|
||||
paragraph_chunks = [chunk_data]
|
||||
self.chunk_size = chunk_data["word_count"]
|
||||
self.token_count = chunk_data["token_count"]
|
||||
|
||||
self.chunk_index += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -13,5 +14,5 @@ class Document(DataPoint):
|
|||
"type": "Document"
|
||||
}
|
||||
|
||||
def read(self, chunk_size: int, chunker = str) -> str:
|
||||
def read(self, chunk_size: int, embedding_model: Optional[str], max_tokens: Optional[int], chunker = str) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from .Document import Document
|
||||
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class ImageDocument(Document):
|
||||
type: str = "image"
|
||||
|
|
@ -10,11 +14,11 @@ class ImageDocument(Document):
|
|||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return(result.choices[0].message.content)
|
||||
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text], embedding_model=embedding_model, max_tokens=max_tokens)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
from typing import Optional
|
||||
|
||||
from pypdf import PdfReader
|
||||
from .Document import Document
|
||||
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
@ -14,7 +18,7 @@ class PdfDocument(Document):
|
|||
yield page_text
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text, embedding_model=embedding_model, max_tokens=max_tokens)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from .Document import Document
|
||||
from typing import Optional
|
||||
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]):
|
||||
def get_text():
|
||||
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
|
||||
while True:
|
||||
|
|
@ -17,6 +20,6 @@ class TextDocument(Document):
|
|||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text, embedding_model=embedding_model, max_tokens=max_tokens)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
||||
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class UnstructuredDocument(Document):
|
||||
type: str = "unstructured"
|
||||
|
||||
def read(self, chunk_size: int, chunker = str) -> str:
|
||||
def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]) -> str:
|
||||
def get_text():
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
|
|
@ -27,6 +29,6 @@ class UnstructuredDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text, embedding_model=embedding_model, max_tokens=max_tokens)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Dict, Any, Iterator
|
||||
from typing import Any, Dict, Iterator, Optional, Union
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
import tiktoken
|
||||
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
|
||||
def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True) -> Iterator[Dict[str, Any]]:
|
||||
|
||||
def chunk_by_paragraph(
|
||||
data: str,
|
||||
embedding_model: Optional[str],
|
||||
max_tokens: Optional[Union[int, float]],
|
||||
paragraph_length: int = 1024,
|
||||
batch_paragraphs: bool = True
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Chunks text by paragraph while preserving exact text reconstruction capability.
|
||||
When chunks are joined with empty string "", they reproduce the original text exactly.
|
||||
|
|
@ -12,14 +22,22 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
|
|||
chunk_index = 0
|
||||
paragraph_ids = []
|
||||
last_cut_type = None
|
||||
current_token_count = 0
|
||||
|
||||
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length):
|
||||
# Check if this sentence would exceed length limit
|
||||
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
|
||||
if embedding_model:
|
||||
tokenizer = tiktoken.encoding_for_model(embedding_model)
|
||||
token_count = len(tokenizer.encode(sentence))
|
||||
else:
|
||||
token_count = 0
|
||||
|
||||
if current_word_count > 0 and (current_word_count + word_count > paragraph_length or current_token_count + token_count > max_tokens):
|
||||
# Yield current chunk
|
||||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"word_count": current_word_count,
|
||||
"token_count": current_token_count,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"paragraph_ids": paragraph_ids,
|
||||
"chunk_index": chunk_index,
|
||||
|
|
@ -32,11 +50,13 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
|
|||
paragraph_ids = []
|
||||
current_chunk = ""
|
||||
current_word_count = 0
|
||||
current_token_count = 0
|
||||
chunk_index += 1
|
||||
|
||||
paragraph_ids.append(paragraph_id)
|
||||
current_chunk += sentence
|
||||
current_word_count += word_count
|
||||
current_token_count += token_count
|
||||
|
||||
# Handle end of paragraph
|
||||
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
|
||||
|
|
@ -44,6 +64,7 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
|
|||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"word_count": current_word_count,
|
||||
"token_count": current_token_count,
|
||||
"paragraph_ids": paragraph_ids,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"chunk_index": chunk_index,
|
||||
|
|
@ -53,6 +74,7 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
|
|||
paragraph_ids = []
|
||||
current_chunk = ""
|
||||
current_word_count = 0
|
||||
current_token_count = 0
|
||||
chunk_index += 1
|
||||
|
||||
last_cut_type = end_type
|
||||
|
|
@ -62,6 +84,7 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
|
|||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"word_count": current_word_count,
|
||||
"token_count": current_token_count,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"paragraph_ids": paragraph_ids,
|
||||
"chunk_index": chunk_index,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
|
||||
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'):
|
||||
async def extract_chunks_from_documents(
|
||||
documents: list[Document],
|
||||
chunk_size: int = 1024,
|
||||
chunker='text_chunker',
|
||||
embedding_model: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
for document in documents:
|
||||
for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker):
|
||||
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker, embedding_model=embedding_model, max_tokens=max_tokens):
|
||||
yield document_chunk
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue