Compare commits
1 commit
main
...
fix-integr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13cc00043b |
4 changed files with 200 additions and 17 deletions
8
Makefile
8
Makefile
|
|
@ -190,6 +190,7 @@ test:
|
|||
test-integration:
|
||||
@echo "🧪 Running integration tests (requires infrastructure)..."
|
||||
@echo "💡 Make sure to run 'make infra' first!"
|
||||
@echo "💡 Set MOCK_EMBEDDINGS=true to avoid hitting real API"
|
||||
uv run pytest tests/integration/ -v
|
||||
|
||||
# CI-friendly integration test target: brings up infra, waits, runs tests, tears down
|
||||
|
|
@ -207,6 +208,7 @@ test-ci:
|
|||
fi; \
|
||||
echo "Cleaning up old containers and volumes..."; \
|
||||
docker compose -f docker-compose-cpu.yml down -v 2>/dev/null || true; \
|
||||
export MOCK_EMBEDDINGS=true; \
|
||||
echo "Pulling latest images..."; \
|
||||
docker compose -f docker-compose-cpu.yml pull; \
|
||||
echo "Building OpenSearch image override..."; \
|
||||
|
|
@ -247,8 +249,9 @@ test-ci:
|
|||
for i in $$(seq 1 60); do \
|
||||
curl -s $${DOCLING_ENDPOINT}/health >/dev/null 2>&1 && break || sleep 2; \
|
||||
done; \
|
||||
echo "Running integration tests"; \
|
||||
echo "Running integration tests with mocked embeddings"; \
|
||||
LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \
|
||||
MOCK_EMBEDDINGS=true \
|
||||
GOOGLE_OAUTH_CLIENT_ID="" \
|
||||
GOOGLE_OAUTH_CLIENT_SECRET="" \
|
||||
OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \
|
||||
|
|
@ -328,8 +331,9 @@ test-ci-local:
|
|||
for i in $$(seq 1 60); do \
|
||||
curl -s $${DOCLING_ENDPOINT}/health >/dev/null 2>&1 && break || sleep 2; \
|
||||
done; \
|
||||
echo "Running integration tests"; \
|
||||
echo "Running integration tests with mocked embeddings"; \
|
||||
LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \
|
||||
MOCK_EMBEDDINGS=true \
|
||||
GOOGLE_OAUTH_CLIENT_ID="" \
|
||||
GOOGLE_OAUTH_CLIENT_SECRET="" \
|
||||
OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
from .tasks import UploadTask, FileTask
|
||||
from utils.logging_config import get_logger
|
||||
|
|
@ -208,11 +209,76 @@ class TaskProcessor:
|
|||
text_batches = chunk_texts_for_embeddings(texts, max_tokens=8000)
|
||||
embeddings = []
|
||||
|
||||
for batch in text_batches:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=embedding_model, input=batch
|
||||
)
|
||||
embeddings.extend([d.embedding for d in resp.data])
|
||||
# Embed batches with retry logic for rate limits
|
||||
for batch_idx, batch in enumerate(text_batches):
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=embedding_model, input=batch
|
||||
)
|
||||
embeddings.extend([d.embedding for d in resp.data])
|
||||
break # Success, exit retry loop
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "429" in error_str or "rate" in error_str or "too many requests" in error_str
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Extract Retry-After header if available
|
||||
retry_after = None
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'headers'):
|
||||
retry_after = e.response.headers.get('Retry-After')
|
||||
|
||||
if retry_after and is_rate_limit:
|
||||
try:
|
||||
wait_time = float(retry_after)
|
||||
logger.warning(
|
||||
"Rate limited during document processing - respecting Retry-After",
|
||||
batch_idx=batch_idx,
|
||||
attempt=attempt + 1,
|
||||
retry_after=wait_time,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Use exponential backoff with jitter
|
||||
wait_time = retry_delay * (2 ** attempt)
|
||||
if is_rate_limit:
|
||||
# Longer delays for rate limits
|
||||
wait_time = min(wait_time * 2, 16.0)
|
||||
else:
|
||||
wait_time = min(wait_time, 8.0)
|
||||
|
||||
# Add jitter
|
||||
import random
|
||||
jitter = random.uniform(0, 0.3 * wait_time)
|
||||
wait_time += jitter
|
||||
|
||||
logger.warning(
|
||||
"Retrying embedding generation for batch",
|
||||
batch_idx=batch_idx,
|
||||
attempt=attempt + 1,
|
||||
max_retries=max_retries,
|
||||
wait_time=wait_time,
|
||||
is_rate_limit=is_rate_limit,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
# Final attempt failed
|
||||
logger.error(
|
||||
"Failed to embed batch after retries",
|
||||
batch_idx=batch_idx,
|
||||
attempts=max_retries,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
# Index each chunk as a separate document
|
||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||
|
|
|
|||
|
|
@ -156,26 +156,72 @@ class SearchService:
|
|||
return model_name, resp.data[0].embedding
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Check if it's a rate limit error (429)
|
||||
is_rate_limit = "429" in error_str or "rate" in error_str or "too many requests" in error_str
|
||||
|
||||
if attempts >= MAX_EMBED_RETRIES:
|
||||
logger.error(
|
||||
"Failed to embed with model after retries",
|
||||
model=model_name,
|
||||
attempts=attempts,
|
||||
error=str(e),
|
||||
is_rate_limit=is_rate_limit,
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to embed with model {model_name}"
|
||||
) from e
|
||||
|
||||
logger.warning(
|
||||
"Retrying embedding generation",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
# For rate limit errors, use longer delays
|
||||
if is_rate_limit:
|
||||
# Extract Retry-After header if available
|
||||
retry_after = None
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'headers'):
|
||||
retry_after = e.response.headers.get('Retry-After')
|
||||
|
||||
if retry_after:
|
||||
try:
|
||||
wait_time = float(retry_after)
|
||||
logger.warning(
|
||||
"Rate limited - respecting Retry-After header",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
retry_after=wait_time,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Use exponential backoff with jitter for rate limits
|
||||
wait_time = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
# Add jitter to avoid thundering herd
|
||||
import random
|
||||
jitter = random.uniform(0, 0.5 * wait_time)
|
||||
wait_time += jitter
|
||||
|
||||
logger.warning(
|
||||
"Rate limited - backing off with exponential delay",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
wait_time=wait_time,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
delay = wait_time
|
||||
else:
|
||||
# Regular retry for other errors
|
||||
logger.warning(
|
||||
"Retrying embedding generation",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
|
||||
# Should not reach here, but guard in case
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
|
@ -20,8 +21,68 @@ from src.session_manager import SessionManager
|
|||
from src.main import generate_jwt_keys
|
||||
|
||||
|
||||
# Mock embeddings for CI environment to avoid rate limits
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_openai_embeddings():
|
||||
"""Mock OpenAI embeddings API calls in CI environment to avoid rate limits."""
|
||||
# Only mock in CI environment
|
||||
if os.getenv("CI") or os.getenv("MOCK_EMBEDDINGS", "false").lower() in ("true", "1", "yes"):
|
||||
print("[DEBUG] Mocking OpenAI embeddings for CI environment")
|
||||
|
||||
def create_mock_embedding(texts, model="text-embedding-3-small", **kwargs):
|
||||
"""Create mock embeddings with proper dimensions based on model."""
|
||||
# Get dimensions based on model
|
||||
from src.config.settings import OPENAI_EMBEDDING_DIMENSIONS, WATSONX_EMBEDDING_DIMENSIONS
|
||||
|
||||
dimensions = OPENAI_EMBEDDING_DIMENSIONS.get(
|
||||
model,
|
||||
WATSONX_EMBEDDING_DIMENSIONS.get(model, 1536)
|
||||
)
|
||||
|
||||
# Handle both single string and list of strings
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Create mock response
|
||||
mock_data = []
|
||||
for idx, text in enumerate(texts):
|
||||
# Create deterministic embeddings based on text hash for consistency
|
||||
import hashlib
|
||||
text_hash = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
||||
# Use hash to seed pseudo-random values
|
||||
embedding = [(text_hash % 1000) / 1000.0 + i / dimensions for i in range(dimensions)]
|
||||
mock_data.append(MagicMock(embedding=embedding, index=idx))
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = mock_data
|
||||
return mock_response
|
||||
|
||||
async def async_create_mock_embedding(model, input, **kwargs):
|
||||
"""Async version of mock embedding creation."""
|
||||
return create_mock_embedding(input, model, **kwargs)
|
||||
|
||||
# Patch the OpenAI client's embeddings.create method
|
||||
with patch('openai.AsyncOpenAI') as mock_async_openai:
|
||||
# Create a mock client instance
|
||||
mock_client_instance = MagicMock()
|
||||
mock_embeddings = MagicMock()
|
||||
mock_embeddings.create = AsyncMock(side_effect=async_create_mock_embedding)
|
||||
mock_client_instance.embeddings = mock_embeddings
|
||||
mock_client_instance.close = AsyncMock()
|
||||
|
||||
# Make AsyncOpenAI() return our mock instance
|
||||
mock_async_openai.return_value = mock_client_instance
|
||||
|
||||
# Also patch the agentd patch function to return the mock
|
||||
with patch('agentd.patch.patch_openai_with_mcp', return_value=mock_client_instance):
|
||||
yield mock_client_instance
|
||||
else:
|
||||
# In non-CI environments, don't mock - use real API
|
||||
yield None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def onboard_system():
|
||||
async def onboard_system(mock_openai_embeddings):
|
||||
"""Perform initial onboarding once for all tests in the session.
|
||||
|
||||
This ensures the OpenRAG config is marked as edited and properly initialized
|
||||
|
|
@ -44,6 +105,12 @@ async def onboard_system():
|
|||
except Exception as e:
|
||||
print(f"[DEBUG] Could not clean OpenSearch data directory: {e}")
|
||||
|
||||
# If we're using mocks, patch the clients to use mock embeddings
|
||||
if mock_openai_embeddings is not None:
|
||||
print("[DEBUG] Using mock OpenAI embeddings client")
|
||||
# Replace the client's patched_async_client with our mock
|
||||
clients._patched_async_client = mock_openai_embeddings
|
||||
|
||||
# Initialize clients
|
||||
await clients.initialize()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue