fix document processing embedding model bug
This commit is contained in:
parent
5b32a0ce12
commit
3998014561
5 changed files with 27 additions and 16 deletions
|
|
@ -595,3 +595,8 @@ def get_knowledge_config():
|
||||||
def get_agent_config():
|
def get_agent_config():
|
||||||
"""Get agent configuration."""
|
"""Get agent configuration."""
|
||||||
return get_openrag_config().agent
|
return get_openrag_config().agent
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model() -> str:
|
||||||
|
"""Return the currently configured embedding model."""
|
||||||
|
return get_openrag_config().knowledge.embedding_model
|
||||||
|
|
|
||||||
|
|
@ -53,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
|
||||||
# Configuration and setup
|
# Configuration and setup
|
||||||
from config.settings import (
|
from config.settings import (
|
||||||
DISABLE_INGEST_WITH_LANGFLOW,
|
DISABLE_INGEST_WITH_LANGFLOW,
|
||||||
EMBED_MODEL,
|
|
||||||
INDEX_BODY,
|
INDEX_BODY,
|
||||||
INDEX_NAME,
|
INDEX_NAME,
|
||||||
SESSION_SECRET,
|
SESSION_SECRET,
|
||||||
clients,
|
clients,
|
||||||
|
get_embedding_model,
|
||||||
is_no_auth_mode,
|
is_no_auth_mode,
|
||||||
get_openrag_config,
|
get_openrag_config,
|
||||||
)
|
)
|
||||||
|
|
@ -505,7 +505,7 @@ async def initialize_services():
|
||||||
openrag_connector_service = ConnectorService(
|
openrag_connector_service = ConnectorService(
|
||||||
patched_async_client=clients.patched_async_client,
|
patched_async_client=clients.patched_async_client,
|
||||||
process_pool=process_pool,
|
process_pool=process_pool,
|
||||||
embed_model=EMBED_MODEL,
|
embed_model=get_embedding_model(),
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
task_service=task_service,
|
task_service=task_service,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
|
|
@ -1228,4 +1228,4 @@ if __name__ == "__main__":
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=8000,
|
port=8000,
|
||||||
reload=False, # Disable reload since we're running from main
|
reload=False, # Disable reload since we're running from main
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -163,16 +163,17 @@ class TaskProcessor:
|
||||||
docling conversion + embeddings + OpenSearch indexing.
|
docling conversion + embeddings + OpenSearch indexing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding_model: Embedding model to use (defaults to EMBED_MODEL from settings)
|
embedding_model: Embedding model to use (defaults to the current
|
||||||
|
embedding model from settings)
|
||||||
"""
|
"""
|
||||||
import datetime
|
import datetime
|
||||||
from config.settings import INDEX_NAME, EMBED_MODEL, clients
|
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||||
from services.document_service import chunk_texts_for_embeddings
|
from services.document_service import chunk_texts_for_embeddings
|
||||||
from utils.document_processing import extract_relevant
|
from utils.document_processing import extract_relevant
|
||||||
from utils.embedding_fields import get_embedding_field_name, ensure_embedding_field_exists
|
from utils.embedding_fields import get_embedding_field_name, ensure_embedding_field_exists
|
||||||
|
|
||||||
# Use provided embedding model or fall back to default
|
# Use provided embedding model or fall back to default
|
||||||
embedding_model = embedding_model or EMBED_MODEL
|
embedding_model = embedding_model or get_embedding_model()
|
||||||
|
|
||||||
# Get user's OpenSearch client with JWT for OIDC auth
|
# Get user's OpenSearch client with JWT for OIDC auth
|
||||||
opensearch_client = self.document_service.session_manager.get_user_opensearch_client(
|
opensearch_client = self.document_service.session_manager.get_user_opensearch_client(
|
||||||
|
|
@ -575,7 +576,7 @@ class S3FileProcessor(TaskProcessor):
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
from config.settings import INDEX_NAME, EMBED_MODEL, clients
|
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||||
from services.document_service import chunk_texts_for_embeddings
|
from services.document_service import chunk_texts_for_embeddings
|
||||||
from utils.document_processing import process_document_sync
|
from utils.document_processing import process_document_sync
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,12 +12,13 @@ from utils.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
from config.settings import clients, INDEX_NAME, get_embedding_model
|
||||||
from utils.document_processing import extract_relevant, process_document_sync
|
from utils.document_processing import extract_relevant, process_document_sync
|
||||||
|
|
||||||
|
|
||||||
def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
def get_token_count(text: str, model: str = None) -> int:
|
||||||
"""Get accurate token count using tiktoken"""
|
"""Get accurate token count using tiktoken"""
|
||||||
|
model = model or get_embedding_model()
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
return len(encoding.encode(text))
|
return len(encoding.encode(text))
|
||||||
|
|
@ -28,12 +29,14 @@ def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
||||||
|
|
||||||
|
|
||||||
def chunk_texts_for_embeddings(
|
def chunk_texts_for_embeddings(
|
||||||
texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL
|
texts: List[str], max_tokens: int = None, model: str = None
|
||||||
) -> List[List[str]]:
|
) -> List[List[str]]:
|
||||||
"""
|
"""
|
||||||
Split texts into batches that won't exceed token limits.
|
Split texts into batches that won't exceed token limits.
|
||||||
If max_tokens is None, returns texts as single batch (no splitting).
|
If max_tokens is None, returns texts as single batch (no splitting).
|
||||||
"""
|
"""
|
||||||
|
model = model or get_embedding_model()
|
||||||
|
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
return [texts]
|
return [texts]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from agentd.tool_decorator import tool
|
from agentd.tool_decorator import tool
|
||||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
from config.settings import clients, INDEX_NAME, get_embedding_model
|
||||||
from auth_context import get_auth_context
|
from auth_context import get_auth_context
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
|
|
@ -24,17 +24,18 @@ class SearchService:
|
||||||
Args:
|
Args:
|
||||||
query (str): query string to search the corpus
|
query (str): query string to search the corpus
|
||||||
embedding_model (str): Optional override for embedding model.
|
embedding_model (str): Optional override for embedding model.
|
||||||
If not provided, uses EMBED_MODEL from config.
|
If not provided, uses the current embedding
|
||||||
|
model from configuration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict (str, Any): {"results": [chunks]} on success
|
dict (str, Any): {"results": [chunks]} on success
|
||||||
"""
|
"""
|
||||||
from utils.embedding_fields import get_embedding_field_name
|
from utils.embedding_fields import get_embedding_field_name
|
||||||
|
|
||||||
# Strategy: Use provided model, or default to EMBED_MODEL
|
# Strategy: Use provided model, or default to the configured embedding
|
||||||
# This assumes documents are embedded with EMBED_MODEL by default
|
# model. This assumes documents are embedded with that model by default.
|
||||||
# Future enhancement: Could auto-detect available models in corpus
|
# Future enhancement: Could auto-detect available models in corpus
|
||||||
embedding_model = embedding_model or EMBED_MODEL
|
embedding_model = embedding_model or get_embedding_model()
|
||||||
embedding_field_name = get_embedding_field_name(embedding_model)
|
embedding_field_name = get_embedding_field_name(embedding_model)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -451,7 +452,8 @@ class SearchService:
|
||||||
"""Public search method for API endpoints
|
"""Public search method for API endpoints
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding_model: Embedding model to use for search (defaults to EMBED_MODEL)
|
embedding_model: Embedding model to use for search (defaults to the
|
||||||
|
currently configured embedding model)
|
||||||
"""
|
"""
|
||||||
# Set auth context if provided (for direct API calls)
|
# Set auth context if provided (for direct API calls)
|
||||||
from config.settings import is_no_auth_mode
|
from config.settings import is_no_auth_mode
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue