feat: add kwargs to openai adapter functions

This commit is contained in:
Andrej Milicevic 2025-11-17 17:42:22 +01:00
parent de6b842d02
commit 0a4b1068a2

View file

@ -108,7 +108,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True, reraise=True,
) )
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from a user query. Generate a response from a user query.
@ -149,6 +149,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version, api_version=self.api_version,
response_model=response_model, response_model=response_model,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
except ( except (
ContentFilterFinishReasonError, ContentFilterFinishReasonError,
@ -174,6 +175,7 @@ class OpenAIAdapter(LLMInterface):
# api_base=self.fallback_endpoint, # api_base=self.fallback_endpoint,
response_model=response_model, response_model=response_model,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
except ( except (
ContentFilterFinishReasonError, ContentFilterFinishReasonError,
@ -199,7 +201,7 @@ class OpenAIAdapter(LLMInterface):
reraise=True, reraise=True,
) )
def create_structured_output( def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel: ) -> BaseModel:
""" """
Generate a response from a user query. Generate a response from a user query.
@ -239,6 +241,7 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version, api_version=self.api_version,
response_model=response_model, response_model=response_model,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
@retry( @retry(
@ -248,7 +251,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def create_transcript(self, input): async def create_transcript(self, input, **kwargs):
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -275,6 +278,7 @@ class OpenAIAdapter(LLMInterface):
api_base=self.endpoint, api_base=self.endpoint,
api_version=self.api_version, api_version=self.api_version,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )
return transcription return transcription
@ -286,7 +290,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def transcribe_image(self, input) -> BaseModel: async def transcribe_image(self, input, **kwargs) -> BaseModel:
""" """
Generate a transcription of an image from a user query. Generate a transcription of an image from a user query.
@ -331,4 +335,5 @@ class OpenAIAdapter(LLMInterface):
api_version=self.api_version, api_version=self.api_version,
max_completion_tokens=300, max_completion_tokens=300,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
**kwargs,
) )