This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:32 +08:00
parent 34b77e50ae
commit 65d26c6476

View file

@ -77,46 +77,86 @@ class InvalidResponseError(Exception):
def create_openai_async_client(
api_key: str | None = None,
base_url: str | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
timeout: int | None = None,
client_configs: dict[str, Any] | None = None,
) -> AsyncOpenAI:
"""Create an AsyncOpenAI client with the given configuration.
"""Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration.
Args:
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
use_azure: Whether to create an Azure OpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name (only used when use_azure=True).
api_version: Azure OpenAI API version (only used when use_azure=True).
timeout: Request timeout in seconds.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
Returns:
An AsyncOpenAI client instance.
An AsyncOpenAI or AsyncAzureOpenAI client instance.
"""
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
if use_azure:
from openai import AsyncAzureOpenAI
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
if not api_key:
api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get(
"LLM_BINDING_API_KEY"
)
if client_configs is None:
client_configs = {}
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
# Create a merged config dict with precedence: explicit params > client_configs
merged_configs = {
**client_configs,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
# Add explicit parameters (override client_configs)
if base_url is not None:
merged_configs["azure_endpoint"] = base_url
if azure_deployment is not None:
merged_configs["azure_deployment"] = azure_deployment
if api_version is not None:
merged_configs["api_version"] = api_version
if timeout is not None:
merged_configs["timeout"] = timeout
return AsyncAzureOpenAI(**merged_configs)
else:
merged_configs["base_url"] = os.environ.get(
"OPENAI_API_BASE", "https://api.openai.com/v1"
)
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
return AsyncOpenAI(**merged_configs)
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
else:
merged_configs["base_url"] = os.environ.get(
"OPENAI_API_BASE", "https://api.openai.com/v1"
)
if timeout is not None:
merged_configs["timeout"] = timeout
return AsyncOpenAI(**merged_configs)
@retry(
@ -141,6 +181,9 @@ async def openai_complete_if_cache(
stream: bool | None = None,
timeout: int | None = None,
keyword_extraction: bool = False,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@ -162,23 +205,33 @@ async def openai_complete_if_cache(
6. For non-streaming: COT content is prepended to regular content with <think> tags.
Args:
model: The OpenAI model to use.
model: The OpenAI model to use. For Azure, this can be the deployment name.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
token_tracker: Optional token usage tracker for monitoring API usage.
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
base_url: Optional base URL for the OpenAI API. For Azure, this should be the
Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment
variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None.
token_tracker: Optional token usage tracker for monitoring API usage.
stream: Whether to stream the response. Default is False.
timeout: Request timeout in seconds. Default is None.
keyword_extraction: Whether to enable keyword extraction mode. When True, triggers
special response formatting for keyword extraction. Default is False.
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
When True, creates an AsyncAzureOpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable.
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION
environment variable.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
explicit parameters (api_key, base_url). Supports proxy configuration,
custom headers, retry policies, etc.
Returns:
The completed text (with integrated COT content if available) or an async iterator
@ -207,10 +260,14 @@ async def openai_complete_if_cache(
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
# Create the OpenAI client
# Create the OpenAI client (supports both OpenAI and Azure)
openai_async_client = create_openai_async_client(
api_key=api_key,
base_url=base_url,
use_azure=use_azure,
azure_deployment=azure_deployment,
api_version=api_version,
timeout=timeout,
client_configs=client_configs,
)
@ -241,7 +298,7 @@ async def openai_complete_if_cache(
try:
# Don't use async with context manager, use client directly
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
response = await openai_async_client.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
@ -453,7 +510,7 @@ async def openai_complete_if_cache(
raise InvalidResponseError("Invalid response from OpenAI API")
message = response.choices[0].message
# Handle parsed responses (structured output via response_format)
# When using beta.chat.completions.parse(), the response is in message.parsed
if hasattr(message, "parsed") and message.parsed is not None:
@ -492,7 +549,9 @@ async def openai_complete_if_cache(
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = f"<think>{reasoning_content}</think>{final_content}"
final_content = (
f"<think>{reasoning_content}</think>{final_content}"
)
else:
# COT disabled, only use regular content
final_content = content or ""
@ -629,24 +688,40 @@ async def openai_embed(
embedding_dim: int | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API.
This function supports both standard OpenAI and Azure OpenAI services.
Args:
texts: List of texts to embed.
model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small").
For Azure, this can be the deployment name.
base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint.
For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None.
For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
client_configs: Additional configuration options for the AsyncOpenAI client.
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
explicit parameters (api_key, base_url). Supports proxy configuration,
custom headers, retry policies, etc.
token_tracker: Optional token usage tracker for monitoring API usage.
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
When True, creates an AsyncAzureOpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable.
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION
environment variable.
Returns:
A numpy array of embeddings, one per input text.
@ -656,9 +731,14 @@ async def openai_embed(
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
# Create the OpenAI client
# Create the OpenAI client (supports both OpenAI and Azure)
openai_async_client = create_openai_async_client(
api_key=api_key, base_url=base_url, client_configs=client_configs
api_key=api_key,
base_url=base_url,
use_azure=use_azure,
azure_deployment=azure_deployment,
api_version=api_version,
client_configs=client_configs,
)
async with openai_async_client:
@ -691,3 +771,169 @@ async def openai_embed(
for dp in response.data
]
)
# Azure OpenAI wrapper functions for backward compatibility
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
stream: bool | None = None,
timeout: int | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
"""Azure OpenAI completion wrapper function.
This function provides backward compatibility by wrapping the unified
openai_complete_if_cache implementation with Azure-specific parameter handling.
All parameters from the underlying openai_complete_if_cache are exposed to ensure
full feature parity and API consistency.
"""
# Handle Azure-specific environment variables and parameters
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
)
api_key = (
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
# Call the unified implementation with Azure-specific parameters
return await openai_complete_if_cache(
model=model,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
base_url=base_url,
api_key=api_key,
token_tracker=token_tracker,
stream=stream,
timeout=timeout,
use_azure=True,
azure_deployment=deployment,
api_version=api_version,
keyword_extraction=keyword_extraction,
**kwargs,
)
async def azure_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
"""Azure OpenAI complete wrapper function.
Provides backward compatibility for azure_openai_complete calls.
"""
if history_messages is None:
history_messages = []
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
client_configs: dict[str, Any] | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Azure OpenAI embedding wrapper function.
This function provides backward compatibility by wrapping the unified
openai_embed implementation with Azure-specific parameter handling.
All parameters from the underlying openai_embed are exposed to ensure
full feature parity and API consistency.
IMPORTANT - Decorator Usage:
1. This function is decorated with @wrap_embedding_func_with_attrs to provide
the EmbeddingFunc interface for users who need to access embedding_dim
and other attributes.
2. This function does NOT use @retry decorator to avoid double-wrapping,
since the underlying openai_embed.func already has retry logic.
3. This function calls openai_embed.func (the unwrapped function) instead of
openai_embed (the EmbeddingFunc instance) to avoid double decoration issues:
Correct: await openai_embed.func(...) # Calls unwrapped function with retry
Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
The call chain with correct implementation:
azure_openai_embed(texts)
EmbeddingFunc.__call__(texts) # azure's decorator
azure_openai_embed_impl(texts, embedding_dim=1536)
openai_embed.func(texts, ...)
@retry_wrapper(texts, ...) # openai's retry (only one layer)
openai_embed_impl(texts, ...)
actual embedding computation
"""
# Handle Azure-specific environment variables and parameters
deployment = (
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
or model
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
)
base_url = (
base_url
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
or os.getenv("EMBEDDING_BINDING_HOST")
)
api_key = (
api_key
or os.getenv("AZURE_EMBEDDING_API_KEY")
or os.getenv("EMBEDDING_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
# CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration
# openai_embed is an EmbeddingFunc instance, .func accesses the underlying function
return await openai_embed.func(
texts=texts,
model=model or deployment,
base_url=base_url,
api_key=api_key,
token_tracker=token_tracker,
client_configs=client_configs,
use_azure=True,
azure_deployment=deployment,
api_version=api_version,
)