chunking token limit

This commit is contained in:
estevez.sebastian@gmail.com 2025-07-31 13:39:32 -04:00
parent 93b72a19be
commit 155f7edba9

View file

@ -5,10 +5,77 @@ import os
import aiofiles
from io import BytesIO
from docling_core.types.io import DocumentStream
from typing import List
import openai
import tiktoken
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from utils.document_processing import extract_relevant, process_document_sync
def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
"""Get accurate token count using tiktoken"""
try:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
except KeyError:
# Fallback to cl100k_base for unknown models
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
def chunk_texts_for_embeddings(texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL) -> List[List[str]]:
"""
Split texts into batches that won't exceed token limits.
If max_tokens is None, returns texts as single batch (no splitting).
"""
if max_tokens is None:
return [texts]
batches = []
current_batch = []
current_tokens = 0
for text in texts:
text_tokens = get_token_count(text, model)
# If single text exceeds limit, split it further
if text_tokens > max_tokens:
# If we have current batch, save it first
if current_batch:
batches.append(current_batch)
current_batch = []
current_tokens = 0
# Split the large text into smaller chunks
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
for i in range(0, len(tokens), max_tokens):
chunk_tokens = tokens[i:i + max_tokens]
chunk_text = encoding.decode(chunk_tokens)
batches.append([chunk_text])
# If adding this text would exceed limit, start new batch
elif current_tokens + text_tokens > max_tokens:
if current_batch: # Don't add empty batches
batches.append(current_batch)
current_batch = [text]
current_tokens = text_tokens
# Add to current batch
else:
current_batch.append(text)
current_tokens += text_tokens
# Add final batch if not empty
if current_batch:
batches.append(current_batch)
return batches
class DocumentService:
def __init__(self, process_pool=None):
self.process_pool = process_pool
@ -41,8 +108,14 @@ class DocumentService:
slim_doc = extract_relevant(full_doc)
texts = [c["text"] for c in slim_doc["chunks"]]
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
embeddings = [d.embedding for d in resp.data]
# Split into batches to avoid token limits (8191 limit, use 8000 with buffer)
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
embeddings = []
for batch in text_batches:
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch)
embeddings.extend([d.embedding for d in resp.data])
# Index each chunk as a separate document
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
@ -145,8 +218,14 @@ class DocumentService:
else:
# Generate embeddings and index (I/O bound, keep in main process)
texts = [c["text"] for c in slim_doc["chunks"]]
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
embeddings = [d.embedding for d in resp.data]
# Split into batches to avoid token limits (8191 limit, use 8000 with buffer)
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
embeddings = []
for batch in text_batches:
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch)
embeddings.extend([d.embedding for d in resp.data])
# Index each chunk
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):