openrag/src/services/search_service.py
Edwin Jose 13cc00043b Add robust retry and mocking for embedding API calls
Introduces exponential backoff and Retry-After header handling for embedding API rate limits in both processors and search service. Adds CI fixture to mock OpenAI embeddings, avoiding real API calls during tests. Updates Makefile to document and set MOCK_EMBEDDINGS for integration and CI test targets.
2025-11-26 15:54:38 -05:00

523 lines
22 KiB
Python

import copy
from typing import Any, Dict
from agentd.tool_decorator import tool
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model
from auth_context import get_auth_context
from utils.logging_config import get_logger
logger = get_logger(__name__)
MAX_EMBED_RETRIES = 3
EMBED_RETRY_INITIAL_DELAY = 1.0
EMBED_RETRY_MAX_DELAY = 8.0
class SearchService:
def __init__(self, session_manager=None):
self.session_manager = session_manager
@tool
async def search_tool(self, query: str, embedding_model: str = None) -> Dict[str, Any]:
"""
Use this tool to search for documents relevant to the query.
Args:
query (str): query string to search the corpus
embedding_model (str): Optional override for embedding model.
If not provided, uses the current embedding
model from configuration.
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
from utils.embedding_fields import get_embedding_field_name
# Strategy: Use provided model, or default to the configured embedding
# model. This assumes documents are embedded with that model by default.
# Future enhancement: Could auto-detect available models in corpus.
embedding_model = embedding_model or get_embedding_model() or EMBED_MODEL
embedding_field_name = get_embedding_field_name(embedding_model)
logger.info(
"Search with embedding model",
embedding_model=embedding_model,
embedding_field=embedding_field_name,
query_preview=query[:50] if query else None,
)
# Get authentication context from the current async context
user_id, jwt_token = get_auth_context()
# Get search filters, limit, and score threshold from context
from auth_context import (
get_search_filters,
get_search_limit,
get_score_threshold,
)
filters = get_search_filters() or {}
limit = get_search_limit()
score_threshold = get_score_threshold()
# Detect wildcard request ("*") to return global facets/stats without semantic search
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
# Get available embedding models from corpus
query_embeddings = {}
available_models = []
opensearch_client = self.session_manager.get_user_opensearch_client(
user_id, jwt_token
)
if not is_wildcard_match_all:
# Build filter clauses first so we can use them in model detection
filter_clauses = []
if filters:
# Map frontend filter names to backend field names
field_mapping = {
"data_sources": "filename",
"document_types": "mimetype",
"owners": "owner_name.keyword",
"connector_types": "connector_type",
}
for filter_key, values in filters.items():
if values is not None and isinstance(values, list):
# Map frontend key to backend field name
field_name = field_mapping.get(filter_key, filter_key)
if len(values) == 0:
# Empty array means "match nothing" - use impossible filter
filter_clauses.append(
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
)
elif len(values) == 1:
# Single value filter
filter_clauses.append({"term": {field_name: values[0]}})
else:
# Multiple values filter
filter_clauses.append({"terms": {field_name: values}})
try:
# Build aggregation query with filters applied
agg_query = {
"size": 0,
"aggs": {
"embedding_models": {
"terms": {
"field": "embedding_model",
"size": 10
}
}
}
}
# Apply filters to model detection if any exist
if filter_clauses:
agg_query["query"] = {
"bool": {
"filter": filter_clauses
}
}
agg_result = await opensearch_client.search(
index=INDEX_NAME, body=agg_query, params={"terminate_after": 0}
)
buckets = agg_result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
available_models = [b["key"] for b in buckets if b["key"]]
if not available_models:
# Fallback to configured model if no documents indexed yet
available_models = [embedding_model]
logger.info(
"Detected embedding models in corpus",
available_models=available_models,
model_counts={b["key"]: b["doc_count"] for b in buckets},
with_filters=len(filter_clauses) > 0
)
except Exception as e:
logger.warning("Failed to detect embedding models, using configured model", error=str(e))
available_models = [embedding_model]
# Parallelize embedding generation for all models
import asyncio
async def embed_with_model(model_name):
delay = EMBED_RETRY_INITIAL_DELAY
attempts = 0
last_exception = None
while attempts < MAX_EMBED_RETRIES:
attempts += 1
try:
resp = await clients.patched_async_client.embeddings.create(
model=model_name, input=[query]
)
return model_name, resp.data[0].embedding
except Exception as e:
last_exception = e
error_str = str(e).lower()
# Check if it's a rate limit error (429)
is_rate_limit = "429" in error_str or "rate" in error_str or "too many requests" in error_str
if attempts >= MAX_EMBED_RETRIES:
logger.error(
"Failed to embed with model after retries",
model=model_name,
attempts=attempts,
error=str(e),
is_rate_limit=is_rate_limit,
)
raise RuntimeError(
f"Failed to embed with model {model_name}"
) from e
# For rate limit errors, use longer delays
if is_rate_limit:
# Extract Retry-After header if available
retry_after = None
if hasattr(e, 'response') and hasattr(e.response, 'headers'):
retry_after = e.response.headers.get('Retry-After')
if retry_after:
try:
wait_time = float(retry_after)
logger.warning(
"Rate limited - respecting Retry-After header",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
retry_after=wait_time,
)
await asyncio.sleep(wait_time)
continue
except (ValueError, TypeError):
pass
# Use exponential backoff with jitter for rate limits
wait_time = min(delay * 2, EMBED_RETRY_MAX_DELAY)
# Add jitter to avoid thundering herd
import random
jitter = random.uniform(0, 0.5 * wait_time)
wait_time += jitter
logger.warning(
"Rate limited - backing off with exponential delay",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
wait_time=wait_time,
)
await asyncio.sleep(wait_time)
delay = wait_time
else:
# Regular retry for other errors
logger.warning(
"Retrying embedding generation",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
error=str(e),
)
await asyncio.sleep(delay)
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
# Should not reach here, but guard in case
raise RuntimeError(
f"Failed to embed with model {model_name}"
) from last_exception
# Run all embeddings in parallel
try:
embedding_results = await asyncio.gather(
*[embed_with_model(model) for model in available_models]
)
except Exception as e:
logger.error("Embedding generation failed", error=str(e))
raise
# Collect successful embeddings
for result in embedding_results:
if isinstance(result, tuple) and result[1] is not None:
model_name, embedding = result
query_embeddings[model_name] = embedding
logger.info(
"Generated query embeddings",
models=list(query_embeddings.keys()),
query_preview=query[:50]
)
else:
# Wildcard query - no embedding needed
filter_clauses = []
if filters:
# Map frontend filter names to backend field names
field_mapping = {
"data_sources": "filename",
"document_types": "mimetype",
"owners": "owner_name.keyword",
"connector_types": "connector_type",
}
for filter_key, values in filters.items():
if values is not None and isinstance(values, list):
# Map frontend key to backend field name
field_name = field_mapping.get(filter_key, filter_key)
if len(values) == 0:
# Empty array means "match nothing" - use impossible filter
filter_clauses.append(
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
)
elif len(values) == 1:
# Single value filter
filter_clauses.append({"term": {field_name: values[0]}})
else:
# Multiple values filter
filter_clauses.append({"terms": {field_name: values}})
# Build query body
if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope
if filter_clauses:
query_block = {"bool": {"filter": filter_clauses}}
else:
query_block = {"match_all": {}}
else:
# Build multi-model KNN queries
knn_queries = []
embedding_fields_to_check = []
for model_name, embedding_vector in query_embeddings.items():
field_name = get_embedding_field_name(model_name)
embedding_fields_to_check.append(field_name)
knn_queries.append({
"knn": {
field_name: {
"vector": embedding_vector,
"k": 50,
"num_candidates": 1000,
}
}
})
# Build exists filter - doc must have at least one embedding field
exists_any_embedding = {
"bool": {
"should": [{"exists": {"field": f}} for f in embedding_fields_to_check],
"minimum_should_match": 1
}
}
# Add exists filter to existing filters
all_filters = [*filter_clauses, exists_any_embedding]
logger.debug(
"Building hybrid query with filters",
user_filters_count=len(filter_clauses),
total_filters_count=len(all_filters),
filter_types=[type(f).__name__ for f in all_filters]
)
# Hybrid search query structure (semantic + keyword)
# Use dis_max to pick best score across multiple embedding fields
query_block = {
"bool": {
"should": [
{
"dis_max": {
"tie_breaker": 0.0, # Take only the best match, no blending
"boost": 0.7, # 70% weight for semantic search
"queries": knn_queries
}
},
{
"multi_match": {
"query": query,
"fields": ["text^2", "filename^1.5"],
"type": "best_fields",
"fuzziness": "AUTO",
"boost": 0.3, # 30% weight for keyword search
}
},
],
"minimum_should_match": 1,
"filter": all_filters,
}
}
search_body = {
"query": query_block,
"aggs": {
"data_sources": {"terms": {"field": "filename", "size": 20}},
"document_types": {"terms": {"field": "mimetype", "size": 10}},
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
"embedding_models": {"terms": {"field": "embedding_model", "size": 10}},
},
"_source": [
"filename",
"mimetype",
"page",
"text",
"source_url",
"owner",
"owner_name",
"owner_email",
"file_size",
"connector_type",
"embedding_model", # Include embedding model in results
"embedding_dimensions",
"allowed_users",
"allowed_groups",
],
"size": limit,
}
# Add score threshold only for hybrid (not meaningful for match_all)
if not is_wildcard_match_all and score_threshold > 0:
search_body["min_score"] = score_threshold
# Prepare fallback search body without num_candidates for clusters that don't support it
fallback_search_body = None
if not is_wildcard_match_all:
try:
fallback_search_body = copy.deepcopy(search_body)
knn_query_blocks = (
fallback_search_body["query"]["bool"]["should"][0]["dis_max"]["queries"]
)
for query_candidate in knn_query_blocks:
knn_section = query_candidate.get("knn")
if isinstance(knn_section, dict):
for params in knn_section.values():
if isinstance(params, dict):
params.pop("num_candidates", None)
except (KeyError, IndexError, AttributeError, TypeError):
fallback_search_body = None
# Authentication required - DLS will handle document filtering automatically
logger.debug(
"search_service authentication info",
user_id=user_id,
has_jwt_token=jwt_token is not None,
)
if not user_id:
logger.debug("search_service: user_id is None/empty, returning auth error")
return {"results": [], "error": "Authentication required"}
# Get user's OpenSearch client with JWT for OIDC auth through session manager
opensearch_client = self.session_manager.get_user_opensearch_client(
user_id, jwt_token
)
from opensearchpy.exceptions import RequestError
search_params = {"terminate_after": 0}
try:
results = await opensearch_client.search(
index=INDEX_NAME, body=search_body, params=search_params
)
except RequestError as e:
error_message = str(e)
if (
fallback_search_body is not None
and "unknown field [num_candidates]" in error_message.lower()
):
logger.warning(
"OpenSearch cluster does not support num_candidates; retrying without it"
)
try:
results = await opensearch_client.search(
index=INDEX_NAME,
body=fallback_search_body,
params=search_params,
)
except RequestError as retry_error:
logger.error(
"OpenSearch retry without num_candidates failed",
error=str(retry_error),
search_body=fallback_search_body,
)
raise
else:
logger.error(
"OpenSearch query failed", error=error_message, search_body=search_body
)
raise
except Exception as e:
logger.error(
"OpenSearch query failed", error=str(e), search_body=search_body
)
# Re-raise the exception so the API returns the error to frontend
raise
# Transform results (keep for backward compatibility)
chunks = []
for hit in results["hits"]["hits"]:
chunks.append(
{
"filename": hit["_source"].get("filename"),
"mimetype": hit["_source"].get("mimetype"),
"page": hit["_source"].get("page"),
"text": hit["_source"].get("text"),
"score": hit.get("_score"),
"source_url": hit["_source"].get("source_url"),
"owner": hit["_source"].get("owner"),
"owner_name": hit["_source"].get("owner_name"),
"owner_email": hit["_source"].get("owner_email"),
"file_size": hit["_source"].get("file_size"),
"connector_type": hit["_source"].get("connector_type"),
"embedding_model": hit["_source"].get("embedding_model"), # Include in results
"embedding_dimensions": hit["_source"].get("embedding_dimensions"),
}
)
# Return both transformed results and aggregations
return {
"results": chunks,
"aggregations": results.get("aggregations", {}),
"total": (
results.get("hits", {}).get("total", {}).get("value")
if isinstance(results.get("hits", {}).get("total"), dict)
else results.get("hits", {}).get("total")
),
}
async def search(
self,
query: str,
user_id: str = None,
jwt_token: str = None,
filters: Dict[str, Any] = None,
limit: int = 10,
score_threshold: float = 0,
embedding_model: str = None,
) -> Dict[str, Any]:
"""Public search method for API endpoints
Args:
embedding_model: Embedding model to use for search (defaults to the
currently configured embedding model)
"""
# Set auth context if provided (for direct API calls)
from config.settings import is_no_auth_mode
if user_id and (jwt_token or is_no_auth_mode()):
from auth_context import set_auth_context
set_auth_context(user_id, jwt_token)
# Set filters and limit in context if provided
if filters:
from auth_context import set_search_filters
set_search_filters(filters)
from auth_context import set_search_limit, set_score_threshold
set_search_limit(limit)
set_score_threshold(score_threshold)
return await self.search_tool(query, embedding_model=embedding_model)