chunking token limit
This commit is contained in:
parent
93b72a19be
commit
155f7edba9
1 changed files with 83 additions and 4 deletions
|
|
@ -5,10 +5,77 @@ import os
|
|||
import aiofiles
|
||||
from io import BytesIO
|
||||
from docling_core.types.io import DocumentStream
|
||||
from typing import List
|
||||
import openai
|
||||
import tiktoken
|
||||
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
from utils.document_processing import extract_relevant, process_document_sync
|
||||
|
||||
def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
||||
"""Get accurate token count using tiktoken"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(text))
|
||||
except KeyError:
|
||||
# Fallback to cl100k_base for unknown models
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(text))
|
||||
|
||||
def chunk_texts_for_embeddings(texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL) -> List[List[str]]:
|
||||
"""
|
||||
Split texts into batches that won't exceed token limits.
|
||||
If max_tokens is None, returns texts as single batch (no splitting).
|
||||
"""
|
||||
if max_tokens is None:
|
||||
return [texts]
|
||||
|
||||
batches = []
|
||||
current_batch = []
|
||||
current_tokens = 0
|
||||
|
||||
for text in texts:
|
||||
text_tokens = get_token_count(text, model)
|
||||
|
||||
# If single text exceeds limit, split it further
|
||||
if text_tokens > max_tokens:
|
||||
# If we have current batch, save it first
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
current_tokens = 0
|
||||
|
||||
# Split the large text into smaller chunks
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
tokens = encoding.encode(text)
|
||||
|
||||
for i in range(0, len(tokens), max_tokens):
|
||||
chunk_tokens = tokens[i:i + max_tokens]
|
||||
chunk_text = encoding.decode(chunk_tokens)
|
||||
batches.append([chunk_text])
|
||||
|
||||
# If adding this text would exceed limit, start new batch
|
||||
elif current_tokens + text_tokens > max_tokens:
|
||||
if current_batch: # Don't add empty batches
|
||||
batches.append(current_batch)
|
||||
current_batch = [text]
|
||||
current_tokens = text_tokens
|
||||
|
||||
# Add to current batch
|
||||
else:
|
||||
current_batch.append(text)
|
||||
current_tokens += text_tokens
|
||||
|
||||
# Add final batch if not empty
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
class DocumentService:
|
||||
def __init__(self, process_pool=None):
|
||||
self.process_pool = process_pool
|
||||
|
|
@ -41,8 +108,14 @@ class DocumentService:
|
|||
slim_doc = extract_relevant(full_doc)
|
||||
|
||||
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||||
embeddings = [d.embedding for d in resp.data]
|
||||
|
||||
# Split into batches to avoid token limits (8191 limit, use 8000 with buffer)
|
||||
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
|
||||
embeddings = []
|
||||
|
||||
for batch in text_batches:
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch)
|
||||
embeddings.extend([d.embedding for d in resp.data])
|
||||
|
||||
# Index each chunk as a separate document
|
||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||
|
|
@ -145,8 +218,14 @@ class DocumentService:
|
|||
else:
|
||||
# Generate embeddings and index (I/O bound, keep in main process)
|
||||
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||||
embeddings = [d.embedding for d in resp.data]
|
||||
|
||||
# Split into batches to avoid token limits (8191 limit, use 8000 with buffer)
|
||||
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
|
||||
embeddings = []
|
||||
|
||||
for batch in text_batches:
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch)
|
||||
embeddings.extend([d.embedding for d in resp.data])
|
||||
|
||||
# Index each chunk
|
||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue