Merge pull request #2181 from yrangana/feat/openai-embedding-token-tracking

feat: Add token tracking support to openai_embed function
This commit is contained in:
Daniel.y 2025-10-09 12:15:29 +08:00 committed by GitHub
commit 0f15fdc3e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -579,6 +579,7 @@ async def openai_embed(
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
client_configs: dict[str, Any] | None = None, client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API. """Generate embeddings for a list of texts using OpenAI's API.
@ -590,6 +591,7 @@ async def openai_embed(
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
token_tracker: Optional token usage tracker for monitoring API usage.
Returns: Returns:
A numpy array of embeddings, one per input text. A numpy array of embeddings, one per input text.
@ -608,6 +610,14 @@ async def openai_embed(
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="base64" model=model, input=texts, encoding_format="base64"
) )
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
return np.array( return np.array(
[ [
np.array(dp.embedding, dtype=np.float32) np.array(dp.embedding, dtype=np.float32)