Merge pull request #2181 from yrangana/feat/openai-embedding-token-tracking
feat: Add token tracking support to openai_embed function
This commit is contained in:
commit
0f15fdc3e2
1 changed files with 10 additions and 0 deletions
|
|
@ -579,6 +579,7 @@ async def openai_embed(
|
|||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
client_configs: dict[str, Any] | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using OpenAI's API.
|
||||
|
||||
|
|
@ -590,6 +591,7 @@ async def openai_embed(
|
|||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
|
@ -608,6 +610,14 @@ async def openai_embed(
|
|||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="base64"
|
||||
)
|
||||
|
||||
if token_tracker and hasattr(response, "usage"):
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
return np.array(
|
||||
[
|
||||
np.array(dp.embedding, dtype=np.float32)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue