diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py
index a16ba830..01ccbbdf 100644
--- a/lightrag/api/lightrag_server.py
+++ b/lightrag/api/lightrag_server.py
@@ -89,6 +89,7 @@ class LLMConfigCache:
# Initialize configurations based on binding conditions
self.openai_llm_options = None
self.gemini_llm_options = None
+ self.gemini_embedding_options = None
self.ollama_llm_options = None
self.ollama_embedding_options = None
@@ -135,6 +136,23 @@ class LLMConfigCache:
)
self.ollama_embedding_options = {}
+ # Only initialize and log Gemini Embedding options when using Gemini Embedding binding
+ if args.embedding_binding == "gemini":
+ try:
+ from lightrag.llm.binding_options import GeminiEmbeddingOptions
+
+ self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict(
+ args
+ )
+ logger.info(
+ f"Gemini Embedding Options: {self.gemini_embedding_options}"
+ )
+ except ImportError:
+ logger.warning(
+ "GeminiEmbeddingOptions not available, using default configuration"
+ )
+ self.gemini_embedding_options = {}
+
def check_frontend_build():
"""Check if frontend is built and optionally check if source is up-to-date
@@ -296,6 +314,7 @@ def create_app(args):
"azure_openai",
"aws_bedrock",
"jina",
+ "gemini",
]:
raise Exception("embedding binding not supported")
@@ -646,6 +665,26 @@ def create_app(args):
return await jina_embed(
texts, embedding_dim=embedding_dim, base_url=host, api_key=api_key
)
+ elif binding == "gemini":
+ from lightrag.llm.gemini import gemini_embed
+
+ # Use pre-processed configuration if available, otherwise fallback to dynamic parsing
+ if config_cache.gemini_embedding_options is not None:
+ gemini_options = config_cache.gemini_embedding_options
+ else:
+ # Fallback for cases where config cache wasn't initialized properly
+ from lightrag.llm.binding_options import GeminiEmbeddingOptions
+
+ gemini_options = GeminiEmbeddingOptions.options_dict(args)
+
+ return await gemini_embed(
+ texts,
+ model=model,
+ base_url=host,
+ api_key=api_key,
+ embedding_dim=embedding_dim,
+ task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
+ )
else: # openai and compatible
from lightrag.llm.openai import openai_embed
diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py
index f17ba0f8..1cb52a81 100644
--- a/lightrag/llm/binding_options.py
+++ b/lightrag/llm/binding_options.py
@@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace
import argparse
import json
from dataclasses import asdict, dataclass, field
-from typing import Any, ClassVar, List
+from typing import Any, ClassVar, List, get_args, get_origin
from lightrag.utils import get_env_value
from lightrag.constants import DEFAULT_TEMPERATURE
+def _resolve_optional_type(field_type: Any) -> Any:
+ """Return the concrete type for Optional/Union annotations."""
+ origin = get_origin(field_type)
+ if origin in (list, dict, tuple):
+ return field_type
+
+ args = get_args(field_type)
+ if args:
+ non_none_args = [arg for arg in args if arg is not type(None)]
+ if len(non_none_args) == 1:
+ return non_none_args[0]
+ return field_type
+
+
# =============================================================================
# BindingOptions Base Class
# =============================================================================
@@ -177,9 +191,13 @@ class BindingOptions:
help=arg_item["help"],
)
else:
+ resolved_type = arg_item["type"]
+ if resolved_type is not None:
+ resolved_type = _resolve_optional_type(resolved_type)
+
group.add_argument(
f"--{arg_item['argname']}",
- type=arg_item["type"],
+ type=resolved_type,
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
@@ -210,7 +228,7 @@ class BindingOptions:
argdef = {
"argname": f"{args_prefix}-{field.name}",
"env_name": f"{env_var_prefix}{field.name.upper()}",
- "type": field.type,
+ "type": _resolve_optional_type(field.type),
"default": default_value,
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
}
@@ -454,6 +472,55 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
_binding_name: ClassVar[str] = "ollama_llm"
+# =============================================================================
+# Binding Options for Gemini
+# =============================================================================
+@dataclass
+class GeminiLLMOptions(BindingOptions):
+ """Options for Google Gemini models."""
+
+ _binding_name: ClassVar[str] = "gemini_llm"
+
+ temperature: float = DEFAULT_TEMPERATURE
+ top_p: float = 0.95
+ top_k: int = 40
+ max_output_tokens: int | None = None
+ candidate_count: int = 1
+ presence_penalty: float = 0.0
+ frequency_penalty: float = 0.0
+ stop_sequences: List[str] = field(default_factory=list)
+ seed: int | None = None
+ thinking_config: dict | None = None
+ safety_settings: dict | None = None
+
+ _help: ClassVar[dict[str, str]] = {
+ "temperature": "Controls randomness (0.0-2.0, higher = more creative)",
+ "top_p": "Nucleus sampling parameter (0.0-1.0)",
+ "top_k": "Limits sampling to the top K tokens (1 disables the limit)",
+ "max_output_tokens": "Maximum tokens generated in the response",
+ "candidate_count": "Number of candidates returned per request",
+ "presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
+ "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
+ "stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
+ "seed": "Random seed for reproducible generation (leave empty for random)",
+ "thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
+ "safety_settings": "JSON object with Gemini safety settings overrides",
+ }
+
+
+@dataclass
+class GeminiEmbeddingOptions(BindingOptions):
+ """Options for Google Gemini embedding models."""
+
+ _binding_name: ClassVar[str] = "gemini_embedding"
+
+ task_type: str = "RETRIEVAL_DOCUMENT"
+
+ _help: ClassVar[dict[str, str]] = {
+ "task_type": "Task type for embedding optimization (RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION)",
+ }
+
+
# =============================================================================
# Binding Options for OpenAI
# =============================================================================
diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py
new file mode 100644
index 00000000..3954e814
--- /dev/null
+++ b/lightrag/llm/gemini.py
@@ -0,0 +1,564 @@
+"""
+Gemini LLM binding for LightRAG.
+
+This module provides asynchronous helpers that adapt Google's Gemini models
+to the same interface used by the rest of the LightRAG LLM bindings. The
+implementation mirrors the OpenAI helpers while relying on the official
+``google-genai`` client under the hood.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from collections.abc import AsyncIterator
+from functools import lru_cache
+from typing import Any
+
+import numpy as np
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.utils import (
+ logger,
+ remove_think_tags,
+ safe_unicode_decode,
+ wrap_embedding_func_with_attrs,
+)
+
+import pipmaster as pm
+
+# Install the Google Gemini client on demand
+if not pm.is_installed("google-genai"):
+ pm.install("google-genai")
+
+from google import genai # type: ignore
+from google.genai import types # type: ignore
+
+DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
+
+LOG = logging.getLogger(__name__)
+
+
+@lru_cache(maxsize=8)
+def _get_gemini_client(
+ api_key: str, base_url: str | None, timeout: int | None = None
+) -> genai.Client:
+ """
+ Create (or fetch cached) Gemini client.
+
+ Args:
+ api_key: Google Gemini API key.
+ base_url: Optional custom API endpoint.
+ timeout: Optional request timeout in milliseconds.
+
+ Returns:
+ genai.Client: Configured Gemini client instance.
+ """
+ client_kwargs: dict[str, Any] = {"api_key": api_key}
+
+ if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None:
+ try:
+ http_options_kwargs = {}
+ if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
+ http_options_kwargs["api_endpoint"] = base_url
+ if timeout is not None:
+ http_options_kwargs["timeout"] = timeout
+
+ client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
+ except Exception as exc: # pragma: no cover - defensive
+ LOG.warning("Failed to apply custom Gemini http_options: %s", exc)
+
+ try:
+ return genai.Client(**client_kwargs)
+ except TypeError:
+ # Older google-genai releases don't accept http_options; retry without it.
+ client_kwargs.pop("http_options", None)
+ return genai.Client(**client_kwargs)
+
+
+def _ensure_api_key(api_key: str | None) -> str:
+ key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
+ if not key:
+ raise ValueError(
+ "Gemini API key not provided. "
+ "Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
+ )
+ return key
+
+
+def _build_generation_config(
+ base_config: dict[str, Any] | None,
+ system_prompt: str | None,
+ keyword_extraction: bool,
+) -> types.GenerateContentConfig | None:
+ config_data = dict(base_config or {})
+
+ if system_prompt:
+ if config_data.get("system_instruction"):
+ config_data["system_instruction"] = (
+ f"{config_data['system_instruction']}\n{system_prompt}"
+ )
+ else:
+ config_data["system_instruction"] = system_prompt
+
+ if keyword_extraction and not config_data.get("response_mime_type"):
+ config_data["response_mime_type"] = "application/json"
+
+ # Remove entries that are explicitly set to None to avoid type errors
+ sanitized = {
+ key: value
+ for key, value in config_data.items()
+ if value is not None and value != ""
+ }
+
+ if not sanitized:
+ return None
+
+ return types.GenerateContentConfig(**sanitized)
+
+
+def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
+ if not history_messages:
+ return ""
+
+ history_lines: list[str] = []
+ for message in history_messages:
+ role = message.get("role", "user")
+ content = message.get("content", "")
+ history_lines.append(f"[{role}] {content}")
+
+ return "\n".join(history_lines)
+
+
+def _extract_response_text(
+ response: Any, extract_thoughts: bool = False
+) -> tuple[str, str]:
+ """
+ Extract text content from Gemini response, separating regular content from thoughts.
+
+ Args:
+ response: Gemini API response object
+ extract_thoughts: Whether to extract thought content separately
+
+ Returns:
+ Tuple of (regular_text, thought_text)
+ """
+ candidates = getattr(response, "candidates", None)
+ if not candidates:
+ return ("", "")
+
+ regular_parts: list[str] = []
+ thought_parts: list[str] = []
+
+ for candidate in candidates:
+ if not getattr(candidate, "content", None):
+ continue
+ # Use 'or []' to handle None values from parts attribute
+ for part in getattr(candidate.content, "parts", None) or []:
+ text = getattr(part, "text", None)
+ if not text:
+ continue
+
+ # Check if this part is thought content using the 'thought' attribute
+ is_thought = getattr(part, "thought", False)
+
+ if is_thought and extract_thoughts:
+ thought_parts.append(text)
+ elif not is_thought:
+ regular_parts.append(text)
+
+ return ("\n".join(regular_parts), "\n".join(thought_parts))
+
+
+async def gemini_complete_if_cache(
+ model: str,
+ prompt: str,
+ system_prompt: str | None = None,
+ history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
+ base_url: str | None = None,
+ api_key: str | None = None,
+ token_tracker: Any | None = None,
+ stream: bool | None = None,
+ keyword_extraction: bool = False,
+ generation_config: dict[str, Any] | None = None,
+ timeout: int | None = None,
+ **_: Any,
+) -> str | AsyncIterator[str]:
+ """
+ Complete a prompt using Gemini's API with Chain of Thought (COT) support.
+
+ This function supports automatic integration of reasoning content from Gemini models
+ that provide Chain of Thought capabilities via the thinking_config API feature.
+
+ COT Integration:
+ - When enable_cot=True: Thought content is wrapped in ... tags
+ - When enable_cot=False: Thought content is filtered out, only regular content returned
+ - Thought content is identified by the 'thought' attribute on response parts
+ - Requires thinking_config to be enabled in generation_config for API to return thoughts
+
+ Args:
+ model: The Gemini model to use.
+ prompt: The prompt to complete.
+ system_prompt: Optional system prompt to include.
+ history_messages: Optional list of previous messages in the conversation.
+ api_key: Optional Gemini API key. If None, uses environment variable.
+ base_url: Optional custom API endpoint.
+ generation_config: Optional generation configuration dict.
+ keyword_extraction: Whether to use JSON response format.
+ token_tracker: Optional token usage tracker for monitoring API usage.
+ stream: Whether to stream the response.
+ hashing_kv: Storage interface (for interface parity with other bindings).
+ enable_cot: Whether to include Chain of Thought content in the response.
+ timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
+ **_: Additional keyword arguments (ignored).
+
+ Returns:
+ The completed text (with COT content if enable_cot=True) or an async iterator
+ of text chunks if streaming. COT content is wrapped in ... tags.
+
+ Raises:
+ RuntimeError: If the response from Gemini is empty.
+ ValueError: If API key is not provided or configured.
+ """
+ loop = asyncio.get_running_loop()
+
+ key = _ensure_api_key(api_key)
+ # Convert timeout from seconds to milliseconds for Gemini API
+ timeout_ms = timeout * 1000 if timeout else None
+ client = _get_gemini_client(key, base_url, timeout_ms)
+
+ history_block = _format_history_messages(history_messages)
+ prompt_sections = []
+ if history_block:
+ prompt_sections.append(history_block)
+ prompt_sections.append(f"[user] {prompt}")
+ combined_prompt = "\n".join(prompt_sections)
+
+ config_obj = _build_generation_config(
+ generation_config,
+ system_prompt=system_prompt,
+ keyword_extraction=keyword_extraction,
+ )
+
+ request_kwargs: dict[str, Any] = {
+ "model": model,
+ "contents": [combined_prompt],
+ }
+ if config_obj is not None:
+ request_kwargs["config"] = config_obj
+
+ def _call_model():
+ return client.models.generate_content(**request_kwargs)
+
+ if stream:
+ queue: asyncio.Queue[Any] = asyncio.Queue()
+ usage_container: dict[str, Any] = {}
+
+ def _stream_model() -> None:
+ # COT state tracking for streaming
+ cot_active = False
+ cot_started = False
+ initial_content_seen = False
+
+ try:
+ stream_kwargs = dict(request_kwargs)
+ stream_iterator = client.models.generate_content_stream(**stream_kwargs)
+ for chunk in stream_iterator:
+ usage = getattr(chunk, "usage_metadata", None)
+ if usage is not None:
+ usage_container["usage"] = usage
+
+ # Extract both regular and thought content
+ regular_text, thought_text = _extract_response_text(
+ chunk, extract_thoughts=True
+ )
+
+ if enable_cot:
+ # Process regular content
+ if regular_text:
+ if not initial_content_seen:
+ initial_content_seen = True
+
+ # Close COT section if it was active
+ if cot_active:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ cot_active = False
+
+ # Send regular content
+ loop.call_soon_threadsafe(queue.put_nowait, regular_text)
+
+ # Process thought content
+ if thought_text:
+ if not initial_content_seen and not cot_started:
+ # Start COT section
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ cot_active = True
+ cot_started = True
+
+ # Send thought content if COT is active
+ if cot_active:
+ loop.call_soon_threadsafe(
+ queue.put_nowait, thought_text
+ )
+ else:
+ # COT disabled - only send regular content
+ if regular_text:
+ loop.call_soon_threadsafe(queue.put_nowait, regular_text)
+
+ # Ensure COT is properly closed if still active
+ if cot_active:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+
+ loop.call_soon_threadsafe(queue.put_nowait, None)
+ except Exception as exc: # pragma: no cover - surface runtime issues
+ # Try to close COT tag before reporting error
+ if cot_active:
+ try:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ except Exception:
+ pass
+ loop.call_soon_threadsafe(queue.put_nowait, exc)
+
+ loop.run_in_executor(None, _stream_model)
+
+ async def _async_stream() -> AsyncIterator[str]:
+ try:
+ while True:
+ item = await queue.get()
+ if item is None:
+ break
+ if isinstance(item, Exception):
+ raise item
+
+ chunk_text = str(item)
+ if "\\u" in chunk_text:
+ chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
+
+ # Yield the chunk directly without filtering
+ # COT filtering is already handled in _stream_model()
+ yield chunk_text
+ finally:
+ usage = usage_container.get("usage")
+ if token_tracker and usage:
+ token_tracker.add_usage(
+ {
+ "prompt_tokens": getattr(usage, "prompt_token_count", 0),
+ "completion_tokens": getattr(
+ usage, "candidates_token_count", 0
+ ),
+ "total_tokens": getattr(usage, "total_token_count", 0),
+ }
+ )
+
+ return _async_stream()
+
+ response = await asyncio.to_thread(_call_model)
+
+ # Extract both regular text and thought text
+ regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
+
+ # Apply COT filtering logic based on enable_cot parameter
+ if enable_cot:
+ # Include thought content wrapped in tags
+ if thought_text and thought_text.strip():
+ if not regular_text or regular_text.strip() == "":
+ # Only thought content available
+ final_text = f"{thought_text}"
+ else:
+ # Both content types present: prepend thought to regular content
+ final_text = f"{thought_text}{regular_text}"
+ else:
+ # No thought content, use regular content only
+ final_text = regular_text or ""
+ else:
+ # Filter out thought content, return only regular content
+ final_text = regular_text or ""
+
+ if not final_text:
+ raise RuntimeError("Gemini response did not contain any text content.")
+
+ if "\\u" in final_text:
+ final_text = safe_unicode_decode(final_text.encode("utf-8"))
+
+ final_text = remove_think_tags(final_text)
+
+ usage = getattr(response, "usage_metadata", None)
+ if token_tracker and usage:
+ token_tracker.add_usage(
+ {
+ "prompt_tokens": getattr(usage, "prompt_token_count", 0),
+ "completion_tokens": getattr(usage, "candidates_token_count", 0),
+ "total_tokens": getattr(usage, "total_token_count", 0),
+ }
+ )
+
+ logger.debug("Gemini response length: %s", len(final_text))
+ return final_text
+
+
+async def gemini_model_complete(
+ prompt: str,
+ system_prompt: str | None = None,
+ history_messages: list[dict[str, Any]] | None = None,
+ keyword_extraction: bool = False,
+ **kwargs: Any,
+) -> str | AsyncIterator[str]:
+ hashing_kv = kwargs.get("hashing_kv")
+ model_name = None
+ if hashing_kv is not None:
+ model_name = hashing_kv.global_config.get("llm_model_name")
+ if model_name is None:
+ model_name = kwargs.pop("model_name", None)
+ if model_name is None:
+ raise ValueError("Gemini model name not provided in configuration.")
+
+ return await gemini_complete_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ keyword_extraction=keyword_extraction,
+ **kwargs,
+ )
+
+
+@wrap_embedding_func_with_attrs(embedding_dim=1536)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=(
+ retry_if_exception_type(Exception) # Gemini uses generic exceptions
+ ),
+)
+async def gemini_embed(
+ texts: list[str],
+ model: str = "gemini-embedding-001",
+ base_url: str | None = None,
+ api_key: str | None = None,
+ embedding_dim: int | None = None,
+ task_type: str = "RETRIEVAL_DOCUMENT",
+ timeout: int | None = None,
+ token_tracker: Any | None = None,
+) -> np.ndarray:
+ """Generate embeddings for a list of texts using Gemini's API.
+
+ This function uses Google's Gemini embedding model to generate text embeddings.
+ It supports dynamic dimension control and automatic normalization for dimensions
+ less than 3072.
+
+ Args:
+ texts: List of texts to embed.
+ model: The Gemini embedding model to use. Default is "gemini-embedding-001".
+ base_url: Optional custom API endpoint.
+ api_key: Optional Gemini API key. If None, uses environment variables.
+ embedding_dim: Optional embedding dimension for dynamic dimension reduction.
+ **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
+ Do NOT manually pass this parameter when calling the function directly.
+ The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
+ or the EMBEDDING_DIM environment variable.
+ Supported range: 128-3072. Recommended values: 768, 1536, 3072.
+ task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
+ Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
+ RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
+ QUESTION_ANSWERING, FACT_VERIFICATION.
+ timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
+ token_tracker: Optional token usage tracker for monitoring API usage.
+
+ Returns:
+ A numpy array of embeddings, one per input text. For dimensions < 3072,
+ the embeddings are L2-normalized to ensure optimal semantic similarity performance.
+
+ Raises:
+ ValueError: If API key is not provided or configured.
+ RuntimeError: If the response from Gemini is invalid or empty.
+
+ Note:
+ - For dimension 3072: Embeddings are already normalized by the API
+ - For dimensions < 3072: Embeddings are L2-normalized after retrieval
+ - Normalization ensures accurate semantic similarity via cosine distance
+ """
+ loop = asyncio.get_running_loop()
+
+ key = _ensure_api_key(api_key)
+ # Convert timeout from seconds to milliseconds for Gemini API
+ timeout_ms = timeout * 1000 if timeout else None
+ client = _get_gemini_client(key, base_url, timeout_ms)
+
+ # Prepare embedding configuration
+ config_kwargs: dict[str, Any] = {}
+
+ # Add task_type to config
+ if task_type:
+ config_kwargs["task_type"] = task_type
+
+ # Add output_dimensionality if embedding_dim is provided
+ if embedding_dim is not None:
+ config_kwargs["output_dimensionality"] = embedding_dim
+
+ # Create config object if we have parameters
+ config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
+
+ def _call_embed() -> Any:
+ """Call Gemini embedding API in executor thread."""
+ request_kwargs: dict[str, Any] = {
+ "model": model,
+ "contents": texts,
+ }
+ if config_obj is not None:
+ request_kwargs["config"] = config_obj
+
+ return client.models.embed_content(**request_kwargs)
+
+ # Execute API call in thread pool
+ response = await loop.run_in_executor(None, _call_embed)
+
+ # Extract embeddings from response
+ if not hasattr(response, "embeddings") or not response.embeddings:
+ raise RuntimeError("Gemini response did not contain embeddings.")
+
+ # Convert embeddings to numpy array
+ embeddings = np.array(
+ [np.array(e.values, dtype=np.float32) for e in response.embeddings]
+ )
+
+ # Apply L2 normalization for dimensions < 3072
+ # The 3072 dimension embedding is already normalized by Gemini API
+ if embedding_dim and embedding_dim < 3072:
+ # Normalize each embedding vector to unit length
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
+ # Avoid division by zero
+ norms = np.where(norms == 0, 1, norms)
+ embeddings = embeddings / norms
+ logger.debug(
+ f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
+ )
+
+ # Track token usage if tracker is provided
+ # Note: Gemini embedding API may not provide usage metadata
+ if token_tracker and hasattr(response, "usage_metadata"):
+ usage = response.usage_metadata
+ token_counts = {
+ "prompt_tokens": getattr(usage, "prompt_token_count", 0),
+ "total_tokens": getattr(usage, "total_token_count", 0),
+ }
+ token_tracker.add_usage(token_counts)
+
+ logger.debug(
+ f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
+ )
+
+ return embeddings
+
+
+__all__ = [
+ "gemini_complete_if_cache",
+ "gemini_model_complete",
+ "gemini_embed",
+]