fix document processing embedding model bug

This commit is contained in:
phact 2025-10-13 11:31:56 -04:00
parent 5b32a0ce12
commit 3998014561
5 changed files with 27 additions and 16 deletions

View file

@ -595,3 +595,8 @@ def get_knowledge_config():
def get_agent_config():
"""Get agent configuration."""
return get_openrag_config().agent
def get_embedding_model() -> str:
"""Return the currently configured embedding model."""
return get_openrag_config().knowledge.embedding_model

View file

@ -53,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
# Configuration and setup
from config.settings import (
DISABLE_INGEST_WITH_LANGFLOW,
EMBED_MODEL,
INDEX_BODY,
INDEX_NAME,
SESSION_SECRET,
clients,
get_embedding_model,
is_no_auth_mode,
get_openrag_config,
)
@ -505,7 +505,7 @@ async def initialize_services():
openrag_connector_service = ConnectorService(
patched_async_client=clients.patched_async_client,
process_pool=process_pool,
embed_model=EMBED_MODEL,
embed_model=get_embedding_model(),
index_name=INDEX_NAME,
task_service=task_service,
session_manager=session_manager,
@ -1228,4 +1228,4 @@ if __name__ == "__main__":
host="0.0.0.0",
port=8000,
reload=False, # Disable reload since we're running from main
)
)

View file

@ -163,16 +163,17 @@ class TaskProcessor:
docling conversion + embeddings + OpenSearch indexing.
Args:
embedding_model: Embedding model to use (defaults to EMBED_MODEL from settings)
embedding_model: Embedding model to use (defaults to the current
embedding model from settings)
"""
import datetime
from config.settings import INDEX_NAME, EMBED_MODEL, clients
from config.settings import INDEX_NAME, clients, get_embedding_model
from services.document_service import chunk_texts_for_embeddings
from utils.document_processing import extract_relevant
from utils.embedding_fields import get_embedding_field_name, ensure_embedding_field_exists
# Use provided embedding model or fall back to default
embedding_model = embedding_model or EMBED_MODEL
embedding_model = embedding_model or get_embedding_model()
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = self.document_service.session_manager.get_user_opensearch_client(
@ -575,7 +576,7 @@ class S3FileProcessor(TaskProcessor):
import time
import asyncio
import datetime
from config.settings import INDEX_NAME, EMBED_MODEL, clients
from config.settings import INDEX_NAME, clients, get_embedding_model
from services.document_service import chunk_texts_for_embeddings
from utils.document_processing import process_document_sync

View file

@ -12,12 +12,13 @@ from utils.logging_config import get_logger
logger = get_logger(__name__)
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from config.settings import clients, INDEX_NAME, get_embedding_model
from utils.document_processing import extract_relevant, process_document_sync
def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
def get_token_count(text: str, model: str = None) -> int:
"""Get accurate token count using tiktoken"""
model = model or get_embedding_model()
try:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
@ -28,12 +29,14 @@ def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
def chunk_texts_for_embeddings(
texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL
texts: List[str], max_tokens: int = None, model: str = None
) -> List[List[str]]:
"""
Split texts into batches that won't exceed token limits.
If max_tokens is None, returns texts as single batch (no splitting).
"""
model = model or get_embedding_model()
if max_tokens is None:
return [texts]

View file

@ -1,7 +1,7 @@
import copy
from typing import Any, Dict
from agentd.tool_decorator import tool
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from config.settings import clients, INDEX_NAME, get_embedding_model
from auth_context import get_auth_context
from utils.logging_config import get_logger
@ -24,17 +24,18 @@ class SearchService:
Args:
query (str): query string to search the corpus
embedding_model (str): Optional override for embedding model.
If not provided, uses EMBED_MODEL from config.
If not provided, uses the current embedding
model from configuration.
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
from utils.embedding_fields import get_embedding_field_name
# Strategy: Use provided model, or default to EMBED_MODEL
# This assumes documents are embedded with EMBED_MODEL by default
# Strategy: Use provided model, or default to the configured embedding
# model. This assumes documents are embedded with that model by default.
# Future enhancement: Could auto-detect available models in corpus
embedding_model = embedding_model or EMBED_MODEL
embedding_model = embedding_model or get_embedding_model()
embedding_field_name = get_embedding_field_name(embedding_model)
logger.info(
@ -451,7 +452,8 @@ class SearchService:
"""Public search method for API endpoints
Args:
embedding_model: Embedding model to use for search (defaults to EMBED_MODEL)
embedding_model: Embedding model to use for search (defaults to the
currently configured embedding model)
"""
# Set auth context if provided (for direct API calls)
from config.settings import is_no_auth_mode