Refactor keyword extraction handling to centralize response format logic
• Move response format to core function • Remove duplicate format assignments • Standardize keyword extraction flow • Clean up redundant parameter handling • Improve Azure OpenAI compatibility
This commit is contained in:
parent
46ce6d9a13
commit
c9e1c86e81
2 changed files with 11 additions and 8 deletions
|
|
@ -26,6 +26,7 @@ from lightrag.utils import (
|
||||||
safe_unicode_decode,
|
safe_unicode_decode,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
from lightrag.types import GPTKeywordExtractionFormat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -46,6 +47,7 @@ async def azure_openai_complete_if_cache(
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
api_version: str | None = None,
|
api_version: str | None = None,
|
||||||
|
keyword_extraction: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if enable_cot:
|
if enable_cot:
|
||||||
|
|
@ -66,9 +68,12 @@ async def azure_openai_complete_if_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
kwargs.pop("keyword_extraction", None)
|
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
|
# Handle keyword extraction mode
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
|
|
||||||
openai_async_client = AsyncAzureOpenAI(
|
openai_async_client = AsyncAzureOpenAI(
|
||||||
azure_endpoint=base_url,
|
azure_endpoint=base_url,
|
||||||
azure_deployment=deployment,
|
azure_deployment=deployment,
|
||||||
|
|
@ -117,12 +122,12 @@ async def azure_openai_complete_if_cache(
|
||||||
async def azure_openai_complete(
|
async def azure_openai_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("keyword_extraction", None)
|
|
||||||
result = await azure_openai_complete_if_cache(
|
result = await azure_openai_complete_if_cache(
|
||||||
os.getenv("LLM_MODEL", "gpt-4o-mini"),
|
os.getenv("LLM_MODEL", "gpt-4o-mini"),
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,10 @@ async def openai_complete_if_cache(
|
||||||
# Extract client configuration options
|
# Extract client configuration options
|
||||||
client_configs = kwargs.pop("openai_client_configs", {})
|
client_configs = kwargs.pop("openai_client_configs", {})
|
||||||
|
|
||||||
|
# Handle keyword extraction mode
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
|
|
||||||
# Create the OpenAI client
|
# Create the OpenAI client
|
||||||
openai_async_client = create_openai_async_client(
|
openai_async_client = create_openai_async_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
@ -522,8 +526,6 @@ async def openai_complete(
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
if keyword_extraction:
|
|
||||||
kwargs["response_format"] = "json"
|
|
||||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
model_name,
|
model_name,
|
||||||
|
|
@ -545,8 +547,6 @@ async def gpt_4o_complete(
|
||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
if keyword_extraction:
|
|
||||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
prompt,
|
prompt,
|
||||||
|
|
@ -568,8 +568,6 @@ async def gpt_4o_mini_complete(
|
||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
if keyword_extraction:
|
|
||||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
"gpt-4o-mini",
|
"gpt-4o-mini",
|
||||||
prompt,
|
prompt,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue