Overcome ContextWindowExceededError by checking token count while chunking (#413)

This commit is contained in:
alekszievr 2025-01-07 11:46:46 +01:00 committed by GitHub
parent 5e79dc53c5
commit 4802567871
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 84 additions and 24 deletions

View file

@ -71,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user), Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents), Task(classify_documents),
Task(extract_chunks_from_documents), Task(extract_chunks_from_documents, embedding_model=embedding_engine.model, max_tokens=8192),
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}), Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}),
Task( Task(
summarize_text, summarize_text,

View file

@ -1,7 +1,10 @@
from uuid import uuid5, NAMESPACE_OID from typing import Optional
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from .models.DocumentChunk import DocumentChunk from .models.DocumentChunk import DocumentChunk
from cognee.tasks.chunks import chunk_by_paragraph
class TextChunker(): class TextChunker():
document = None document = None
@ -9,23 +12,34 @@ class TextChunker():
chunk_index = 0 chunk_index = 0
chunk_size = 0 chunk_size = 0
token_count = 0
def __init__(self, document, get_text: callable, chunk_size: int = 1024): def __init__(self, document, get_text: callable, embedding_model: Optional[str] = None, max_tokens: Optional[int] = None, chunk_size: int = 1024):
self.document = document self.document = document
self.max_chunk_size = chunk_size self.max_chunk_size = chunk_size
self.get_text = get_text self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf")
self.embedding_model = embedding_model
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
return word_count_fits and token_count_fits
def read(self): def read(self):
paragraph_chunks = [] paragraph_chunks = []
for content_text in self.get_text(): for content_text in self.get_text():
for chunk_data in chunk_by_paragraph( for chunk_data in chunk_by_paragraph(
content_text, content_text,
self.embedding_model,
self.max_tokens,
self.max_chunk_size, self.max_chunk_size,
batch_paragraphs = True, batch_paragraphs = True,
): ):
if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size: if self.check_word_count_and_token_count(self.chunk_size, self.token_count, chunk_data):
paragraph_chunks.append(chunk_data) paragraph_chunks.append(chunk_data)
self.chunk_size += chunk_data["word_count"] self.chunk_size += chunk_data["word_count"]
self.token_count += chunk_data["token_count"]
else: else:
if len(paragraph_chunks) == 0: if len(paragraph_chunks) == 0:
yield DocumentChunk( yield DocumentChunk(
@ -63,6 +77,7 @@ class TextChunker():
print(e) print(e)
paragraph_chunks = [chunk_data] paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"] self.chunk_size = chunk_data["word_count"]
self.token_count = chunk_data["token_count"]
self.chunk_index += 1 self.chunk_index += 1

View file

@ -1,3 +1,4 @@
from typing import Optional
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -13,5 +14,5 @@ class Document(DataPoint):
"type": "Document" "type": "Document"
} }
def read(self, chunk_size: int, chunker = str) -> str: def read(self, chunk_size: int, embedding_model: Optional[str], max_tokens: Optional[int], chunker = str) -> str:
pass pass

View file

@ -1,6 +1,10 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class ImageDocument(Document): class ImageDocument(Document):
type: str = "image" type: str = "image"
@ -10,11 +14,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return(result.choices[0].message.content) return(result.choices[0].message.content)
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text]) chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text], embedding_model=embedding_model, max_tokens=max_tokens)
yield from chunker.read() yield from chunker.read()

View file

@ -1,11 +1,15 @@
from typing import Optional
from pypdf import PdfReader from pypdf import PdfReader
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -14,7 +18,7 @@ class PdfDocument(Document):
yield page_text yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text, embedding_model=embedding_model, max_tokens=max_tokens)
yield from chunker.read() yield from chunker.read()

View file

@ -1,10 +1,13 @@
from .Document import Document from typing import Optional
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]):
def get_text(): def get_text():
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file: with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
while True: while True:
@ -17,6 +20,6 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text, embedding_model=embedding_model, max_tokens=max_tokens)
yield from chunker.read() yield from chunker.read()

View file

@ -1,14 +1,16 @@
from io import StringIO from io import StringIO
from typing import Optional
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from cognee.modules.data.exceptions import UnstructuredLibraryImportError from cognee.modules.data.exceptions import UnstructuredLibraryImportError
from .Document import Document
class UnstructuredDocument(Document): class UnstructuredDocument(Document):
type: str = "unstructured" type: str = "unstructured"
def read(self, chunk_size: int, chunker = str) -> str: def read(self, chunk_size: int, chunker: str, embedding_model:Optional[str], max_tokens: Optional[int]) -> str:
def get_text(): def get_text():
try: try:
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
@ -27,6 +29,6 @@ class UnstructuredDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text, embedding_model=embedding_model, max_tokens=max_tokens)
yield from chunker.read() yield from chunker.read()

View file

@ -1,8 +1,18 @@
from uuid import uuid5, NAMESPACE_OID from typing import Any, Dict, Iterator, Optional, Union
from typing import Dict, Any, Iterator from uuid import NAMESPACE_OID, uuid5
import tiktoken
from .chunk_by_sentence import chunk_by_sentence from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True) -> Iterator[Dict[str, Any]]:
def chunk_by_paragraph(
data: str,
embedding_model: Optional[str],
max_tokens: Optional[Union[int, float]],
paragraph_length: int = 1024,
batch_paragraphs: bool = True
) -> Iterator[Dict[str, Any]]:
""" """
Chunks text by paragraph while preserving exact text reconstruction capability. Chunks text by paragraph while preserving exact text reconstruction capability.
When chunks are joined with empty string "", they reproduce the original text exactly. When chunks are joined with empty string "", they reproduce the original text exactly.
@ -12,14 +22,22 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
chunk_index = 0 chunk_index = 0
paragraph_ids = [] paragraph_ids = []
last_cut_type = None last_cut_type = None
current_token_count = 0
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length): for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length):
# Check if this sentence would exceed length limit # Check if this sentence would exceed length limit
if current_word_count > 0 and current_word_count + word_count > paragraph_length: if embedding_model:
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))
else:
token_count = 0
if current_word_count > 0 and (current_word_count + word_count > paragraph_length or current_token_count + token_count > max_tokens):
# Yield current chunk # Yield current chunk
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_index": chunk_index, "chunk_index": chunk_index,
@ -32,11 +50,13 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
paragraph_ids = [] paragraph_ids = []
current_chunk = "" current_chunk = ""
current_word_count = 0 current_word_count = 0
current_token_count = 0
chunk_index += 1 chunk_index += 1
paragraph_ids.append(paragraph_id) paragraph_ids.append(paragraph_id)
current_chunk += sentence current_chunk += sentence
current_word_count += word_count current_word_count += word_count
current_token_count += token_count
# Handle end of paragraph # Handle end of paragraph
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs: if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
@ -44,6 +64,7 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"chunk_index": chunk_index, "chunk_index": chunk_index,
@ -53,6 +74,7 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
paragraph_ids = [] paragraph_ids = []
current_chunk = "" current_chunk = ""
current_word_count = 0 current_word_count = 0
current_token_count = 0
chunk_index += 1 chunk_index += 1
last_cut_type = end_type last_cut_type = end_type
@ -62,6 +84,7 @@ def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_index": chunk_index, "chunk_index": chunk_index,

View file

@ -1,7 +1,15 @@
from typing import Optional
from cognee.modules.data.processing.document_types.Document import Document from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'): async def extract_chunks_from_documents(
documents: list[Document],
chunk_size: int = 1024,
chunker='text_chunker',
embedding_model: Optional[str] = None,
max_tokens: Optional[int] = None,
):
for document in documents: for document in documents:
for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker): for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker, embedding_model=embedding_model, max_tokens=max_tokens):
yield document_chunk yield document_chunk