This commit is contained in:
Raphaël MANSUY 2025-12-04 19:19:21 +08:00
parent 79698f6fae
commit 621621786a
2 changed files with 69 additions and 248 deletions

View file

@ -26,6 +26,7 @@ from lightrag.utils import (
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
@ -46,6 +47,7 @@ async def azure_openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
if enable_cot:
@ -66,9 +68,12 @@ async def azure_openai_complete_if_cache(
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
timeout = kwargs.pop("timeout", None)
# Handle keyword extraction mode
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=deployment,
@ -117,12 +122,12 @@ async def azure_openai_complete_if_cache(
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
return result

View file

@ -77,73 +77,46 @@ class InvalidResponseError(Exception):
def create_openai_async_client(
api_key: str | None = None,
base_url: str | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
timeout: int | None = None,
client_configs: dict[str, Any] | None = None,
) -> AsyncOpenAI:
"""Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration.
"""Create an AsyncOpenAI client with the given configuration.
Args:
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
use_azure: Whether to create an Azure OpenAI client. Default is False.
azure_deployment: Azure OpenAI deployment name (only used when use_azure=True).
api_version: Azure OpenAI API version (only used when use_azure=True).
timeout: Request timeout in seconds.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
Returns:
An AsyncOpenAI or AsyncAzureOpenAI client instance.
An AsyncOpenAI client instance.
"""
if use_azure:
from openai import AsyncAzureOpenAI
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
if not api_key:
api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get(
"LLM_BINDING_API_KEY"
)
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
return AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=azure_deployment,
api_key=api_key,
api_version=api_version,
timeout=timeout,
)
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
else:
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
merged_configs["base_url"] = os.environ.get(
"OPENAI_API_BASE", "https://api.openai.com/v1"
)
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
else:
merged_configs["base_url"] = os.environ.get(
"OPENAI_API_BASE", "https://api.openai.com/v1"
)
if timeout is not None:
merged_configs["timeout"] = timeout
return AsyncOpenAI(**merged_configs)
return AsyncOpenAI(**merged_configs)
@retry(
@ -168,9 +141,6 @@ async def openai_complete_if_cache(
stream: bool | None = None,
timeout: int | None = None,
keyword_extraction: bool = False,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@ -237,14 +207,10 @@ async def openai_complete_if_cache(
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
# Create the OpenAI client (supports both OpenAI and Azure)
# Create the OpenAI client
openai_async_client = create_openai_async_client(
api_key=api_key,
base_url=base_url,
use_azure=use_azure,
azure_deployment=azure_deployment,
api_version=api_version,
timeout=timeout,
client_configs=client_configs,
)
@ -275,7 +241,7 @@ async def openai_complete_if_cache(
try:
# Don't use async with context manager, use client directly
if "response_format" in kwargs:
response = await openai_async_client.chat.completions.parse(
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
@ -487,57 +453,46 @@ async def openai_complete_if_cache(
raise InvalidResponseError("Invalid response from OpenAI API")
message = response.choices[0].message
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", "")
# Handle parsed responses (structured output via response_format)
# When using beta.chat.completions.parse(), the response is in message.parsed
if hasattr(message, "parsed") and message.parsed is not None:
# Serialize the parsed structured response to JSON
final_content = message.parsed.model_dump_json()
logger.debug("Using parsed structured response from API")
else:
# Handle regular content responses
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", "")
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
content or ""
) # Use empty string if content is None
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
else:
# No reasoning content, use regular content
final_content = content or ""
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
f"<think>{reasoning_content}</think>{final_content}"
)
content or ""
) # Use empty string if content is None
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
else:
# COT disabled, only use regular content
# No reasoning content, use regular content
final_content = content or ""
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = f"<think>{reasoning_content}</think>{final_content}"
else:
# COT disabled, only use regular content
final_content = content or ""
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
# Apply Unicode decoding to final content if needed
if r"\u" in final_content:
@ -665,9 +620,6 @@ async def openai_embed(
embedding_dim: int | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API.
@ -695,14 +647,9 @@ async def openai_embed(
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
# Create the OpenAI client (supports both OpenAI and Azure)
# Create the OpenAI client
openai_async_client = create_openai_async_client(
api_key=api_key,
base_url=base_url,
use_azure=use_azure,
azure_deployment=azure_deployment,
api_version=api_version,
client_configs=client_configs,
api_key=api_key, base_url=base_url, client_configs=client_configs
)
async with openai_async_client:
@ -735,134 +682,3 @@ async def openai_embed(
for dp in response.data
]
)
# Azure OpenAI wrapper functions for backward compatibility
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
"""Azure OpenAI completion wrapper function.
This function provides backward compatibility by wrapping the unified
openai_complete_if_cache implementation with Azure-specific parameter handling.
"""
# Handle Azure-specific environment variables and parameters
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
)
api_key = (
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
# Pop timeout from kwargs if present (will be handled by openai_complete_if_cache)
timeout = kwargs.pop("timeout", None)
# Call the unified implementation with Azure-specific parameters
return await openai_complete_if_cache(
model=model,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
base_url=base_url,
api_key=api_key,
timeout=timeout,
use_azure=True,
azure_deployment=deployment,
api_version=api_version,
keyword_extraction=keyword_extraction,
**kwargs,
)
async def azure_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
"""Azure OpenAI complete wrapper function.
Provides backward compatibility for azure_openai_complete calls.
"""
if history_messages is None:
history_messages = []
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Azure OpenAI embedding wrapper function.
This function provides backward compatibility by wrapping the unified
openai_embed implementation with Azure-specific parameter handling.
"""
# Handle Azure-specific environment variables and parameters
deployment = (
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
or model
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
)
base_url = (
base_url
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
or os.getenv("EMBEDDING_BINDING_HOST")
)
api_key = (
api_key
or os.getenv("AZURE_EMBEDDING_API_KEY")
or os.getenv("EMBEDDING_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
# Call the unified implementation with Azure-specific parameters
return await openai_embed(
texts=texts,
model=model or deployment,
base_url=base_url,
api_key=api_key,
use_azure=True,
azure_deployment=deployment,
api_version=api_version,
)