This commit is contained in:
Raphaël MANSUY 2025-12-04 19:18:15 +08:00
parent 593b277945
commit fd7c3e269d

View file

@ -579,6 +579,7 @@ async def openai_embed(
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
client_configs: dict[str, Any] | None = None, client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API. """Generate embeddings for a list of texts using OpenAI's API.
@ -590,6 +591,7 @@ async def openai_embed(
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
token_tracker: Optional token usage tracker for monitoring API usage.
Returns: Returns:
A numpy array of embeddings, one per input text. A numpy array of embeddings, one per input text.
@ -608,6 +610,14 @@ async def openai_embed(
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="base64" model=model, input=texts, encoding_format="base64"
) )
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
return np.array( return np.array(
[ [
np.array(dp.embedding, dtype=np.float32) np.array(dp.embedding, dtype=np.float32)