fix: sync openai.py from upstream

This commit is contained in:
Raphaël MANSUY 2025-12-04 19:20:11 +08:00
parent 253cc3b9e4
commit f6f3ed93d3

View file

@ -77,46 +77,86 @@ class InvalidResponseError(Exception):
def create_openai_async_client(
api_key: str | None = None,
base_url: str | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
timeout: int | None = None,
client_configs: dict[str, Any] | None = None,
) -> AsyncOpenAI:
"""Create an AsyncOpenAI client with the given configuration.
"""Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration.
Args:
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
use_azure: Whether to create an Azure OpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name (only used when use_azure=True).
api_version: Azure OpenAI API version (only used when use_azure=True).
timeout: Request timeout in seconds.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
Returns:
An AsyncOpenAI client instance.
An AsyncOpenAI or AsyncAzureOpenAI client instance.
"""
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
if use_azure:
from openai import AsyncAzureOpenAI
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
if not api_key:
api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get(
"LLM_BINDING_API_KEY"
)
if client_configs is None:
client_configs = {}
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
# Create a merged config dict with precedence: explicit params > client_configs
merged_configs = {
**client_configs,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
# Add explicit parameters (override client_configs)
if base_url is not None:
merged_configs["azure_endpoint"] = base_url
if azure_deployment is not None:
merged_configs["azure_deployment"] = azure_deployment
if api_version is not None:
merged_configs["api_version"] = api_version
if timeout is not None:
merged_configs["timeout"] = timeout
return AsyncAzureOpenAI(**merged_configs)
else:
merged_configs["base_url"] = os.environ.get(
"OPENAI_API_BASE", "https://api.openai.com/v1"
)
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
return AsyncOpenAI(**merged_configs)
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
else:
merged_configs["base_url"] = os.environ.get(
"OPENAI_API_BASE", "https://api.openai.com/v1"
)
if timeout is not None:
merged_configs["timeout"] = timeout
return AsyncOpenAI(**merged_configs)
@retry(
@ -141,6 +181,9 @@ async def openai_complete_if_cache(
stream: bool | None = None,
timeout: int | None = None,
keyword_extraction: bool = False,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@ -162,23 +205,33 @@ async def openai_complete_if_cache(
6. For non-streaming: COT content is prepended to regular content with <think> tags.
Args:
model: The OpenAI model to use.
model: The OpenAI model to use. For Azure, this can be the deployment name.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
token_tracker: Optional token usage tracker for monitoring API usage.
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
base_url: Optional base URL for the OpenAI API. For Azure, this should be the
Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment
variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None.
token_tracker: Optional token usage tracker for monitoring API usage.
stream: Whether to stream the response. Default is False.
timeout: Request timeout in seconds. Default is None.
keyword_extraction: Whether to enable keyword extraction mode. When True, triggers
special response formatting for keyword extraction. Default is False.
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
When True, creates an AsyncAzureOpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable.
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION
environment variable.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
explicit parameters (api_key, base_url). Supports proxy configuration,
custom headers, retry policies, etc.
Returns:
The completed text (with integrated COT content if available) or an async iterator
@ -207,10 +260,14 @@ async def openai_complete_if_cache(
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
# Create the OpenAI client
# Create the OpenAI client (supports both OpenAI and Azure)
openai_async_client = create_openai_async_client(
api_key=api_key,
base_url=base_url,
use_azure=use_azure,
azure_deployment=azure_deployment,
api_version=api_version,
timeout=timeout,
client_configs=client_configs,
)
@ -238,15 +295,19 @@ async def openai_complete_if_cache(
if timeout is not None:
kwargs["timeout"] = timeout
# Determine the correct model identifier to use
# For Azure OpenAI, we must use the deployment name instead of the model name
api_model = azure_deployment if use_azure and azure_deployment else model
try:
# Don't use async with context manager, use client directly
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
response = await openai_async_client.chat.completions.parse(
model=api_model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
model=api_model, messages=messages, **kwargs
)
except APITimeoutError as e:
logger.error(f"OpenAI API Timeout Error: {e}")
@ -291,7 +352,10 @@ async def openai_complete_if_cache(
# Check if choices exists and is not empty
if not hasattr(chunk, "choices") or not chunk.choices:
logger.warning(f"Received chunk without choices: {chunk}")
# Azure OpenAI sends content filter results in first chunk without choices
logger.debug(
f"Received chunk without choices (likely Azure content filter): {chunk}"
)
continue
# Check if delta exists
@ -453,46 +517,57 @@ async def openai_complete_if_cache(
raise InvalidResponseError("Invalid response from OpenAI API")
message = response.choices[0].message
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", "")
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
# Handle parsed responses (structured output via response_format)
# When using beta.chat.completions.parse(), the response is in message.parsed
if hasattr(message, "parsed") and message.parsed is not None:
# Serialize the parsed structured response to JSON
final_content = message.parsed.model_dump_json()
logger.debug("Using parsed structured response from API")
else:
# Handle regular content responses
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", "")
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
content or ""
) # Use empty string if content is None
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
content or ""
) # Use empty string if content is None
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
# No reasoning content, use regular content
final_content = content or ""
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = (
f"<think>{reasoning_content}</think>{final_content}"
)
else:
# No reasoning content, use regular content
# COT disabled, only use regular content
final_content = content or ""
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = f"<think>{reasoning_content}</think>{final_content}"
else:
# COT disabled, only use regular content
final_content = content or ""
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
# Apply Unicode decoding to final content if needed
if r"\u" in final_content:
@ -620,24 +695,40 @@ async def openai_embed(
embedding_dim: int | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API.
This function supports both standard OpenAI and Azure OpenAI services.
Args:
texts: List of texts to embed.
model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small").
For Azure, this can be the deployment name.
base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint.
For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None.
For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
client_configs: Additional configuration options for the AsyncOpenAI client.
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
explicit parameters (api_key, base_url). Supports proxy configuration,
custom headers, retry policies, etc.
token_tracker: Optional token usage tracker for monitoring API usage.
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
When True, creates an AsyncAzureOpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable.
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION
environment variable.
Returns:
A numpy array of embeddings, one per input text.
@ -647,15 +738,24 @@ async def openai_embed(
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
# Create the OpenAI client
# Create the OpenAI client (supports both OpenAI and Azure)
openai_async_client = create_openai_async_client(
api_key=api_key, base_url=base_url, client_configs=client_configs
api_key=api_key,
base_url=base_url,
use_azure=use_azure,
azure_deployment=azure_deployment,
api_version=api_version,
client_configs=client_configs,
)
async with openai_async_client:
# Determine the correct model identifier to use
# For Azure OpenAI, we must use the deployment name instead of the model name
api_model = azure_deployment if use_azure and azure_deployment else model
# Prepare API call parameters
api_params = {
"model": model,
"model": api_model,
"input": texts,
"encoding_format": "base64",
}
@ -682,3 +782,172 @@ async def openai_embed(
for dp in response.data
]
)
# Azure OpenAI wrapper functions for backward compatibility
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
stream: bool | None = None,
timeout: int | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
"""Azure OpenAI completion wrapper function.
This function provides backward compatibility by wrapping the unified
openai_complete_if_cache implementation with Azure-specific parameter handling.
All parameters from the underlying openai_complete_if_cache are exposed to ensure
full feature parity and API consistency.
"""
# Handle Azure-specific environment variables and parameters
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
)
api_key = (
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
or "2024-08-01-preview"
)
# Call the unified implementation with Azure-specific parameters
return await openai_complete_if_cache(
model=deployment,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
base_url=base_url,
api_key=api_key,
token_tracker=token_tracker,
stream=stream,
timeout=timeout,
use_azure=True,
azure_deployment=deployment,
api_version=api_version,
keyword_extraction=keyword_extraction,
**kwargs,
)
async def azure_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
"""Azure OpenAI complete wrapper function.
Provides backward compatibility for azure_openai_complete calls.
"""
if history_messages is None:
history_messages = []
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
client_configs: dict[str, Any] | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Azure OpenAI embedding wrapper function.
This function provides backward compatibility by wrapping the unified
openai_embed implementation with Azure-specific parameter handling.
All parameters from the underlying openai_embed are exposed to ensure
full feature parity and API consistency.
IMPORTANT - Decorator Usage:
1. This function is decorated with @wrap_embedding_func_with_attrs to provide
the EmbeddingFunc interface for users who need to access embedding_dim
and other attributes.
2. This function does NOT use @retry decorator to avoid double-wrapping,
since the underlying openai_embed.func already has retry logic.
3. This function calls openai_embed.func (the unwrapped function) instead of
openai_embed (the EmbeddingFunc instance) to avoid double decoration issues:
Correct: await openai_embed.func(...) # Calls unwrapped function with retry
Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
The call chain with correct implementation:
azure_openai_embed(texts)
EmbeddingFunc.__call__(texts) # azure's decorator
azure_openai_embed_impl(texts, embedding_dim=1536)
openai_embed.func(texts, ...)
@retry_wrapper(texts, ...) # openai's retry (only one layer)
openai_embed_impl(texts, ...)
actual embedding computation
"""
# Handle Azure-specific environment variables and parameters
deployment = (
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
or model
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
)
base_url = (
base_url
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
or os.getenv("EMBEDDING_BINDING_HOST")
)
api_key = (
api_key
or os.getenv("AZURE_EMBEDDING_API_KEY")
or os.getenv("EMBEDDING_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
or "2024-08-01-preview"
)
# CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration
# openai_embed is an EmbeddingFunc instance, .func accesses the underlying function
return await openai_embed.func(
texts=texts,
model=deployment,
base_url=base_url,
api_key=api_key,
token_tracker=token_tracker,
client_configs=client_configs,
use_azure=True,
azure_deployment=deployment,
api_version=api_version,
)