chunking token limit
This commit is contained in:
parent
6c2cc907dd
commit
8e947c3965
1 changed files with 83 additions and 4 deletions
|
|
@ -5,10 +5,77 @@ import os
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from docling_core.types.io import DocumentStream
|
from docling_core.types.io import DocumentStream
|
||||||
|
from typing import List
|
||||||
|
import openai
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||||
from utils.document_processing import extract_relevant, process_document_sync
|
from utils.document_processing import extract_relevant, process_document_sync
|
||||||
|
|
||||||
|
def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
||||||
|
"""Get accurate token count using tiktoken"""
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
return len(encoding.encode(text))
|
||||||
|
except KeyError:
|
||||||
|
# Fallback to cl100k_base for unknown models
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return len(encoding.encode(text))
|
||||||
|
|
||||||
|
def chunk_texts_for_embeddings(texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL) -> List[List[str]]:
|
||||||
|
"""
|
||||||
|
Split texts into batches that won't exceed token limits.
|
||||||
|
If max_tokens is None, returns texts as single batch (no splitting).
|
||||||
|
"""
|
||||||
|
if max_tokens is None:
|
||||||
|
return [texts]
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
current_batch = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
text_tokens = get_token_count(text, model)
|
||||||
|
|
||||||
|
# If single text exceeds limit, split it further
|
||||||
|
if text_tokens > max_tokens:
|
||||||
|
# If we have current batch, save it first
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
current_batch = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
# Split the large text into smaller chunks
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
tokens = encoding.encode(text)
|
||||||
|
|
||||||
|
for i in range(0, len(tokens), max_tokens):
|
||||||
|
chunk_tokens = tokens[i:i + max_tokens]
|
||||||
|
chunk_text = encoding.decode(chunk_tokens)
|
||||||
|
batches.append([chunk_text])
|
||||||
|
|
||||||
|
# If adding this text would exceed limit, start new batch
|
||||||
|
elif current_tokens + text_tokens > max_tokens:
|
||||||
|
if current_batch: # Don't add empty batches
|
||||||
|
batches.append(current_batch)
|
||||||
|
current_batch = [text]
|
||||||
|
current_tokens = text_tokens
|
||||||
|
|
||||||
|
# Add to current batch
|
||||||
|
else:
|
||||||
|
current_batch.append(text)
|
||||||
|
current_tokens += text_tokens
|
||||||
|
|
||||||
|
# Add final batch if not empty
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
class DocumentService:
|
class DocumentService:
|
||||||
def __init__(self, process_pool=None):
|
def __init__(self, process_pool=None):
|
||||||
self.process_pool = process_pool
|
self.process_pool = process_pool
|
||||||
|
|
@ -41,8 +108,14 @@ class DocumentService:
|
||||||
slim_doc = extract_relevant(full_doc)
|
slim_doc = extract_relevant(full_doc)
|
||||||
|
|
||||||
texts = [c["text"] for c in slim_doc["chunks"]]
|
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
|
||||||
embeddings = [d.embedding for d in resp.data]
|
# Split into batches to avoid token limits (8191 limit, use 8000 with buffer)
|
||||||
|
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for batch in text_batches:
|
||||||
|
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch)
|
||||||
|
embeddings.extend([d.embedding for d in resp.data])
|
||||||
|
|
||||||
# Index each chunk as a separate document
|
# Index each chunk as a separate document
|
||||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||||
|
|
@ -145,8 +218,14 @@ class DocumentService:
|
||||||
else:
|
else:
|
||||||
# Generate embeddings and index (I/O bound, keep in main process)
|
# Generate embeddings and index (I/O bound, keep in main process)
|
||||||
texts = [c["text"] for c in slim_doc["chunks"]]
|
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
|
||||||
embeddings = [d.embedding for d in resp.data]
|
# Split into batches to avoid token limits (8191 limit, use 8000 with buffer)
|
||||||
|
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for batch in text_batches:
|
||||||
|
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=batch)
|
||||||
|
embeddings.extend([d.embedding for d in resp.data])
|
||||||
|
|
||||||
# Index each chunk
|
# Index each chunk
|
||||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue