cherry-pick 46ce6d9a
This commit is contained in:
parent
7fa455ff07
commit
5a3c0c1499
1 changed files with 1 additions and 55 deletions
|
|
@ -30,57 +30,6 @@ from lightrag.utils import (
|
|||
import numpy as np
|
||||
|
||||
|
||||
def _normalize_openai_kwargs_for_model(model: str, kwargs: dict) -> None:
|
||||
"""
|
||||
Normalize OpenAI API parameters based on the model being used.
|
||||
|
||||
This function handles model-specific parameter requirements:
|
||||
- gpt-5-nano uses 'max_completion_tokens' instead of 'max_tokens'
|
||||
- gpt-5-nano uses reasoning tokens which consume from the token budget
|
||||
- gpt-5-nano doesn't support custom temperature values
|
||||
- Other models support both parameters
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., 'gpt-5-nano', 'gpt-4o', 'gpt-4o-mini')
|
||||
kwargs: The API parameters dict to normalize (modified in-place)
|
||||
"""
|
||||
# Handle max_tokens vs max_completion_tokens conversion for gpt-5 models
|
||||
if model.startswith("gpt-5"):
|
||||
# gpt-5-nano and variants use max_completion_tokens
|
||||
if "max_tokens" in kwargs and "max_completion_tokens" not in kwargs:
|
||||
# If only max_tokens is set, move it to max_completion_tokens
|
||||
max_tokens = kwargs.pop("max_tokens")
|
||||
# For gpt-5-nano, we need to account for reasoning tokens
|
||||
# Increase buffer to ensure actual content is generated
|
||||
# Reasoning typically uses 1.5-2x the actual content tokens needed
|
||||
kwargs["max_completion_tokens"] = int(max(max_tokens * 2.5, 300))
|
||||
else:
|
||||
# If both are set, remove max_tokens (it's not supported)
|
||||
max_tokens = kwargs.pop("max_tokens", None)
|
||||
if max_tokens and "max_completion_tokens" in kwargs:
|
||||
# If max_completion_tokens is already set and seems too small, increase it
|
||||
if kwargs["max_completion_tokens"] < 300:
|
||||
kwargs["max_completion_tokens"] = int(max(kwargs["max_completion_tokens"] * 2.5, 300))
|
||||
|
||||
# Ensure a minimum token budget for gpt-5-nano due to reasoning overhead
|
||||
if "max_completion_tokens" in kwargs:
|
||||
if kwargs["max_completion_tokens"] < 300:
|
||||
# Minimum 300 tokens to account for reasoning (reasoning can be expensive)
|
||||
original = kwargs["max_completion_tokens"]
|
||||
kwargs["max_completion_tokens"] = 300
|
||||
logger.debug(f"Increased max_completion_tokens from {original} to 300 for {model} (reasoning overhead)")
|
||||
|
||||
# Handle temperature constraint for gpt-5 models
|
||||
if model.startswith("gpt-5"):
|
||||
# gpt-5-nano requires default temperature (doesn't support custom values)
|
||||
# Remove any custom temperature setting
|
||||
if "temperature" in kwargs:
|
||||
kwargs.pop("temperature")
|
||||
logger.debug(f"Removed custom temperature for {model}: uses default")
|
||||
|
||||
logger.debug(f"Normalized parameters for {model}: {kwargs}")
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
|
@ -135,9 +84,6 @@ async def azure_openai_complete_if_cache(
|
|||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Normalize API parameters based on model requirements
|
||||
_normalize_openai_kwargs_for_model(model, kwargs)
|
||||
|
||||
if "response_format" in kwargs:
|
||||
response = await openai_async_client.beta.chat.completions.parse(
|
||||
model=model, messages=messages, **kwargs
|
||||
|
|
@ -226,6 +172,6 @@ async def azure_openai_embed(
|
|||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
model=model or deployment, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue