diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 14a1b238..4cec3e71 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -114,24 +114,44 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s return "\n".join(history_lines) -def _extract_response_text(response: Any) -> str: - if getattr(response, "text", None): - return response.text +def _extract_response_text( + response: Any, extract_thoughts: bool = False +) -> tuple[str, str]: + """ + Extract text content from Gemini response, separating regular content from thoughts. + Args: + response: Gemini API response object + extract_thoughts: Whether to extract thought content separately + + Returns: + Tuple of (regular_text, thought_text) + """ candidates = getattr(response, "candidates", None) if not candidates: - return "" + return ("", "") + + regular_parts: list[str] = [] + thought_parts: list[str] = [] - parts: list[str] = [] for candidate in candidates: if not getattr(candidate, "content", None): continue - for part in getattr(candidate.content, "parts", []): + # Use 'or []' to handle None values from parts attribute + for part in getattr(candidate.content, "parts", None) or []: text = getattr(part, "text", None) - if text: - parts.append(text) + if not text: + continue - return "\n".join(parts) + # Check if this part is thought content using the 'thought' attribute + is_thought = getattr(part, "thought", False) + + if is_thought and extract_thoughts: + thought_parts.append(text) + elif not is_thought: + regular_parts.append(text) + + return ("\n".join(regular_parts), "\n".join(thought_parts)) async def gemini_complete_if_cache( @@ -139,18 +159,51 @@ async def gemini_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, - *, - api_key: str | None = None, + enable_cot: bool = False, base_url: str | None = None, - generation_config: dict[str, Any] | None = None, - keyword_extraction: bool = False, + api_key: str | None = None, token_tracker: Any | None = None, - hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity stream: bool | None = None, - enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently - timeout: float | None = None, # noqa: ARG001 - handled by caller if needed + keyword_extraction: bool = False, + generation_config: dict[str, Any] | None = None, **_: Any, ) -> str | AsyncIterator[str]: + """ + Complete a prompt using Gemini's API with Chain of Thought (COT) support. + + This function supports automatic integration of reasoning content from Gemini models + that provide Chain of Thought capabilities via the thinking_config API feature. + + COT Integration: + - When enable_cot=True: Thought content is wrapped in ... tags + - When enable_cot=False: Thought content is filtered out, only regular content returned + - Thought content is identified by the 'thought' attribute on response parts + - Requires thinking_config to be enabled in generation_config for API to return thoughts + + Args: + model: The Gemini model to use. + prompt: The prompt to complete. + system_prompt: Optional system prompt to include. + history_messages: Optional list of previous messages in the conversation. + api_key: Optional Gemini API key. If None, uses environment variable. + base_url: Optional custom API endpoint. + generation_config: Optional generation configuration dict. + keyword_extraction: Whether to use JSON response format. + token_tracker: Optional token usage tracker for monitoring API usage. + hashing_kv: Storage interface (for interface parity with other bindings). + stream: Whether to stream the response. + enable_cot: Whether to include Chain of Thought content in the response. + timeout: Request timeout (handled by caller if needed). + **_: Additional keyword arguments (ignored). + + Returns: + The completed text (with COT content if enable_cot=True) or an async iterator + of text chunks if streaming. COT content is wrapped in ... tags. + + Raises: + RuntimeError: If the response from Gemini is empty. + ValueError: If API key is not provided or configured. + """ loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) @@ -184,6 +237,11 @@ async def gemini_complete_if_cache( usage_container: dict[str, Any] = {} def _stream_model() -> None: + # COT state tracking for streaming + cot_active = False + cot_started = False + initial_content_seen = False + try: stream_kwargs = dict(request_kwargs) stream_iterator = client.models.generate_content_stream(**stream_kwargs) @@ -191,18 +249,59 @@ async def gemini_complete_if_cache( usage = getattr(chunk, "usage_metadata", None) if usage is not None: usage_container["usage"] = usage - text_piece = getattr(chunk, "text", None) or _extract_response_text(chunk) - if text_piece: - loop.call_soon_threadsafe(queue.put_nowait, text_piece) + + # Extract both regular and thought content + regular_text, thought_text = _extract_response_text( + chunk, extract_thoughts=True + ) + + if enable_cot: + # Process regular content + if regular_text: + if not initial_content_seen: + initial_content_seen = True + + # Close COT section if it was active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = False + + # Send regular content + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Process thought content + if thought_text: + if not initial_content_seen and not cot_started: + # Start COT section + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = True + cot_started = True + + # Send thought content if COT is active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, thought_text) + else: + # COT disabled - only send regular content + if regular_text: + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Ensure COT is properly closed if still active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as exc: # pragma: no cover - surface runtime issues + # Try to close COT tag before reporting error + if cot_active: + try: + loop.call_soon_threadsafe(queue.put_nowait, "") + except Exception: + pass loop.call_soon_threadsafe(queue.put_nowait, exc) loop.run_in_executor(None, _stream_model) async def _async_stream() -> AsyncIterator[str]: - accumulated = "" - emitted = "" try: while True: item = await queue.get() @@ -215,16 +314,9 @@ async def gemini_complete_if_cache( if "\\u" in chunk_text: chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) - accumulated += chunk_text - sanitized = remove_think_tags(accumulated) - if sanitized.startswith(emitted): - delta = sanitized[len(emitted) :] - else: - delta = sanitized - emitted = sanitized - - if delta: - yield delta + # Yield the chunk directly without filtering + # COT filtering is already handled in _stream_model() + yield chunk_text finally: usage = usage_container.get("usage") if token_tracker and usage: @@ -242,14 +334,33 @@ async def gemini_complete_if_cache( response = await asyncio.to_thread(_call_model) - text = _extract_response_text(response) - if not text: + # Extract both regular text and thought text + regular_text, thought_text = _extract_response_text(response, extract_thoughts=True) + + # Apply COT filtering logic based on enable_cot parameter + if enable_cot: + # Include thought content wrapped in tags + if thought_text and thought_text.strip(): + if not regular_text or regular_text.strip() == "": + # Only thought content available + final_text = f"{thought_text}" + else: + # Both content types present: prepend thought to regular content + final_text = f"{thought_text}{regular_text}" + else: + # No thought content, use regular content only + final_text = regular_text or "" + else: + # Filter out thought content, return only regular content + final_text = regular_text or "" + + if not final_text: raise RuntimeError("Gemini response did not contain any text content.") - if "\\u" in text: - text = safe_unicode_decode(text.encode("utf-8")) + if "\\u" in final_text: + final_text = safe_unicode_decode(final_text.encode("utf-8")) - text = remove_think_tags(text) + final_text = remove_think_tags(final_text) usage = getattr(response, "usage_metadata", None) if token_tracker and usage: @@ -261,8 +372,8 @@ async def gemini_complete_if_cache( } ) - logger.debug("Gemini response length: %s", len(text)) - return text + logger.debug("Gemini response length: %s", len(final_text)) + return final_text async def gemini_model_complete( diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index e045aa13..2cdbb72b 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -47,7 +47,7 @@ try: # Only enable Langfuse if both keys are configured if langfuse_public_key and langfuse_secret_key: - from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped] + from langfuse.openai import AsyncOpenAI LANGFUSE_ENABLED = True logger.info("Langfuse observability enabled for OpenAI client") @@ -77,73 +77,46 @@ class InvalidResponseError(Exception): def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, - use_azure: bool = False, - azure_deployment: str | None = None, - api_version: str | None = None, - timeout: int | None = None, client_configs: dict[str, Any] | None = None, ) -> AsyncOpenAI: - """Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration. + """Create an AsyncOpenAI client with the given configuration. Args: api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. - use_azure: Whether to create an Azure OpenAI client. Default is False. - azure_deployment: Azure OpenAI deployment name (only used when use_azure=True). - api_version: Azure OpenAI API version (only used when use_azure=True). - timeout: Request timeout in seconds. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Returns: - An AsyncOpenAI or AsyncAzureOpenAI client instance. + An AsyncOpenAI client instance. """ - if use_azure: - from openai import AsyncAzureOpenAI + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] - if not api_key: - api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get( - "LLM_BINDING_API_KEY" - ) + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } - return AsyncAzureOpenAI( - azure_endpoint=base_url, - azure_deployment=azure_deployment, - api_key=api_key, - api_version=api_version, - timeout=timeout, - ) + if client_configs is None: + client_configs = {} + + # Create a merged config dict with precedence: explicit params > client_configs > defaults + merged_configs = { + **client_configs, + "default_headers": default_headers, + "api_key": api_key, + } + + if base_url is not None: + merged_configs["base_url"] = base_url else: - if not api_key: - api_key = os.environ["OPENAI_API_KEY"] + merged_configs["base_url"] = os.environ.get( + "OPENAI_API_BASE", "https://api.openai.com/v1" + ) - default_headers = { - "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json", - } - - if client_configs is None: - client_configs = {} - - # Create a merged config dict with precedence: explicit params > client_configs > defaults - merged_configs = { - **client_configs, - "default_headers": default_headers, - "api_key": api_key, - } - - if base_url is not None: - merged_configs["base_url"] = base_url - else: - merged_configs["base_url"] = os.environ.get( - "OPENAI_API_BASE", "https://api.openai.com/v1" - ) - - if timeout is not None: - merged_configs["timeout"] = timeout - - return AsyncOpenAI(**merged_configs) + return AsyncOpenAI(**merged_configs) @retry( @@ -165,12 +138,6 @@ async def openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, - stream: bool | None = None, - timeout: int | None = None, - keyword_extraction: bool = False, - use_azure: bool = False, - azure_deployment: str | None = None, - api_version: str | None = None, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -200,15 +167,14 @@ async def openai_complete_if_cache( api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. token_tracker: Optional token usage tracker for monitoring API usage. enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. - stream: Whether to stream the response. Default is False. - timeout: Request timeout in seconds. Default is None. - keyword_extraction: Whether to enable keyword extraction mode. When True, triggers - special response formatting for keyword extraction. Default is False. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by explicit parameters (api_key, base_url). + - hashing_kv: Will be removed from kwargs before passing to OpenAI. + - keyword_extraction: Will be removed from kwargs before passing to OpenAI. + - stream: Whether to stream the response. Default is False. Returns: The completed text (with integrated COT content if available) or an async iterator @@ -229,22 +195,15 @@ async def openai_complete_if_cache( # Remove special kwargs that shouldn't be passed to OpenAI kwargs.pop("hashing_kv", None) + kwargs.pop("keyword_extraction", None) # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) - # Handle keyword extraction mode - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat - - # Create the OpenAI client (supports both OpenAI and Azure) + # Create the OpenAI client openai_async_client = create_openai_async_client( api_key=api_key, base_url=base_url, - use_azure=use_azure, - azure_deployment=azure_deployment, - api_version=api_version, - timeout=timeout, client_configs=client_configs, ) @@ -266,16 +225,10 @@ async def openai_complete_if_cache( messages = kwargs.pop("messages", messages) - # Add explicit parameters back to kwargs so they're passed to OpenAI API - if stream is not None: - kwargs["stream"] = stream - if timeout is not None: - kwargs["timeout"] = timeout - try: # Don't use async with context manager, use client directly if "response_format" in kwargs: - response = await openai_async_client.chat.completions.parse( + response = await openai_async_client.beta.chat.completions.parse( model=model, messages=messages, **kwargs ) else: @@ -487,57 +440,46 @@ async def openai_complete_if_cache( raise InvalidResponseError("Invalid response from OpenAI API") message = response.choices[0].message + content = getattr(message, "content", None) + reasoning_content = getattr(message, "reasoning_content", "") - # Handle parsed responses (structured output via response_format) - # When using beta.chat.completions.parse(), the response is in message.parsed - if hasattr(message, "parsed") and message.parsed is not None: - # Serialize the parsed structured response to JSON - final_content = message.parsed.model_dump_json() - logger.debug("Using parsed structured response from API") - else: - # Handle regular content responses - content = getattr(message, "content", None) - reasoning_content = getattr(message, "reasoning_content", "") + # Handle COT logic for non-streaming responses (only if enabled) + final_content = "" - # Handle COT logic for non-streaming responses (only if enabled) - final_content = "" - - if enable_cot: - # Check if we should include reasoning content - should_include_reasoning = False - if reasoning_content and reasoning_content.strip(): - if not content or content.strip() == "": - # Case 1: Only reasoning content, should include COT - should_include_reasoning = True - final_content = ( - content or "" - ) # Use empty string if content is None - else: - # Case 3: Both content and reasoning_content present, ignore reasoning - should_include_reasoning = False - final_content = content - else: - # No reasoning content, use regular content - final_content = content or "" - - # Apply COT wrapping if needed - if should_include_reasoning: - if r"\u" in reasoning_content: - reasoning_content = safe_unicode_decode( - reasoning_content.encode("utf-8") - ) + if enable_cot: + # Check if we should include reasoning content + should_include_reasoning = False + if reasoning_content and reasoning_content.strip(): + if not content or content.strip() == "": + # Case 1: Only reasoning content, should include COT + should_include_reasoning = True final_content = ( - f"{reasoning_content}{final_content}" - ) + content or "" + ) # Use empty string if content is None + else: + # Case 3: Both content and reasoning_content present, ignore reasoning + should_include_reasoning = False + final_content = content else: - # COT disabled, only use regular content + # No reasoning content, use regular content final_content = content or "" - # Validate final content - if not final_content or final_content.strip() == "": - logger.error("Received empty content from OpenAI API") - await openai_async_client.close() # Ensure client is closed - raise InvalidResponseError("Received empty content from OpenAI API") + # Apply COT wrapping if needed + if should_include_reasoning: + if r"\u" in reasoning_content: + reasoning_content = safe_unicode_decode( + reasoning_content.encode("utf-8") + ) + final_content = f"{reasoning_content}{final_content}" + else: + # COT disabled, only use regular content + final_content = content or "" + + # Validate final content + if not final_content or final_content.strip() == "": + logger.error("Received empty content from OpenAI API") + await openai_async_client.close() # Ensure client is closed + raise InvalidResponseError("Received empty content from OpenAI API") # Apply Unicode decoding to final content if needed if r"\u" in final_content: @@ -571,13 +513,15 @@ async def openai_complete( ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await openai_complete_if_cache( model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, - keyword_extraction=keyword_extraction, **kwargs, ) @@ -592,13 +536,15 @@ async def gpt_4o_complete( ) -> str: if history_messages is None: history_messages = [] + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o", prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, - keyword_extraction=keyword_extraction, **kwargs, ) @@ -613,13 +559,15 @@ async def gpt_4o_mini_complete( ) -> str: if history_messages is None: history_messages = [] + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o-mini", prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, - keyword_extraction=keyword_extraction, **kwargs, ) @@ -634,20 +582,20 @@ async def nvidia_openai_complete( ) -> str: if history_messages is None: history_messages = [] + kwargs.pop("keyword_extraction", None) result = await openai_complete_if_cache( "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, - keyword_extraction=keyword_extraction, base_url="https://integrate.api.nvidia.com/v1", **kwargs, ) return result -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=1536) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -662,12 +610,8 @@ async def openai_embed( model: str = "text-embedding-3-small", base_url: str | None = None, api_key: str | None = None, - embedding_dim: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, - use_azure: bool = False, - azure_deployment: str | None = None, - api_version: str | None = None, ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. @@ -676,12 +620,6 @@ async def openai_embed( model: The OpenAI embedding model to use. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. - embedding_dim: Optional embedding dimension for dynamic dimension reduction. - **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. - Do NOT manually pass this parameter when calling the function directly. - The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. - Manually passing a different value will trigger a warning and be ignored. - When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). @@ -695,30 +633,15 @@ async def openai_embed( RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ - # Create the OpenAI client (supports both OpenAI and Azure) + # Create the OpenAI client openai_async_client = create_openai_async_client( - api_key=api_key, - base_url=base_url, - use_azure=use_azure, - azure_deployment=azure_deployment, - api_version=api_version, - client_configs=client_configs, + api_key=api_key, base_url=base_url, client_configs=client_configs ) async with openai_async_client: - # Prepare API call parameters - api_params = { - "model": model, - "input": texts, - "encoding_format": "base64", - } - - # Add dimensions parameter only if embedding_dim is provided - if embedding_dim is not None: - api_params["dimensions"] = embedding_dim - - # Make API call - response = await openai_async_client.embeddings.create(**api_params) + response = await openai_async_client.embeddings.create( + model=model, input=texts, encoding_format="base64" + ) if token_tracker and hasattr(response, "usage"): token_counts = { @@ -735,158 +658,3 @@ async def openai_embed( for dp in response.data ] ) - - -# Azure OpenAI wrapper functions for backward compatibility -async def azure_openai_complete_if_cache( - model, - prompt, - system_prompt: str | None = None, - history_messages: list[dict[str, Any]] | None = None, - enable_cot: bool = False, - base_url: str | None = None, - api_key: str | None = None, - api_version: str | None = None, - keyword_extraction: bool = False, - **kwargs, -): - """Azure OpenAI completion wrapper function. - - This function provides backward compatibility by wrapping the unified - openai_complete_if_cache implementation with Azure-specific parameter handling. - """ - # Handle Azure-specific environment variables and parameters - deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") - base_url = ( - base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") - ) - api_key = ( - api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") - ) - api_version = ( - api_version - or os.getenv("AZURE_OPENAI_API_VERSION") - or os.getenv("OPENAI_API_VERSION") - ) - - # Pop timeout from kwargs if present (will be handled by openai_complete_if_cache) - timeout = kwargs.pop("timeout", None) - - # Call the unified implementation with Azure-specific parameters - return await openai_complete_if_cache( - model=model, - prompt=prompt, - system_prompt=system_prompt, - history_messages=history_messages, - enable_cot=enable_cot, - base_url=base_url, - api_key=api_key, - timeout=timeout, - use_azure=True, - azure_deployment=deployment, - api_version=api_version, - keyword_extraction=keyword_extraction, - **kwargs, - ) - - -async def azure_openai_complete( - prompt, - system_prompt=None, - history_messages=None, - keyword_extraction=False, - **kwargs, -) -> str: - """Azure OpenAI complete wrapper function. - - Provides backward compatibility for azure_openai_complete calls. - """ - if history_messages is None: - history_messages = [] - result = await azure_openai_complete_if_cache( - os.getenv("LLM_MODEL", "gpt-4o-mini"), - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - keyword_extraction=keyword_extraction, - **kwargs, - ) - return result - - -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) -async def azure_openai_embed( - texts: list[str], - model: str | None = None, - base_url: str | None = None, - api_key: str | None = None, - api_version: str | None = None, -) -> np.ndarray: - """Azure OpenAI embedding wrapper function. - - This function provides backward compatibility by wrapping the unified - openai_embed implementation with Azure-specific parameter handling. - - IMPORTANT - Decorator Usage: - - 1. This function is decorated with @wrap_embedding_func_with_attrs to provide - the EmbeddingFunc interface for users who need to access embedding_dim - and other attributes. - - 2. This function does NOT use @retry decorator to avoid double-wrapping, - since the underlying openai_embed.func already has retry logic. - - 3. This function calls openai_embed.func (the unwrapped function) instead of - openai_embed (the EmbeddingFunc instance) to avoid double decoration issues: - - ✅ Correct: await openai_embed.func(...) # Calls unwrapped function with retry - ❌ Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping - - Double decoration causes: - - Double injection of embedding_dim parameter - - Incorrect parameter passing to the underlying implementation - - Runtime errors due to parameter conflicts - - The call chain with correct implementation: - azure_openai_embed(texts) - → EmbeddingFunc.__call__(texts) # azure's decorator - → azure_openai_embed_impl(texts, embedding_dim=1536) - → openai_embed.func(texts, ...) - → @retry_wrapper(texts, ...) # openai's retry (only one layer) - → openai_embed_impl(texts, ...) - → actual embedding computation - """ - # Handle Azure-specific environment variables and parameters - deployment = ( - os.getenv("AZURE_EMBEDDING_DEPLOYMENT") - or model - or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") - ) - base_url = ( - base_url - or os.getenv("AZURE_EMBEDDING_ENDPOINT") - or os.getenv("EMBEDDING_BINDING_HOST") - ) - api_key = ( - api_key - or os.getenv("AZURE_EMBEDDING_API_KEY") - or os.getenv("EMBEDDING_BINDING_API_KEY") - ) - api_version = ( - api_version - or os.getenv("AZURE_EMBEDDING_API_VERSION") - or os.getenv("AZURE_OPENAI_API_VERSION") - or os.getenv("OPENAI_API_VERSION") - ) - - # CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration - # openai_embed is an EmbeddingFunc instance, .func accesses the underlying function - return await openai_embed.func( - texts=texts, - model=model or deployment, - base_url=base_url, - api_key=api_key, - use_azure=True, - azure_deployment=deployment, - api_version=api_version, - )