From 155f7edba9f796fb1e46522a020a500b30d753ad Mon Sep 17 00:00:00 2001 From: "estevez.sebastian@gmail.com" Date: Thu, 31 Jul 2025 13:39:32 -0400 Subject: [PATCH] chunking token limit --- src/services/document_service.py | 87 ++++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 4 deletions(-) diff --git a/src/services/document_service.py b/src/services/document_service.py index 7e648789..98928a01 100644 --- a/src/services/document_service.py +++ b/src/services/document_service.py @@ -5,10 +5,77 @@ import os import aiofiles from io import BytesIO from docling_core.types.io import DocumentStream +from typing import List +import openai +import tiktoken from config.settings import clients, INDEX_NAME, EMBED_MODEL from utils.document_processing import extract_relevant, process_document_sync +def get_token_count(text: str, model: str = EMBED_MODEL) -> int: + """Get accurate token count using tiktoken""" + try: + encoding = tiktoken.encoding_for_model(model) + return len(encoding.encode(text)) + except KeyError: + # Fallback to cl100k_base for unknown models + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(text)) + +def chunk_texts_for_embeddings(texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL) -> List[List[str]]: + """ + Split texts into batches that won't exceed token limits. + If max_tokens is None, returns texts as single batch (no splitting). + """ + if max_tokens is None: + return [texts] + + batches = [] + current_batch = [] + current_tokens = 0 + + for text in texts: + text_tokens = get_token_count(text, model) + + # If single text exceeds limit, split it further + if text_tokens > max_tokens: + # If we have current batch, save it first + if current_batch: + batches.append(current_batch) + current_batch = [] + current_tokens = 0 + + # Split the large text into smaller chunks + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + tokens = encoding.encode(text) + + for i in range(0, len(tokens), max_tokens): + chunk_tokens = tokens[i:i + max_tokens] + chunk_text = encoding.decode(chunk_tokens) + batches.append([chunk_text]) + + # If adding this text would exceed limit, start new batch + elif current_tokens + text_tokens > max_tokens: + if current_batch: # Don't add empty batches + batches.append(current_batch) + current_batch = [text] + current_tokens = text_tokens + + # Add to current batch + else: + current_batch.append(text) + current_tokens += text_tokens + + # Add final batch if not empty + if current_batch: + batches.append(current_batch) + + return batches + class DocumentService: def __init__(self, process_pool=None): self.process_pool = process_pool @@ -41,8 +108,14 @@ class DocumentService: slim_doc = extract_relevant(full_doc) texts = [c["text"] for c in slim_doc["chunks"]] - resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts) - embeddings = [d.embedding for d in resp.data] + + # Split into batches to avoid token limits (8191 limit, use 8000 with buffer) + text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000) + embeddings = [] + + for batch in text_batches: + resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch) + embeddings.extend([d.embedding for d in resp.data]) # Index each chunk as a separate document for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)): @@ -145,8 +218,14 @@ class DocumentService: else: # Generate embeddings and index (I/O bound, keep in main process) texts = [c["text"] for c in slim_doc["chunks"]] - resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts) - embeddings = [d.embedding for d in resp.data] + + # Split into batches to avoid token limits (8191 limit, use 8000 with buffer) + text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000) + embeddings = [] + + for batch in text_batches: + resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch) + embeddings.extend([d.embedding for d in resp.data]) # Index each chunk for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):