Add dimensions parameter support to openai_embed()

This commit is contained in:
Yasiru Rangana 2025-11-07 09:55:06 +11:00
parent 366a1e0f5f
commit d94aae9c5e

View file

@ -609,6 +609,7 @@ async def openai_embed(
model: str = "text-embedding-3-small", model: str = "text-embedding-3-small",
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
embedding_dim: int = None,
client_configs: dict[str, Any] | None = None, client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None, token_tracker: Any | None = None,
) -> np.ndarray: ) -> np.ndarray:
@ -619,6 +620,7 @@ async def openai_embed(
model: The OpenAI embedding model to use. model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API. base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
embedding_dim: Optional embedding dimension. If None, uses the default embedding dimension for the model. (will be passed to API for dimension reduction).
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
@ -638,9 +640,19 @@ async def openai_embed(
) )
async with openai_async_client: async with openai_async_client:
response = await openai_async_client.embeddings.create( # Prepare API call parameters
model=model, input=texts, encoding_format="base64" api_params = {
) "model": model,
"input": texts,
"encoding_format": "base64",
}
# Add dimensions parameter only if embedding_dim is provided
if embedding_dim is not None:
api_params["dimensions"] = embedding_dim
# Make API call
response = await openai_async_client.embeddings.create(**api_params)
if token_tracker and hasattr(response, "usage"): if token_tracker and hasattr(response, "usage"):
token_counts = { token_counts = {