feat(gemini): simplify config for Gemini clients (#679)

The cross_encoder for Gemini already supported passing in a custom client.

I replicated the same input pattern to embedder and llm_client.

The value is, you can support custom API endpoints and other options like below:

        cross_encoder=GeminiRerankerClient(
            client=genai.Client(
                api_key=os.environ.get('GOOGLE_GENAI_API_KEY'),
                http_options=types.HttpOptions(api_version='v1alpha')),
            config=LLMConfig(
                model="gemini-2.5-flash-lite-preview-06-17"
            )
        ))
This commit is contained in:
alan blount 2025-07-06 00:14:55 -04:00 committed by GitHub
parent 3ecedf8ebf
commit 432ff7577d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 27 additions and 11 deletions

View file

@ -41,6 +41,9 @@ DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
class GeminiRerankerClient(CrossEncoderClient):
"""
Google Gemini Reranker Client
"""
def __init__(
self,
config: LLMConfig | None = None,
@ -57,7 +60,6 @@ class GeminiRerankerClient(CrossEncoderClient):
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = LLMConfig()

View file

@ -46,16 +46,27 @@ class GeminiEmbedder(EmbedderClient):
"""
Google Gemini Embedder Client
"""
def __init__(
self,
config: GeminiEmbedderConfig | None = None,
client: 'genai.Client | None' = None,
):
"""
Initialize the GeminiEmbedder with the provided configuration and client.
def __init__(self, config: GeminiEmbedderConfig | None = None):
Args:
config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = GeminiEmbedderConfig()
self.config = config
# Configure the Gemini API
self.client = genai.Client(
api_key=config.api_key,
)
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
async def create(
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]

View file

@ -76,6 +76,7 @@ class GeminiClient(LLMClient):
cache: bool = False,
max_tokens: int = DEFAULT_MAX_TOKENS,
thinking_config: types.ThinkingConfig | None = None,
client: 'genai.Client | None' = None,
):
"""
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
@ -85,7 +86,7 @@ class GeminiClient(LLMClient):
cache (bool): Whether to use caching for responses. Defaults to False.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Only use with models that support thinking (gemini-2.5+). Defaults to None.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = LLMConfig()
@ -93,10 +94,12 @@ class GeminiClient(LLMClient):
super().__init__(config, cache)
self.model = config.model
# Configure the Gemini API
self.client = genai.Client(
api_key=config.api_key,
)
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
self.max_tokens = max_tokens
self.thinking_config = thinking_config