improve embedding generation timeout hadling w/ retry and error handling

This commit is contained in:
phact 2025-10-11 01:06:14 -04:00
parent 0c696afef8
commit a424bb422a
7 changed files with 91 additions and 31 deletions

View file

@ -660,24 +660,47 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
msg = "Embedding handle is required to embed documents." msg = "Embedding handle is required to embed documents."
raise ValueError(msg) raise ValueError(msg)
# Generate embeddings (threaded for concurrency) # Generate embeddings (threaded for concurrency) with retries
def embed_chunk(chunk_text: str) -> list[float]: def embed_chunk(chunk_text: str) -> list[float]:
return self.embedding.embed_documents([chunk_text])[0] return self.embedding.embed_documents([chunk_text])[0]
try: vectors: list[list[float]] | None = None
max_workers = min(max(len(texts), 1), 8) last_exception: Exception | None = None
with ThreadPoolExecutor(max_workers=max_workers) as executor: delay = 1.0
futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)} attempts = 0
vectors = [None] * len(texts)
for future in as_completed(futures): while attempts < 3:
idx = futures[future] attempts += 1
vectors[idx] = future.result() try:
except Exception as exc: max_workers = min(max(len(texts), 1), 8)
logger.warning( with ThreadPoolExecutor(max_workers=max_workers) as executor:
"Threaded embedding generation failed, falling back to synchronous mode: %s", futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)}
exc, vectors = [None] * len(texts)
for future in as_completed(futures):
idx = futures[future]
vectors[idx] = future.result()
break
except Exception as exc:
last_exception = exc
if attempts >= 3:
logger.error(
"Embedding generation failed after retries",
error=str(exc),
)
raise
logger.warning(
"Threaded embedding generation failed (attempt %s/%s), retrying in %.1fs",
attempts,
3,
delay,
)
time.sleep(delay)
delay = min(delay * 2, 8.0)
if vectors is None:
raise RuntimeError(
f"Embedding generation failed: {last_exception}" if last_exception else "Embedding generation failed"
) )
vectors = self.embedding.embed_documents(texts)
if not vectors: if not vectors:
self.log("No vectors generated from documents.") self.log("No vectors generated from documents.")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -7,6 +7,10 @@ from utils.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
MAX_EMBED_RETRIES = 3
EMBED_RETRY_INITIAL_DELAY = 1.0
EMBED_RETRY_MAX_DELAY = 8.0
class SearchService: class SearchService:
def __init__(self, session_manager=None): def __init__(self, session_manager=None):
@ -137,20 +141,53 @@ class SearchService:
import asyncio import asyncio
async def embed_with_model(model_name): async def embed_with_model(model_name):
try: delay = EMBED_RETRY_INITIAL_DELAY
resp = await clients.patched_async_client.embeddings.create( attempts = 0
model=model_name, input=[query] last_exception = None
)
return model_name, resp.data[0].embedding while attempts < MAX_EMBED_RETRIES:
except Exception as e: attempts += 1
logger.error(f"Failed to embed with model {model_name}", error=str(e)) try:
return model_name, None resp = await clients.patched_async_client.embeddings.create(
model=model_name, input=[query]
)
return model_name, resp.data[0].embedding
except Exception as e:
last_exception = e
if attempts >= MAX_EMBED_RETRIES:
logger.error(
"Failed to embed with model after retries",
model=model_name,
attempts=attempts,
error=str(e),
)
raise RuntimeError(
f"Failed to embed with model {model_name}"
) from e
logger.warning(
"Retrying embedding generation",
model=model_name,
attempt=attempts,
max_attempts=MAX_EMBED_RETRIES,
error=str(e),
)
await asyncio.sleep(delay)
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
# Should not reach here, but guard in case
raise RuntimeError(
f"Failed to embed with model {model_name}"
) from last_exception
# Run all embeddings in parallel # Run all embeddings in parallel
embedding_results = await asyncio.gather( try:
*[embed_with_model(model) for model in available_models], embedding_results = await asyncio.gather(
return_exceptions=True *[embed_with_model(model) for model in available_models]
) )
except Exception as e:
logger.error("Embedding generation failed", error=str(e))
raise
# Collect successful embeddings # Collect successful embeddings
for result in embedding_results: for result in embedding_results: