From f9e3ae1f909d896a1f05a44f2f4b7ae7140663fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:19:05 +0800 Subject: [PATCH] cherry-pick 30e86fa3 --- lightrag/llm/azure_openai.py | 189 +++++++++++++++++++++++++++++++---- 1 file changed, 172 insertions(+), 17 deletions(-) diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 1fc6feef..c72826f8 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -1,22 +1,177 @@ -""" -Azure OpenAI compatibility layer. +from collections.abc import Iterable +import os +import pipmaster as pm # Pipmaster for dynamic library install -This module provides backward compatibility by re-exporting Azure OpenAI functions -from the main openai module where the actual implementation resides. +# install specific modules +if not pm.is_installed("openai"): + pm.install("openai") -All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai, -with this module serving as a thin compatibility wrapper for existing code that -imports from lightrag.llm.azure_openai. -""" +from openai import ( + AsyncAzureOpenAI, + APIConnectionError, + RateLimitError, + APITimeoutError, +) +from openai.types.chat import ChatCompletionMessageParam -from lightrag.llm.openai import ( - azure_openai_complete_if_cache, - azure_openai_complete, - azure_openai_embed, +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, ) -__all__ = [ - "azure_openai_complete_if_cache", - "azure_openai_complete", - "azure_openai_embed", -] +from lightrag.utils import ( + wrap_embedding_func_with_attrs, + safe_unicode_decode, + logger, +) + +import numpy as np + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APIConnectionError) + ), +) +async def azure_openai_complete_if_cache( + model, + prompt, + system_prompt: str | None = None, + history_messages: Iterable[ChatCompletionMessageParam] | None = None, + enable_cot: bool = False, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, + **kwargs, +): + if enable_cot: + logger.debug( + "enable_cot=True is not supported for the Azure OpenAI API and will be ignored." + ) + deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") + base_url = ( + base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") + ) + api_key = ( + api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + kwargs.pop("hashing_kv", None) + kwargs.pop("keyword_extraction", None) + timeout = kwargs.pop("timeout", None) + + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=base_url, + azure_deployment=deployment, + api_key=api_key, + api_version=api_version, + timeout=timeout, + ) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + if prompt is not None: + messages.append({"role": "user", "content": prompt}) + + if "response_format" in kwargs: + response = await openai_async_client.beta.chat.completions.parse( + model=model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + if hasattr(response, "__aiter__"): + + async def inner(): + async for chunk in response: + if len(chunk.choices) == 0: + continue + content = chunk.choices[0].delta.content + if content is None: + continue + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + + return inner() + else: + content = response.choices[0].message.content + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + return content + + +async def azure_openai_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + kwargs.pop("keyword_extraction", None) + result = await azure_openai_complete_if_cache( + os.getenv("LLM_MODEL", "gpt-4o-mini"), + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + return result + + +@wrap_embedding_func_with_attrs(embedding_dim=1536) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), +) +async def azure_openai_embed( + texts: list[str], + model: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, +) -> np.ndarray: + deployment = ( + os.getenv("AZURE_EMBEDDING_DEPLOYMENT") + or model + or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") + ) + base_url = ( + base_url + or os.getenv("AZURE_EMBEDDING_ENDPOINT") + or os.getenv("EMBEDDING_BINDING_HOST") + ) + api_key = ( + api_key + or os.getenv("AZURE_EMBEDDING_API_KEY") + or os.getenv("EMBEDDING_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_EMBEDDING_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=base_url, + azure_deployment=deployment, + api_key=api_key, + api_version=api_version, + ) + + response = await openai_async_client.embeddings.create( + model=deployment, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data])