Add robust retry and mocking for embedding API calls

Introduces exponential backoff and Retry-After header handling for embedding API rate limits in both processors and search service. Adds CI fixture to mock OpenAI embeddings, avoiding real API calls during tests. Updates Makefile to document and set MOCK_EMBEDDINGS for integration and CI test targets.
This commit is contained in:
Edwin Jose 2025-11-26 15:54:38 -05:00
parent 5a6261d7e0
commit 13cc00043b
4 changed files with 200 additions and 17 deletions

View file

@ -190,6 +190,7 @@ test:
test-integration:
@echo "🧪 Running integration tests (requires infrastructure)..."
@echo "💡 Make sure to run 'make infra' first!"
@echo "💡 Set MOCK_EMBEDDINGS=true to avoid hitting real API"
uv run pytest tests/integration/ -v
# CI-friendly integration test target: brings up infra, waits, runs tests, tears down
@ -207,6 +208,7 @@ test-ci:
fi; \
echo "Cleaning up old containers and volumes..."; \
docker compose -f docker-compose-cpu.yml down -v 2>/dev/null || true; \
export MOCK_EMBEDDINGS=true; \
echo "Pulling latest images..."; \
docker compose -f docker-compose-cpu.yml pull; \
echo "Building OpenSearch image override..."; \
@ -247,8 +249,9 @@ test-ci:
for i in $$(seq 1 60); do \
curl -s $${DOCLING_ENDPOINT}/health >/dev/null 2>&1 && break || sleep 2; \
done; \
echo "Running integration tests"; \
echo "Running integration tests with mocked embeddings"; \
LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \
MOCK_EMBEDDINGS=true \
GOOGLE_OAUTH_CLIENT_ID="" \
GOOGLE_OAUTH_CLIENT_SECRET="" \
OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \
@ -328,8 +331,9 @@ test-ci-local:
for i in $$(seq 1 60); do \
curl -s $${DOCLING_ENDPOINT}/health >/dev/null 2>&1 && break || sleep 2; \
done; \
echo "Running integration tests"; \
echo "Running integration tests with mocked embeddings"; \
LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \
MOCK_EMBEDDINGS=true \
GOOGLE_OAUTH_CLIENT_ID="" \
GOOGLE_OAUTH_CLIENT_SECRET="" \
OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \

View file

@ -1,3 +1,4 @@
import asyncio
from typing import Any
from .tasks import UploadTask, FileTask
from utils.logging_config import get_logger
@ -208,11 +209,76 @@ class TaskProcessor:
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
embeddings = []
for batch in text_batches:
resp = await clients.patched_async_client.embeddings.create(
model=embedding_model, input=batch
)
embeddings.extend([d.embedding for d in resp.data])
# Embed batches with retry logic for rate limits
for batch_idx, batch in enumerate(text_batches):
max_retries = 3
retry_delay = 1.0
last_error = None
for attempt in range(max_retries):
try:
resp = await clients.patched_async_client.embeddings.create(
model=embedding_model, input=batch
)
embeddings.extend([d.embedding for d in resp.data])
break # Success, exit retry loop
except Exception as e:
last_error = e
error_str = str(e).lower()
is_rate_limit = "429" in error_str or "rate" in error_str or "too many requests" in error_str
if attempt < max_retries - 1:
# Extract Retry-After header if available
retry_after = None
if hasattr(e, 'response') and hasattr(e.response, 'headers'):
retry_after = e.response.headers.get('Retry-After')
if retry_after and is_rate_limit:
try:
wait_time = float(retry_after)
logger.warning(
"Rate limited during document processing - respecting Retry-After",
batch_idx=batch_idx,
attempt=attempt + 1,
retry_after=wait_time,
)
await asyncio.sleep(wait_time)
continue
except (ValueError, TypeError):
pass
# Use exponential backoff with jitter
wait_time = retry_delay * (2 ** attempt)
if is_rate_limit:
# Longer delays for rate limits
wait_time = min(wait_time * 2, 16.0)
else:
wait_time = min(wait_time, 8.0)
# Add jitter
import random
jitter = random.uniform(0, 0.3 * wait_time)
wait_time += jitter
logger.warning(
"Retrying embedding generation for batch",
batch_idx=batch_idx,
attempt=attempt + 1,
max_retries=max_retries,
wait_time=wait_time,
is_rate_limit=is_rate_limit,
error=str(e),
)
await asyncio.sleep(wait_time)
else:
# Final attempt failed
logger.error(
"Failed to embed batch after retries",
batch_idx=batch_idx,
attempts=max_retries,
error=str(e),
)
raise
# Index each chunk as a separate document
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):

View file

@ -156,26 +156,72 @@ class SearchService:
return model_name, resp.data[0].embedding
except Exception as e:
last_exception = e
error_str = str(e).lower()
# Check if it's a rate limit error (429)
is_rate_limit = "429" in error_str or "rate" in error_str or "too many requests" in error_str
if attempts >= MAX_EMBED_RETRIES:
logger.error(
"Failed to embed with model after retries",
model=model_name,
attempts=attempts,
error=str(e),
is_rate_limit=is_rate_limit,
)
raise RuntimeError(
f"Failed to embed with model {model_name}"
) from e
logger.warning(
"Retrying embedding generation",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
error=str(e),
)
await asyncio.sleep(delay)
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
# For rate limit errors, use longer delays
if is_rate_limit:
# Extract Retry-After header if available
retry_after = None
if hasattr(e, 'response') and hasattr(e.response, 'headers'):
retry_after = e.response.headers.get('Retry-After')
if retry_after:
try:
wait_time = float(retry_after)
logger.warning(
"Rate limited - respecting Retry-After header",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
retry_after=wait_time,
)
await asyncio.sleep(wait_time)
continue
except (ValueError, TypeError):
pass
# Use exponential backoff with jitter for rate limits
wait_time = min(delay * 2, EMBED_RETRY_MAX_DELAY)
# Add jitter to avoid thundering herd
import random
jitter = random.uniform(0, 0.5 * wait_time)
wait_time += jitter
logger.warning(
"Rate limited - backing off with exponential delay",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
wait_time=wait_time,
)
await asyncio.sleep(wait_time)
delay = wait_time
else:
# Regular retry for other errors
logger.warning(
"Retrying embedding generation",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
error=str(e),
)
await asyncio.sleep(delay)
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
# Should not reach here, but guard in case
raise RuntimeError(

View file

@ -2,6 +2,7 @@ import asyncio
import os
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
@ -20,8 +21,68 @@ from src.session_manager import SessionManager
from src.main import generate_jwt_keys
# Mock embeddings for CI environment to avoid rate limits
@pytest.fixture(scope="session", autouse=True)
def mock_openai_embeddings():
"""Mock OpenAI embeddings API calls in CI environment to avoid rate limits."""
# Only mock in CI environment
if os.getenv("CI") or os.getenv("MOCK_EMBEDDINGS", "false").lower() in ("true", "1", "yes"):
print("[DEBUG] Mocking OpenAI embeddings for CI environment")
def create_mock_embedding(texts, model="text-embedding-3-small", **kwargs):
"""Create mock embeddings with proper dimensions based on model."""
# Get dimensions based on model
from src.config.settings import OPENAI_EMBEDDING_DIMENSIONS, WATSONX_EMBEDDING_DIMENSIONS
dimensions = OPENAI_EMBEDDING_DIMENSIONS.get(
model,
WATSONX_EMBEDDING_DIMENSIONS.get(model, 1536)
)
# Handle both single string and list of strings
if isinstance(texts, str):
texts = [texts]
# Create mock response
mock_data = []
for idx, text in enumerate(texts):
# Create deterministic embeddings based on text hash for consistency
import hashlib
text_hash = int(hashlib.md5(text.encode()).hexdigest(), 16)
# Use hash to seed pseudo-random values
embedding = [(text_hash % 1000) / 1000.0 + i / dimensions for i in range(dimensions)]
mock_data.append(MagicMock(embedding=embedding, index=idx))
mock_response = MagicMock()
mock_response.data = mock_data
return mock_response
async def async_create_mock_embedding(model, input, **kwargs):
"""Async version of mock embedding creation."""
return create_mock_embedding(input, model, **kwargs)
# Patch the OpenAI client's embeddings.create method
with patch('openai.AsyncOpenAI') as mock_async_openai:
# Create a mock client instance
mock_client_instance = MagicMock()
mock_embeddings = MagicMock()
mock_embeddings.create = AsyncMock(side_effect=async_create_mock_embedding)
mock_client_instance.embeddings = mock_embeddings
mock_client_instance.close = AsyncMock()
# Make AsyncOpenAI() return our mock instance
mock_async_openai.return_value = mock_client_instance
# Also patch the agentd patch function to return the mock
with patch('agentd.patch.patch_openai_with_mcp', return_value=mock_client_instance):
yield mock_client_instance
else:
# In non-CI environments, don't mock - use real API
yield None
@pytest_asyncio.fixture(scope="session", autouse=True)
async def onboard_system():
async def onboard_system(mock_openai_embeddings):
"""Perform initial onboarding once for all tests in the session.
This ensures the OpenRAG config is marked as edited and properly initialized
@ -44,6 +105,12 @@ async def onboard_system():
except Exception as e:
print(f"[DEBUG] Could not clean OpenSearch data directory: {e}")
# If we're using mocks, patch the clients to use mock embeddings
if mock_openai_embeddings is not None:
print("[DEBUG] Using mock OpenAI embeddings client")
# Replace the client's patched_async_client with our mock
clients._patched_async_client = mock_openai_embeddings
# Initialize clients
await clients.initialize()