improve embedding generation timeout hadling w/ retry and error handling
This commit is contained in:
parent
0c696afef8
commit
a424bb422a
7 changed files with 91 additions and 31 deletions
|
|
@ -660,24 +660,47 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
|||
msg = "Embedding handle is required to embed documents."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Generate embeddings (threaded for concurrency)
|
||||
# Generate embeddings (threaded for concurrency) with retries
|
||||
def embed_chunk(chunk_text: str) -> list[float]:
|
||||
return self.embedding.embed_documents([chunk_text])[0]
|
||||
|
||||
try:
|
||||
max_workers = min(max(len(texts), 1), 8)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)}
|
||||
vectors = [None] * len(texts)
|
||||
for future in as_completed(futures):
|
||||
idx = futures[future]
|
||||
vectors[idx] = future.result()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Threaded embedding generation failed, falling back to synchronous mode: %s",
|
||||
exc,
|
||||
vectors: list[list[float]] | None = None
|
||||
last_exception: Exception | None = None
|
||||
delay = 1.0
|
||||
attempts = 0
|
||||
|
||||
while attempts < 3:
|
||||
attempts += 1
|
||||
try:
|
||||
max_workers = min(max(len(texts), 1), 8)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)}
|
||||
vectors = [None] * len(texts)
|
||||
for future in as_completed(futures):
|
||||
idx = futures[future]
|
||||
vectors[idx] = future.result()
|
||||
break
|
||||
except Exception as exc:
|
||||
last_exception = exc
|
||||
if attempts >= 3:
|
||||
logger.error(
|
||||
"Embedding generation failed after retries",
|
||||
error=str(exc),
|
||||
)
|
||||
raise
|
||||
logger.warning(
|
||||
"Threaded embedding generation failed (attempt %s/%s), retrying in %.1fs",
|
||||
attempts,
|
||||
3,
|
||||
delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay = min(delay * 2, 8.0)
|
||||
|
||||
if vectors is None:
|
||||
raise RuntimeError(
|
||||
f"Embedding generation failed: {last_exception}" if last_exception else "Embedding generation failed"
|
||||
)
|
||||
vectors = self.embedding.embed_documents(texts)
|
||||
|
||||
if not vectors:
|
||||
self.log("No vectors generated from documents.")
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -7,6 +7,10 @@ from utils.logging_config import get_logger
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
MAX_EMBED_RETRIES = 3
|
||||
EMBED_RETRY_INITIAL_DELAY = 1.0
|
||||
EMBED_RETRY_MAX_DELAY = 8.0
|
||||
|
||||
|
||||
class SearchService:
|
||||
def __init__(self, session_manager=None):
|
||||
|
|
@ -137,20 +141,53 @@ class SearchService:
|
|||
import asyncio
|
||||
|
||||
async def embed_with_model(model_name):
|
||||
try:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=model_name, input=[query]
|
||||
)
|
||||
return model_name, resp.data[0].embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed with model {model_name}", error=str(e))
|
||||
return model_name, None
|
||||
delay = EMBED_RETRY_INITIAL_DELAY
|
||||
attempts = 0
|
||||
last_exception = None
|
||||
|
||||
while attempts < MAX_EMBED_RETRIES:
|
||||
attempts += 1
|
||||
try:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=model_name, input=[query]
|
||||
)
|
||||
return model_name, resp.data[0].embedding
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempts >= MAX_EMBED_RETRIES:
|
||||
logger.error(
|
||||
"Failed to embed with model after retries",
|
||||
model=model_name,
|
||||
attempts=attempts,
|
||||
error=str(e),
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to embed with model {model_name}"
|
||||
) from e
|
||||
|
||||
logger.warning(
|
||||
"Retrying embedding generation",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
|
||||
# Should not reach here, but guard in case
|
||||
raise RuntimeError(
|
||||
f"Failed to embed with model {model_name}"
|
||||
) from last_exception
|
||||
|
||||
# Run all embeddings in parallel
|
||||
embedding_results = await asyncio.gather(
|
||||
*[embed_with_model(model) for model in available_models],
|
||||
return_exceptions=True
|
||||
)
|
||||
try:
|
||||
embedding_results = await asyncio.gather(
|
||||
*[embed_with_model(model) for model in available_models]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Embedding generation failed", error=str(e))
|
||||
raise
|
||||
|
||||
# Collect successful embeddings
|
||||
for result in embedding_results:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue