Refactor keyword_extraction from kwargs to explicit parameter
• Add keyword_extraction param to functions
• Remove kwargs.pop() calls
• Update function signatures
• Improve parameter documentation
• Make parameter handling consistent
(cherry picked from commit 2f16065256)
This commit is contained in:
parent
35dd68d767
commit
033ee5c0f5
1 changed files with 37 additions and 10 deletions
|
|
@ -138,6 +138,9 @@ async def openai_complete_if_cache(
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
token_tracker: Any | None = None,
|
token_tracker: Any | None = None,
|
||||||
|
stream: bool | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
keyword_extraction: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
|
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
|
||||||
|
|
@ -167,13 +170,15 @@ async def openai_complete_if_cache(
|
||||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||||
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
|
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
|
||||||
|
stream: Whether to stream the response. Default is False.
|
||||||
|
timeout: Request timeout in seconds. Default is None.
|
||||||
|
keyword_extraction: Whether to enable keyword extraction mode. When True, triggers
|
||||||
|
special response formatting for keyword extraction. Default is False.
|
||||||
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
||||||
Special kwargs:
|
Special kwargs:
|
||||||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||||
These will be passed to the client constructor but will be overridden by
|
These will be passed to the client constructor but will be overridden by
|
||||||
explicit parameters (api_key, base_url).
|
explicit parameters (api_key, base_url).
|
||||||
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
|
||||||
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The completed text (with integrated COT content if available) or an async iterator
|
The completed text (with integrated COT content if available) or an async iterator
|
||||||
|
|
@ -194,7 +199,6 @@ async def openai_complete_if_cache(
|
||||||
|
|
||||||
# Remove special kwargs that shouldn't be passed to OpenAI
|
# Remove special kwargs that shouldn't be passed to OpenAI
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
kwargs.pop("keyword_extraction", None)
|
|
||||||
|
|
||||||
# Extract client configuration options
|
# Extract client configuration options
|
||||||
client_configs = kwargs.pop("openai_client_configs", {})
|
client_configs = kwargs.pop("openai_client_configs", {})
|
||||||
|
|
@ -224,6 +228,12 @@ async def openai_complete_if_cache(
|
||||||
|
|
||||||
messages = kwargs.pop("messages", messages)
|
messages = kwargs.pop("messages", messages)
|
||||||
|
|
||||||
|
# Add explicit parameters back to kwargs so they're passed to OpenAI API
|
||||||
|
if stream is not None:
|
||||||
|
kwargs["stream"] = stream
|
||||||
|
if timeout is not None:
|
||||||
|
kwargs["timeout"] = timeout
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Don't use async with context manager, use client directly
|
# Don't use async with context manager, use client directly
|
||||||
if "response_format" in kwargs:
|
if "response_format" in kwargs:
|
||||||
|
|
@ -512,7 +522,6 @@ async def openai_complete(
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["response_format"] = "json"
|
kwargs["response_format"] = "json"
|
||||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
|
|
@ -521,6 +530,7 @@ async def openai_complete(
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -535,7 +545,6 @@ async def gpt_4o_complete(
|
||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
|
|
@ -544,6 +553,7 @@ async def gpt_4o_complete(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
enable_cot=enable_cot,
|
enable_cot=enable_cot,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -558,7 +568,6 @@ async def gpt_4o_mini_complete(
|
||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
|
|
@ -567,6 +576,7 @@ async def gpt_4o_mini_complete(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
enable_cot=enable_cot,
|
enable_cot=enable_cot,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -581,13 +591,13 @@ async def nvidia_openai_complete(
|
||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
kwargs.pop("keyword_extraction", None)
|
|
||||||
result = await openai_complete_if_cache(
|
result = await openai_complete_if_cache(
|
||||||
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
|
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
enable_cot=enable_cot,
|
enable_cot=enable_cot,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
base_url="https://integrate.api.nvidia.com/v1",
|
base_url="https://integrate.api.nvidia.com/v1",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
@ -609,6 +619,7 @@ async def openai_embed(
|
||||||
model: str = "text-embedding-3-small",
|
model: str = "text-embedding-3-small",
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
embedding_dim: int | None = None,
|
||||||
client_configs: dict[str, Any] | None = None,
|
client_configs: dict[str, Any] | None = None,
|
||||||
token_tracker: Any | None = None,
|
token_tracker: Any | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
|
@ -619,6 +630,12 @@ async def openai_embed(
|
||||||
model: The OpenAI embedding model to use.
|
model: The OpenAI embedding model to use.
|
||||||
base_url: Optional base URL for the OpenAI API.
|
base_url: Optional base URL for the OpenAI API.
|
||||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||||
|
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||||
|
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||||
|
Do NOT manually pass this parameter when calling the function directly.
|
||||||
|
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||||
|
Manually passing a different value will trigger a warning and be ignored.
|
||||||
|
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
|
||||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||||
These will override any default configurations but will be overridden by
|
These will override any default configurations but will be overridden by
|
||||||
explicit parameters (api_key, base_url).
|
explicit parameters (api_key, base_url).
|
||||||
|
|
@ -638,9 +655,19 @@ async def openai_embed(
|
||||||
)
|
)
|
||||||
|
|
||||||
async with openai_async_client:
|
async with openai_async_client:
|
||||||
response = await openai_async_client.embeddings.create(
|
# Prepare API call parameters
|
||||||
model=model, input=texts, encoding_format="base64"
|
api_params = {
|
||||||
)
|
"model": model,
|
||||||
|
"input": texts,
|
||||||
|
"encoding_format": "base64",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add dimensions parameter only if embedding_dim is provided
|
||||||
|
if embedding_dim is not None:
|
||||||
|
api_params["dimensions"] = embedding_dim
|
||||||
|
|
||||||
|
# Make API call
|
||||||
|
response = await openai_async_client.embeddings.create(**api_params)
|
||||||
|
|
||||||
if token_tracker and hasattr(response, "usage"):
|
if token_tracker and hasattr(response, "usage"):
|
||||||
token_counts = {
|
token_counts = {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue