feat: Add token tracking support to openai_embed function

- Add optional token_tracker parameter to openai_embed()
- Track prompt_tokens and total_tokens for embedding API calls
- Enables monitoring of embedding token usage alongside LLM calls
- Maintains backward compatibility with existing code
This commit is contained in:
Yasiru Rangana 2025-10-08 14:36:08 +11:00
parent f1e0110716
commit ec40b17eea

View file

@ -579,6 +579,7 @@ async def openai_embed(
base_url: str | None = None,
api_key: str | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API.
@ -590,6 +591,7 @@ async def openai_embed(
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
token_tracker: Optional token usage tracker for monitoring API usage.
Returns:
A numpy array of embeddings, one per input text.
@ -608,6 +610,14 @@ async def openai_embed(
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="base64"
)
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
return np.array(
[
np.array(dp.embedding, dtype=np.float32)