feat(gemini): simplify config for Gemini clients (#679)
The cross_encoder for Gemini already supported passing in a custom client.
I replicated the same input pattern to embedder and llm_client.
The value is, you can support custom API endpoints and other options like below:
cross_encoder=GeminiRerankerClient(
client=genai.Client(
api_key=os.environ.get('GOOGLE_GENAI_API_KEY'),
http_options=types.HttpOptions(api_version='v1alpha')),
config=LLMConfig(
model="gemini-2.5-flash-lite-preview-06-17"
)
))
This commit is contained in:
parent
3ecedf8ebf
commit
432ff7577d
3 changed files with 27 additions and 11 deletions
|
|
@ -41,6 +41,9 @@ DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
|||
|
||||
|
||||
class GeminiRerankerClient(CrossEncoderClient):
|
||||
"""
|
||||
Google Gemini Reranker Client
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
|
|
@ -57,7 +60,6 @@ class GeminiRerankerClient(CrossEncoderClient):
|
|||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||
"""
|
||||
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
||||
|
|
|
|||
|
|
@ -46,16 +46,27 @@ class GeminiEmbedder(EmbedderClient):
|
|||
"""
|
||||
Google Gemini Embedder Client
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config: GeminiEmbedderConfig | None = None,
|
||||
client: 'genai.Client | None' = None,
|
||||
):
|
||||
"""
|
||||
Initialize the GeminiEmbedder with the provided configuration and client.
|
||||
|
||||
def __init__(self, config: GeminiEmbedderConfig | None = None):
|
||||
Args:
|
||||
config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
|
||||
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||
"""
|
||||
if config is None:
|
||||
config = GeminiEmbedderConfig()
|
||||
|
||||
self.config = config
|
||||
|
||||
# Configure the Gemini API
|
||||
self.client = genai.Client(
|
||||
api_key=config.api_key,
|
||||
)
|
||||
if client is None:
|
||||
self.client = genai.Client(api_key=config.api_key)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
async def create(
|
||||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ class GeminiClient(LLMClient):
|
|||
cache: bool = False,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
thinking_config: types.ThinkingConfig | None = None,
|
||||
client: 'genai.Client | None' = None,
|
||||
):
|
||||
"""
|
||||
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
||||
|
|
@ -85,7 +86,7 @@ class GeminiClient(LLMClient):
|
|||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
||||
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
||||
|
||||
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
|
@ -93,10 +94,12 @@ class GeminiClient(LLMClient):
|
|||
super().__init__(config, cache)
|
||||
|
||||
self.model = config.model
|
||||
# Configure the Gemini API
|
||||
self.client = genai.Client(
|
||||
api_key=config.api_key,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
self.client = genai.Client(api_key=config.api_key)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
self.max_tokens = max_tokens
|
||||
self.thinking_config = thinking_config
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue