feat(gemini): simplify config for Gemini clients (#679)
The cross_encoder for Gemini already supported passing in a custom client.
I replicated the same input pattern to embedder and llm_client.
The value is, you can support custom API endpoints and other options like below:
cross_encoder=GeminiRerankerClient(
client=genai.Client(
api_key=os.environ.get('GOOGLE_GENAI_API_KEY'),
http_options=types.HttpOptions(api_version='v1alpha')),
config=LLMConfig(
model="gemini-2.5-flash-lite-preview-06-17"
)
))
This commit is contained in:
parent
3ecedf8ebf
commit
432ff7577d
3 changed files with 27 additions and 11 deletions
|
|
@ -41,6 +41,9 @@ DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
||||||
|
|
||||||
|
|
||||||
class GeminiRerankerClient(CrossEncoderClient):
|
class GeminiRerankerClient(CrossEncoderClient):
|
||||||
|
"""
|
||||||
|
Google Gemini Reranker Client
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
|
|
@ -57,7 +60,6 @@ class GeminiRerankerClient(CrossEncoderClient):
|
||||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,16 +46,27 @@ class GeminiEmbedder(EmbedderClient):
|
||||||
"""
|
"""
|
||||||
Google Gemini Embedder Client
|
Google Gemini Embedder Client
|
||||||
"""
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GeminiEmbedderConfig | None = None,
|
||||||
|
client: 'genai.Client | None' = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the GeminiEmbedder with the provided configuration and client.
|
||||||
|
|
||||||
def __init__(self, config: GeminiEmbedderConfig | None = None):
|
Args:
|
||||||
|
config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
|
||||||
|
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||||
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = GeminiEmbedderConfig()
|
config = GeminiEmbedderConfig()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Configure the Gemini API
|
if client is None:
|
||||||
self.client = genai.Client(
|
self.client = genai.Client(api_key=config.api_key)
|
||||||
api_key=config.api_key,
|
else:
|
||||||
)
|
self.client = client
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,7 @@ class GeminiClient(LLMClient):
|
||||||
cache: bool = False,
|
cache: bool = False,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
thinking_config: types.ThinkingConfig | None = None,
|
thinking_config: types.ThinkingConfig | None = None,
|
||||||
|
client: 'genai.Client | None' = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
||||||
|
|
@ -85,7 +86,7 @@ class GeminiClient(LLMClient):
|
||||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||||
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
||||||
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
||||||
|
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
@ -93,10 +94,12 @@ class GeminiClient(LLMClient):
|
||||||
super().__init__(config, cache)
|
super().__init__(config, cache)
|
||||||
|
|
||||||
self.model = config.model
|
self.model = config.model
|
||||||
# Configure the Gemini API
|
|
||||||
self.client = genai.Client(
|
if client is None:
|
||||||
api_key=config.api_key,
|
self.client = genai.Client(api_key=config.api_key)
|
||||||
)
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.thinking_config = thinking_config
|
self.thinking_config = thinking_config
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue