fix document processing embedding model bug
This commit is contained in:
parent
5b32a0ce12
commit
3998014561
5 changed files with 27 additions and 16 deletions
|
|
@ -595,3 +595,8 @@ def get_knowledge_config():
|
|||
def get_agent_config():
|
||||
"""Get agent configuration."""
|
||||
return get_openrag_config().agent
|
||||
|
||||
|
||||
def get_embedding_model() -> str:
|
||||
"""Return the currently configured embedding model."""
|
||||
return get_openrag_config().knowledge.embedding_model
|
||||
|
|
|
|||
|
|
@ -53,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
|
|||
# Configuration and setup
|
||||
from config.settings import (
|
||||
DISABLE_INGEST_WITH_LANGFLOW,
|
||||
EMBED_MODEL,
|
||||
INDEX_BODY,
|
||||
INDEX_NAME,
|
||||
SESSION_SECRET,
|
||||
clients,
|
||||
get_embedding_model,
|
||||
is_no_auth_mode,
|
||||
get_openrag_config,
|
||||
)
|
||||
|
|
@ -505,7 +505,7 @@ async def initialize_services():
|
|||
openrag_connector_service = ConnectorService(
|
||||
patched_async_client=clients.patched_async_client,
|
||||
process_pool=process_pool,
|
||||
embed_model=EMBED_MODEL,
|
||||
embed_model=get_embedding_model(),
|
||||
index_name=INDEX_NAME,
|
||||
task_service=task_service,
|
||||
session_manager=session_manager,
|
||||
|
|
@ -1228,4 +1228,4 @@ if __name__ == "__main__":
|
|||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=False, # Disable reload since we're running from main
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -163,16 +163,17 @@ class TaskProcessor:
|
|||
docling conversion + embeddings + OpenSearch indexing.
|
||||
|
||||
Args:
|
||||
embedding_model: Embedding model to use (defaults to EMBED_MODEL from settings)
|
||||
embedding_model: Embedding model to use (defaults to the current
|
||||
embedding model from settings)
|
||||
"""
|
||||
import datetime
|
||||
from config.settings import INDEX_NAME, EMBED_MODEL, clients
|
||||
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||
from services.document_service import chunk_texts_for_embeddings
|
||||
from utils.document_processing import extract_relevant
|
||||
from utils.embedding_fields import get_embedding_field_name, ensure_embedding_field_exists
|
||||
|
||||
# Use provided embedding model or fall back to default
|
||||
embedding_model = embedding_model or EMBED_MODEL
|
||||
embedding_model = embedding_model or get_embedding_model()
|
||||
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.document_service.session_manager.get_user_opensearch_client(
|
||||
|
|
@ -575,7 +576,7 @@ class S3FileProcessor(TaskProcessor):
|
|||
import time
|
||||
import asyncio
|
||||
import datetime
|
||||
from config.settings import INDEX_NAME, EMBED_MODEL, clients
|
||||
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||
from services.document_service import chunk_texts_for_embeddings
|
||||
from utils.document_processing import process_document_sync
|
||||
|
||||
|
|
|
|||
|
|
@ -12,12 +12,13 @@ from utils.logging_config import get_logger
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
from config.settings import clients, INDEX_NAME, get_embedding_model
|
||||
from utils.document_processing import extract_relevant, process_document_sync
|
||||
|
||||
|
||||
def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
||||
def get_token_count(text: str, model: str = None) -> int:
|
||||
"""Get accurate token count using tiktoken"""
|
||||
model = model or get_embedding_model()
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(text))
|
||||
|
|
@ -28,12 +29,14 @@ def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
|||
|
||||
|
||||
def chunk_texts_for_embeddings(
|
||||
texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL
|
||||
texts: List[str], max_tokens: int = None, model: str = None
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Split texts into batches that won't exceed token limits.
|
||||
If max_tokens is None, returns texts as single batch (no splitting).
|
||||
"""
|
||||
model = model or get_embedding_model()
|
||||
|
||||
if max_tokens is None:
|
||||
return [texts]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
from typing import Any, Dict
|
||||
from agentd.tool_decorator import tool
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
from config.settings import clients, INDEX_NAME, get_embedding_model
|
||||
from auth_context import get_auth_context
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
|
|
@ -24,17 +24,18 @@ class SearchService:
|
|||
Args:
|
||||
query (str): query string to search the corpus
|
||||
embedding_model (str): Optional override for embedding model.
|
||||
If not provided, uses EMBED_MODEL from config.
|
||||
If not provided, uses the current embedding
|
||||
model from configuration.
|
||||
|
||||
Returns:
|
||||
dict (str, Any): {"results": [chunks]} on success
|
||||
"""
|
||||
from utils.embedding_fields import get_embedding_field_name
|
||||
|
||||
# Strategy: Use provided model, or default to EMBED_MODEL
|
||||
# This assumes documents are embedded with EMBED_MODEL by default
|
||||
# Strategy: Use provided model, or default to the configured embedding
|
||||
# model. This assumes documents are embedded with that model by default.
|
||||
# Future enhancement: Could auto-detect available models in corpus
|
||||
embedding_model = embedding_model or EMBED_MODEL
|
||||
embedding_model = embedding_model or get_embedding_model()
|
||||
embedding_field_name = get_embedding_field_name(embedding_model)
|
||||
|
||||
logger.info(
|
||||
|
|
@ -451,7 +452,8 @@ class SearchService:
|
|||
"""Public search method for API endpoints
|
||||
|
||||
Args:
|
||||
embedding_model: Embedding model to use for search (defaults to EMBED_MODEL)
|
||||
embedding_model: Embedding model to use for search (defaults to the
|
||||
currently configured embedding model)
|
||||
"""
|
||||
# Set auth context if provided (for direct API calls)
|
||||
from config.settings import is_no_auth_mode
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue