diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 66c3bfe4..a2bbfa23 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -138,6 +138,9 @@ async def openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, + stream: bool | None = None, + timeout: int | None = None, + keyword_extraction: bool = False, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -167,13 +170,15 @@ async def openai_complete_if_cache( api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. token_tracker: Optional token usage tracker for monitoring API usage. enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. + stream: Whether to stream the response. Default is False. + timeout: Request timeout in seconds. Default is None. + keyword_extraction: Whether to enable keyword extraction mode. When True, triggers + special response formatting for keyword extraction. Default is False. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by explicit parameters (api_key, base_url). - - hashing_kv: Will be removed from kwargs before passing to OpenAI. - - keyword_extraction: Will be removed from kwargs before passing to OpenAI. Returns: The completed text (with integrated COT content if available) or an async iterator @@ -194,7 +199,6 @@ async def openai_complete_if_cache( # Remove special kwargs that shouldn't be passed to OpenAI kwargs.pop("hashing_kv", None) - kwargs.pop("keyword_extraction", None) # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) @@ -224,6 +228,12 @@ async def openai_complete_if_cache( messages = kwargs.pop("messages", messages) + # Add explicit parameters back to kwargs so they're passed to OpenAI API + if stream is not None: + kwargs["stream"] = stream + if timeout is not None: + kwargs["timeout"] = timeout + try: # Don't use async with context manager, use client directly if "response_format" in kwargs: @@ -512,7 +522,6 @@ async def openai_complete( ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] @@ -521,6 +530,7 @@ async def openai_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -535,7 +545,6 @@ async def gpt_4o_complete( ) -> str: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( @@ -544,6 +553,7 @@ async def gpt_4o_complete( system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -558,7 +568,6 @@ async def gpt_4o_mini_complete( ) -> str: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( @@ -567,6 +576,7 @@ async def gpt_4o_mini_complete( system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -581,13 +591,13 @@ async def nvidia_openai_complete( ) -> str: if history_messages is None: history_messages = [] - kwargs.pop("keyword_extraction", None) result = await openai_complete_if_cache( "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, base_url="https://integrate.api.nvidia.com/v1", **kwargs, ) @@ -609,6 +619,7 @@ async def openai_embed( model: str = "text-embedding-3-small", base_url: str | None = None, api_key: str | None = None, + embedding_dim: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, ) -> np.ndarray: @@ -619,6 +630,12 @@ async def openai_embed( model: The OpenAI embedding model to use. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + embedding_dim: Optional embedding dimension for dynamic dimension reduction. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. + Do NOT manually pass this parameter when calling the function directly. + The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. + Manually passing a different value will trigger a warning and be ignored. + When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). @@ -638,9 +655,19 @@ async def openai_embed( ) async with openai_async_client: - response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="base64" - ) + # Prepare API call parameters + api_params = { + "model": model, + "input": texts, + "encoding_format": "base64", + } + + # Add dimensions parameter only if embedding_dim is provided + if embedding_dim is not None: + api_params["dimensions"] = embedding_dim + + # Make API call + response = await openai_async_client.embeddings.create(**api_params) if token_tracker and hasattr(response, "usage"): token_counts = {