This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:32 +08:00
parent 65d26c6476
commit fd486c287a
2 changed files with 52 additions and 298 deletions

View file

@ -1,193 +1,22 @@
from collections.abc import Iterable
import os
import pipmaster as pm # Pipmaster for dynamic library install
"""
Azure OpenAI compatibility layer.
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
This module provides backward compatibility by re-exporting Azure OpenAI functions
from the main openai module where the actual implementation resides.
from openai import (
AsyncAzureOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from openai.types.chat import ChatCompletionMessageParam
All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai,
with this module serving as a thin compatibility wrapper for existing code that
imports from lightrag.llm.azure_openai.
"""
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
from lightrag.llm.openai import (
azure_openai_complete_if_cache,
azure_openai_complete,
azure_openai_embed,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APIConnectionError)
),
)
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt: str | None = None,
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
if enable_cot:
logger.debug(
"enable_cot=True is not supported for the Azure OpenAI API and will be ignored."
)
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
)
api_key = (
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
kwargs.pop("hashing_kv", None)
timeout = kwargs.pop("timeout", None)
# Handle keyword extraction mode
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=deployment,
api_key=api_key,
api_version=api_version,
timeout=timeout,
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hasattr(response, "__aiter__"):
async def inner():
async for chunk in response:
if len(chunk.choices) == 0:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
return inner()
else:
message = response.choices[0].message
# Handle parsed responses (structured output via response_format)
# When using beta.chat.completions.parse(), the response is in message.parsed
if hasattr(message, "parsed") and message.parsed is not None:
# Serialize the parsed structured response to JSON
content = message.parsed.model_dump_json()
logger.debug("Using parsed structured response from API")
else:
# Handle regular content responses
content = message.content
if content and r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
deployment = (
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
or model
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
)
base_url = (
base_url
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
or os.getenv("EMBEDDING_BINDING_HOST")
)
api_key = (
api_key
or os.getenv("AZURE_EMBEDDING_API_KEY")
or os.getenv("EMBEDDING_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=deployment,
api_key=api_key,
api_version=api_version,
)
response = await openai_async_client.embeddings.create(
model=model or deployment, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
__all__ = [
"azure_openai_complete_if_cache",
"azure_openai_complete",
"azure_openai_embed",
]

View file

@ -107,26 +107,13 @@ def create_openai_async_client(
"LLM_BINDING_API_KEY"
)
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs
merged_configs = {
**client_configs,
"api_key": api_key,
}
# Add explicit parameters (override client_configs)
if base_url is not None:
merged_configs["azure_endpoint"] = base_url
if azure_deployment is not None:
merged_configs["azure_deployment"] = azure_deployment
if api_version is not None:
merged_configs["api_version"] = api_version
if timeout is not None:
merged_configs["timeout"] = timeout
return AsyncAzureOpenAI(**merged_configs)
return AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=azure_deployment,
api_key=api_key,
api_version=api_version,
timeout=timeout,
)
else:
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
@ -205,33 +192,23 @@ async def openai_complete_if_cache(
6. For non-streaming: COT content is prepended to regular content with <think> tags.
Args:
model: The OpenAI model to use. For Azure, this can be the deployment name.
model: The OpenAI model to use.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
base_url: Optional base URL for the OpenAI API. For Azure, this should be the
Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment
variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
token_tracker: Optional token usage tracker for monitoring API usage.
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
stream: Whether to stream the response. Default is False.
timeout: Request timeout in seconds. Default is None.
keyword_extraction: Whether to enable keyword extraction mode. When True, triggers
special response formatting for keyword extraction. Default is False.
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
When True, creates an AsyncAzureOpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable.
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION
environment variable.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url). Supports proxy configuration,
custom headers, retry policies, etc.
explicit parameters (api_key, base_url).
Returns:
The completed text (with integrated COT content if available) or an async iterator
@ -694,34 +671,21 @@ async def openai_embed(
) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API.
This function supports both standard OpenAI and Azure OpenAI services.
Args:
texts: List of texts to embed.
model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small").
For Azure, this can be the deployment name.
base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint.
For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None.
For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None.
model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). Supports proxy configuration,
custom headers, retry policies, etc.
explicit parameters (api_key, base_url).
token_tracker: Optional token usage tracker for monitoring API usage.
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
When True, creates an AsyncAzureOpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable.
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION
environment variable.
Returns:
A numpy array of embeddings, one per input text.
@ -782,20 +746,14 @@ async def azure_openai_complete_if_cache(
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
stream: bool | None = None,
timeout: int | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
"""Azure OpenAI completion wrapper function.
This function provides backward compatibility by wrapping the unified
openai_complete_if_cache implementation with Azure-specific parameter handling.
All parameters from the underlying openai_complete_if_cache are exposed to ensure
full feature parity and API consistency.
"""
# Handle Azure-specific environment variables and parameters
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
@ -810,7 +768,10 @@ async def azure_openai_complete_if_cache(
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
# Pop timeout from kwargs if present (will be handled by openai_complete_if_cache)
timeout = kwargs.pop("timeout", None)
# Call the unified implementation with Azure-specific parameters
return await openai_complete_if_cache(
model=model,
@ -820,8 +781,6 @@ async def azure_openai_complete_if_cache(
enable_cot=enable_cot,
base_url=base_url,
api_key=api_key,
token_tracker=token_tracker,
stream=stream,
timeout=timeout,
use_azure=True,
azure_deployment=deployment,
@ -832,14 +791,10 @@ async def azure_openai_complete_if_cache(
async def azure_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs
) -> str:
"""Azure OpenAI complete wrapper function.
Provides backward compatibility for azure_openai_complete calls.
"""
if history_messages is None:
@ -856,51 +811,24 @@ async def azure_openai_complete(
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
client_configs: dict[str, Any] | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Azure OpenAI embedding wrapper function.
This function provides backward compatibility by wrapping the unified
openai_embed implementation with Azure-specific parameter handling.
All parameters from the underlying openai_embed are exposed to ensure
full feature parity and API consistency.
IMPORTANT - Decorator Usage:
1. This function is decorated with @wrap_embedding_func_with_attrs to provide
the EmbeddingFunc interface for users who need to access embedding_dim
and other attributes.
2. This function does NOT use @retry decorator to avoid double-wrapping,
since the underlying openai_embed.func already has retry logic.
3. This function calls openai_embed.func (the unwrapped function) instead of
openai_embed (the EmbeddingFunc instance) to avoid double decoration issues:
Correct: await openai_embed.func(...) # Calls unwrapped function with retry
Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
The call chain with correct implementation:
azure_openai_embed(texts)
EmbeddingFunc.__call__(texts) # azure's decorator
azure_openai_embed_impl(texts, embedding_dim=1536)
openai_embed.func(texts, ...)
@retry_wrapper(texts, ...) # openai's retry (only one layer)
openai_embed_impl(texts, ...)
actual embedding computation
"""
# Handle Azure-specific environment variables and parameters
deployment = (
@ -923,16 +851,13 @@ async def azure_openai_embed(
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
# CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration
# openai_embed is an EmbeddingFunc instance, .func accesses the underlying function
return await openai_embed.func(
# Call the unified implementation with Azure-specific parameters
return await openai_embed(
texts=texts,
model=model or deployment,
base_url=base_url,
api_key=api_key,
token_tracker=token_tracker,
client_configs=client_configs,
use_azure=True,
azure_deployment=deployment,
api_version=api_version,