improve embedding generation timeout hadling w/ retry and error handling

This commit is contained in:
phact 2025-10-11 01:06:14 -04:00
parent 0c696afef8
commit a424bb422a
7 changed files with 91 additions and 31 deletions

View file

@ -660,24 +660,47 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
msg = "Embedding handle is required to embed documents."
raise ValueError(msg)
# Generate embeddings (threaded for concurrency)
# Generate embeddings (threaded for concurrency) with retries
def embed_chunk(chunk_text: str) -> list[float]:
return self.embedding.embed_documents([chunk_text])[0]
try:
max_workers = min(max(len(texts), 1), 8)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)}
vectors = [None] * len(texts)
for future in as_completed(futures):
idx = futures[future]
vectors[idx] = future.result()
except Exception as exc:
logger.warning(
"Threaded embedding generation failed, falling back to synchronous mode: %s",
exc,
vectors: list[list[float]] | None = None
last_exception: Exception | None = None
delay = 1.0
attempts = 0
while attempts < 3:
attempts += 1
try:
max_workers = min(max(len(texts), 1), 8)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)}
vectors = [None] * len(texts)
for future in as_completed(futures):
idx = futures[future]
vectors[idx] = future.result()
break
except Exception as exc:
last_exception = exc
if attempts >= 3:
logger.error(
"Embedding generation failed after retries",
error=str(exc),
)
raise
logger.warning(
"Threaded embedding generation failed (attempt %s/%s), retrying in %.1fs",
attempts,
3,
delay,
)
time.sleep(delay)
delay = min(delay * 2, 8.0)
if vectors is None:
raise RuntimeError(
f"Embedding generation failed: {last_exception}" if last_exception else "Embedding generation failed"
)
vectors = self.embedding.embed_documents(texts)
if not vectors:
self.log("No vectors generated from documents.")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -7,6 +7,10 @@ from utils.logging_config import get_logger
logger = get_logger(__name__)
MAX_EMBED_RETRIES = 3
EMBED_RETRY_INITIAL_DELAY = 1.0
EMBED_RETRY_MAX_DELAY = 8.0
class SearchService:
def __init__(self, session_manager=None):
@ -137,20 +141,53 @@ class SearchService:
import asyncio
async def embed_with_model(model_name):
try:
resp = await clients.patched_async_client.embeddings.create(
model=model_name, input=[query]
)
return model_name, resp.data[0].embedding
except Exception as e:
logger.error(f"Failed to embed with model {model_name}", error=str(e))
return model_name, None
delay = EMBED_RETRY_INITIAL_DELAY
attempts = 0
last_exception = None
while attempts < MAX_EMBED_RETRIES:
attempts += 1
try:
resp = await clients.patched_async_client.embeddings.create(
model=model_name, input=[query]
)
return model_name, resp.data[0].embedding
except Exception as e:
last_exception = e
if attempts >= MAX_EMBED_RETRIES:
logger.error(
"Failed to embed with model after retries",
model=model_name,
attempts=attempts,
error=str(e),
)
raise RuntimeError(
f"Failed to embed with model {model_name}"
) from e
logger.warning(
"Retrying embedding generation",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
error=str(e),
)
await asyncio.sleep(delay)
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
# Should not reach here, but guard in case
raise RuntimeError(
f"Failed to embed with model {model_name}"
) from last_exception
# Run all embeddings in parallel
embedding_results = await asyncio.gather(
*[embed_with_model(model) for model in available_models],
return_exceptions=True
)
try:
embedding_results = await asyncio.gather(
*[embed_with_model(model) for model in available_models]
)
except Exception as e:
logger.error("Embedding generation failed", error=str(e))
raise
# Collect successful embeddings
for result in embedding_results: