feat: Add token tracking support to openai_embed function

- Add optional token_tracker parameter to openai_embed()
- Track prompt_tokens and total_tokens for embedding API calls
- Enables monitoring of embedding token usage alongside LLM calls
- Maintains backward compatibility with existing code
This commit is contained in:
Yasiru Rangana 2025-10-08 14:36:08 +11:00
parent f1e0110716
commit ec40b17eea

View file

@ -579,6 +579,7 @@ async def openai_embed(
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
client_configs: dict[str, Any] | None = None, client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API. """Generate embeddings for a list of texts using OpenAI's API.
@ -590,6 +591,7 @@ async def openai_embed(
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
token_tracker: Optional token usage tracker for monitoring API usage.
Returns: Returns:
A numpy array of embeddings, one per input text. A numpy array of embeddings, one per input text.
@ -608,6 +610,14 @@ async def openai_embed(
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="base64" model=model, input=texts, encoding_format="base64"
) )
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
return np.array( return np.array(
[ [
np.array(dp.embedding, dtype=np.float32) np.array(dp.embedding, dtype=np.float32)