From 39980145618b048bdb639a1248428afc9d78c47f Mon Sep 17 00:00:00 2001 From: phact Date: Mon, 13 Oct 2025 11:31:56 -0400 Subject: [PATCH] fix document processing embedding model bug --- src/config/settings.py | 5 +++++ src/main.py | 6 +++--- src/models/processors.py | 9 +++++---- src/services/document_service.py | 9 ++++++--- src/services/search_service.py | 14 ++++++++------ 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/config/settings.py b/src/config/settings.py index 2edb0d69..53025937 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -595,3 +595,8 @@ def get_knowledge_config(): def get_agent_config(): """Get agent configuration.""" return get_openrag_config().agent + + +def get_embedding_model() -> str: + """Return the currently configured embedding model.""" + return get_openrag_config().knowledge.embedding_model diff --git a/src/main.py b/src/main.py index b16a50d5..b1052874 100644 --- a/src/main.py +++ b/src/main.py @@ -53,11 +53,11 @@ from auth_middleware import optional_auth, require_auth # Configuration and setup from config.settings import ( DISABLE_INGEST_WITH_LANGFLOW, - EMBED_MODEL, INDEX_BODY, INDEX_NAME, SESSION_SECRET, clients, + get_embedding_model, is_no_auth_mode, get_openrag_config, ) @@ -505,7 +505,7 @@ async def initialize_services(): openrag_connector_service = ConnectorService( patched_async_client=clients.patched_async_client, process_pool=process_pool, - embed_model=EMBED_MODEL, + embed_model=get_embedding_model(), index_name=INDEX_NAME, task_service=task_service, session_manager=session_manager, @@ -1228,4 +1228,4 @@ if __name__ == "__main__": host="0.0.0.0", port=8000, reload=False, # Disable reload since we're running from main - ) \ No newline at end of file + ) diff --git a/src/models/processors.py b/src/models/processors.py index 972aa640..d6fb7a72 100644 --- a/src/models/processors.py +++ b/src/models/processors.py @@ -163,16 +163,17 @@ class TaskProcessor: docling conversion + embeddings + OpenSearch indexing. Args: - embedding_model: Embedding model to use (defaults to EMBED_MODEL from settings) + embedding_model: Embedding model to use (defaults to the current + embedding model from settings) """ import datetime - from config.settings import INDEX_NAME, EMBED_MODEL, clients + from config.settings import INDEX_NAME, clients, get_embedding_model from services.document_service import chunk_texts_for_embeddings from utils.document_processing import extract_relevant from utils.embedding_fields import get_embedding_field_name, ensure_embedding_field_exists # Use provided embedding model or fall back to default - embedding_model = embedding_model or EMBED_MODEL + embedding_model = embedding_model or get_embedding_model() # Get user's OpenSearch client with JWT for OIDC auth opensearch_client = self.document_service.session_manager.get_user_opensearch_client( @@ -575,7 +576,7 @@ class S3FileProcessor(TaskProcessor): import time import asyncio import datetime - from config.settings import INDEX_NAME, EMBED_MODEL, clients + from config.settings import INDEX_NAME, clients, get_embedding_model from services.document_service import chunk_texts_for_embeddings from utils.document_processing import process_document_sync diff --git a/src/services/document_service.py b/src/services/document_service.py index d596fb25..882b5eaf 100644 --- a/src/services/document_service.py +++ b/src/services/document_service.py @@ -12,12 +12,13 @@ from utils.logging_config import get_logger logger = get_logger(__name__) -from config.settings import clients, INDEX_NAME, EMBED_MODEL +from config.settings import clients, INDEX_NAME, get_embedding_model from utils.document_processing import extract_relevant, process_document_sync -def get_token_count(text: str, model: str = EMBED_MODEL) -> int: +def get_token_count(text: str, model: str = None) -> int: """Get accurate token count using tiktoken""" + model = model or get_embedding_model() try: encoding = tiktoken.encoding_for_model(model) return len(encoding.encode(text)) @@ -28,12 +29,14 @@ def get_token_count(text: str, model: str = EMBED_MODEL) -> int: def chunk_texts_for_embeddings( - texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL + texts: List[str], max_tokens: int = None, model: str = None ) -> List[List[str]]: """ Split texts into batches that won't exceed token limits. If max_tokens is None, returns texts as single batch (no splitting). """ + model = model or get_embedding_model() + if max_tokens is None: return [texts] diff --git a/src/services/search_service.py b/src/services/search_service.py index 8afa11be..8d12d375 100644 --- a/src/services/search_service.py +++ b/src/services/search_service.py @@ -1,7 +1,7 @@ import copy from typing import Any, Dict from agentd.tool_decorator import tool -from config.settings import clients, INDEX_NAME, EMBED_MODEL +from config.settings import clients, INDEX_NAME, get_embedding_model from auth_context import get_auth_context from utils.logging_config import get_logger @@ -24,17 +24,18 @@ class SearchService: Args: query (str): query string to search the corpus embedding_model (str): Optional override for embedding model. - If not provided, uses EMBED_MODEL from config. + If not provided, uses the current embedding + model from configuration. Returns: dict (str, Any): {"results": [chunks]} on success """ from utils.embedding_fields import get_embedding_field_name - # Strategy: Use provided model, or default to EMBED_MODEL - # This assumes documents are embedded with EMBED_MODEL by default + # Strategy: Use provided model, or default to the configured embedding + # model. This assumes documents are embedded with that model by default. # Future enhancement: Could auto-detect available models in corpus - embedding_model = embedding_model or EMBED_MODEL + embedding_model = embedding_model or get_embedding_model() embedding_field_name = get_embedding_field_name(embedding_model) logger.info( @@ -451,7 +452,8 @@ class SearchService: """Public search method for API endpoints Args: - embedding_model: Embedding model to use for search (defaults to EMBED_MODEL) + embedding_model: Embedding model to use for search (defaults to the + currently configured embedding model) """ # Set auth context if provided (for direct API calls) from config.settings import is_no_auth_mode