Add robust retry and mocking for embedding API calls
Introduces exponential backoff and Retry-After header handling for embedding API rate limits in both processors and search service. Adds CI fixture to mock OpenAI embeddings, avoiding real API calls during tests. Updates Makefile to document and set MOCK_EMBEDDINGS for integration and CI test targets.
This commit is contained in:
parent
5a6261d7e0
commit
13cc00043b
4 changed files with 200 additions and 17 deletions
8
Makefile
8
Makefile
|
|
@ -190,6 +190,7 @@ test:
|
|||
test-integration:
|
||||
@echo "🧪 Running integration tests (requires infrastructure)..."
|
||||
@echo "💡 Make sure to run 'make infra' first!"
|
||||
@echo "💡 Set MOCK_EMBEDDINGS=true to avoid hitting real API"
|
||||
uv run pytest tests/integration/ -v
|
||||
|
||||
# CI-friendly integration test target: brings up infra, waits, runs tests, tears down
|
||||
|
|
@ -207,6 +208,7 @@ test-ci:
|
|||
fi; \
|
||||
echo "Cleaning up old containers and volumes..."; \
|
||||
docker compose -f docker-compose-cpu.yml down -v 2>/dev/null || true; \
|
||||
export MOCK_EMBEDDINGS=true; \
|
||||
echo "Pulling latest images..."; \
|
||||
docker compose -f docker-compose-cpu.yml pull; \
|
||||
echo "Building OpenSearch image override..."; \
|
||||
|
|
@ -247,8 +249,9 @@ test-ci:
|
|||
for i in $$(seq 1 60); do \
|
||||
curl -s $${DOCLING_ENDPOINT}/health >/dev/null 2>&1 && break || sleep 2; \
|
||||
done; \
|
||||
echo "Running integration tests"; \
|
||||
echo "Running integration tests with mocked embeddings"; \
|
||||
LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \
|
||||
MOCK_EMBEDDINGS=true \
|
||||
GOOGLE_OAUTH_CLIENT_ID="" \
|
||||
GOOGLE_OAUTH_CLIENT_SECRET="" \
|
||||
OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \
|
||||
|
|
@ -328,8 +331,9 @@ test-ci-local:
|
|||
for i in $$(seq 1 60); do \
|
||||
curl -s $${DOCLING_ENDPOINT}/health >/dev/null 2>&1 && break || sleep 2; \
|
||||
done; \
|
||||
echo "Running integration tests"; \
|
||||
echo "Running integration tests with mocked embeddings"; \
|
||||
LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \
|
||||
MOCK_EMBEDDINGS=true \
|
||||
GOOGLE_OAUTH_CLIENT_ID="" \
|
||||
GOOGLE_OAUTH_CLIENT_SECRET="" \
|
||||
OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
from .tasks import UploadTask, FileTask
|
||||
from utils.logging_config import get_logger
|
||||
|
|
@ -208,11 +209,76 @@ class TaskProcessor:
|
|||
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
|
||||
embeddings = []
|
||||
|
||||
for batch in text_batches:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=embedding_model, input=batch
|
||||
)
|
||||
embeddings.extend([d.embedding for d in resp.data])
|
||||
# Embed batches with retry logic for rate limits
|
||||
for batch_idx, batch in enumerate(text_batches):
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=embedding_model, input=batch
|
||||
)
|
||||
embeddings.extend([d.embedding for d in resp.data])
|
||||
break # Success, exit retry loop
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "429" in error_str or "rate" in error_str or "too many requests" in error_str
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Extract Retry-After header if available
|
||||
retry_after = None
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'headers'):
|
||||
retry_after = e.response.headers.get('Retry-After')
|
||||
|
||||
if retry_after and is_rate_limit:
|
||||
try:
|
||||
wait_time = float(retry_after)
|
||||
logger.warning(
|
||||
"Rate limited during document processing - respecting Retry-After",
|
||||
batch_idx=batch_idx,
|
||||
attempt=attempt + 1,
|
||||
retry_after=wait_time,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Use exponential backoff with jitter
|
||||
wait_time = retry_delay * (2 ** attempt)
|
||||
if is_rate_limit:
|
||||
# Longer delays for rate limits
|
||||
wait_time = min(wait_time * 2, 16.0)
|
||||
else:
|
||||
wait_time = min(wait_time, 8.0)
|
||||
|
||||
# Add jitter
|
||||
import random
|
||||
jitter = random.uniform(0, 0.3 * wait_time)
|
||||
wait_time += jitter
|
||||
|
||||
logger.warning(
|
||||
"Retrying embedding generation for batch",
|
||||
batch_idx=batch_idx,
|
||||
attempt=attempt + 1,
|
||||
max_retries=max_retries,
|
||||
wait_time=wait_time,
|
||||
is_rate_limit=is_rate_limit,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
# Final attempt failed
|
||||
logger.error(
|
||||
"Failed to embed batch after retries",
|
||||
batch_idx=batch_idx,
|
||||
attempts=max_retries,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
# Index each chunk as a separate document
|
||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||
|
|
|
|||
|
|
@ -156,26 +156,72 @@ class SearchService:
|
|||
return model_name, resp.data[0].embedding
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Check if it's a rate limit error (429)
|
||||
is_rate_limit = "429" in error_str or "rate" in error_str or "too many requests" in error_str
|
||||
|
||||
if attempts >= MAX_EMBED_RETRIES:
|
||||
logger.error(
|
||||
"Failed to embed with model after retries",
|
||||
model=model_name,
|
||||
attempts=attempts,
|
||||
error=str(e),
|
||||
is_rate_limit=is_rate_limit,
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to embed with model {model_name}"
|
||||
) from e
|
||||
|
||||
logger.warning(
|
||||
"Retrying embedding generation",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
# For rate limit errors, use longer delays
|
||||
if is_rate_limit:
|
||||
# Extract Retry-After header if available
|
||||
retry_after = None
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'headers'):
|
||||
retry_after = e.response.headers.get('Retry-After')
|
||||
|
||||
if retry_after:
|
||||
try:
|
||||
wait_time = float(retry_after)
|
||||
logger.warning(
|
||||
"Rate limited - respecting Retry-After header",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
retry_after=wait_time,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Use exponential backoff with jitter for rate limits
|
||||
wait_time = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
# Add jitter to avoid thundering herd
|
||||
import random
|
||||
jitter = random.uniform(0, 0.5 * wait_time)
|
||||
wait_time += jitter
|
||||
|
||||
logger.warning(
|
||||
"Rate limited - backing off with exponential delay",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
wait_time=wait_time,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
delay = wait_time
|
||||
else:
|
||||
# Regular retry for other errors
|
||||
logger.warning(
|
||||
"Retrying embedding generation",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
|
||||
# Should not reach here, but guard in case
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
|
@ -20,8 +21,68 @@ from src.session_manager import SessionManager
|
|||
from src.main import generate_jwt_keys
|
||||
|
||||
|
||||
# Mock embeddings for CI environment to avoid rate limits
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_openai_embeddings():
|
||||
"""Mock OpenAI embeddings API calls in CI environment to avoid rate limits."""
|
||||
# Only mock in CI environment
|
||||
if os.getenv("CI") or os.getenv("MOCK_EMBEDDINGS", "false").lower() in ("true", "1", "yes"):
|
||||
print("[DEBUG] Mocking OpenAI embeddings for CI environment")
|
||||
|
||||
def create_mock_embedding(texts, model="text-embedding-3-small", **kwargs):
|
||||
"""Create mock embeddings with proper dimensions based on model."""
|
||||
# Get dimensions based on model
|
||||
from src.config.settings import OPENAI_EMBEDDING_DIMENSIONS, WATSONX_EMBEDDING_DIMENSIONS
|
||||
|
||||
dimensions = OPENAI_EMBEDDING_DIMENSIONS.get(
|
||||
model,
|
||||
WATSONX_EMBEDDING_DIMENSIONS.get(model, 1536)
|
||||
)
|
||||
|
||||
# Handle both single string and list of strings
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Create mock response
|
||||
mock_data = []
|
||||
for idx, text in enumerate(texts):
|
||||
# Create deterministic embeddings based on text hash for consistency
|
||||
import hashlib
|
||||
text_hash = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
||||
# Use hash to seed pseudo-random values
|
||||
embedding = [(text_hash % 1000) / 1000.0 + i / dimensions for i in range(dimensions)]
|
||||
mock_data.append(MagicMock(embedding=embedding, index=idx))
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = mock_data
|
||||
return mock_response
|
||||
|
||||
async def async_create_mock_embedding(model, input, **kwargs):
|
||||
"""Async version of mock embedding creation."""
|
||||
return create_mock_embedding(input, model, **kwargs)
|
||||
|
||||
# Patch the OpenAI client's embeddings.create method
|
||||
with patch('openai.AsyncOpenAI') as mock_async_openai:
|
||||
# Create a mock client instance
|
||||
mock_client_instance = MagicMock()
|
||||
mock_embeddings = MagicMock()
|
||||
mock_embeddings.create = AsyncMock(side_effect=async_create_mock_embedding)
|
||||
mock_client_instance.embeddings = mock_embeddings
|
||||
mock_client_instance.close = AsyncMock()
|
||||
|
||||
# Make AsyncOpenAI() return our mock instance
|
||||
mock_async_openai.return_value = mock_client_instance
|
||||
|
||||
# Also patch the agentd patch function to return the mock
|
||||
with patch('agentd.patch.patch_openai_with_mcp', return_value=mock_client_instance):
|
||||
yield mock_client_instance
|
||||
else:
|
||||
# In non-CI environments, don't mock - use real API
|
||||
yield None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def onboard_system():
|
||||
async def onboard_system(mock_openai_embeddings):
|
||||
"""Perform initial onboarding once for all tests in the session.
|
||||
|
||||
This ensures the OpenRAG config is marked as edited and properly initialized
|
||||
|
|
@ -44,6 +105,12 @@ async def onboard_system():
|
|||
except Exception as e:
|
||||
print(f"[DEBUG] Could not clean OpenSearch data directory: {e}")
|
||||
|
||||
# If we're using mocks, patch the clients to use mock embeddings
|
||||
if mock_openai_embeddings is not None:
|
||||
print("[DEBUG] Using mock OpenAI embeddings client")
|
||||
# Replace the client's patched_async_client with our mock
|
||||
clients._patched_async_client = mock_openai_embeddings
|
||||
|
||||
# Initialize clients
|
||||
await clients.initialize()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue