From de4ed7365203de7a756493b82b609b4c193e8b52 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Nov 2025 03:34:30 +0800 Subject: [PATCH 1/2] Add Gemini embedding support - Implement gemini_embed function - Add gemini to embedding binding choices - Add L2 normalization for dims < 3072 --- lightrag/api/config.py | 26 ++++-- lightrag/api/lightrag_server.py | 39 +++++++++ lightrag/llm/binding_options.py | 13 +++ lightrag/llm/gemini.py | 144 +++++++++++++++++++++++++++++++- 4 files changed, 216 insertions(+), 6 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index baaf9c52..1f46d147 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -8,6 +8,7 @@ import logging from dotenv import load_dotenv from lightrag.utils import get_env_value from lightrag.llm.binding_options import ( + GeminiEmbeddingOptions, GeminiLLMOptions, OllamaEmbeddingOptions, OllamaLLMOptions, @@ -238,7 +239,15 @@ def parse_args() -> argparse.Namespace: "--embedding-binding", type=str, default=get_env_value("EMBEDDING_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], + choices=[ + "lollms", + "ollama", + "openai", + "azure_openai", + "aws_bedrock", + "jina", + "gemini", + ], help="Embedding binding type (default: from env or ollama)", ) parser.add_argument( @@ -265,12 +274,19 @@ def parse_args() -> argparse.Namespace: if "--embedding-binding" in sys.argv: try: idx = sys.argv.index("--embedding-binding") - if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "ollama": - OllamaEmbeddingOptions.add_args(parser) + if idx + 1 < len(sys.argv): + if sys.argv[idx + 1] == "ollama": + OllamaEmbeddingOptions.add_args(parser) + elif sys.argv[idx + 1] == "gemini": + GeminiEmbeddingOptions.add_args(parser) except IndexError: pass - elif os.environ.get("EMBEDDING_BINDING") == "ollama": - OllamaEmbeddingOptions.add_args(parser) + else: + env_embedding_binding = os.environ.get("EMBEDDING_BINDING") + if env_embedding_binding == "ollama": + OllamaEmbeddingOptions.add_args(parser) + elif env_embedding_binding == "gemini": + GeminiEmbeddingOptions.add_args(parser) # Add OpenAI LLM options when llm-binding is openai or azure_openai if "--llm-binding" in sys.argv: diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b3a439e8..7a291018 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -89,6 +89,7 @@ class LLMConfigCache: # Initialize configurations based on binding conditions self.openai_llm_options = None self.gemini_llm_options = None + self.gemini_embedding_options = None self.ollama_llm_options = None self.ollama_embedding_options = None @@ -135,6 +136,23 @@ class LLMConfigCache: ) self.ollama_embedding_options = {} + # Only initialize and log Gemini Embedding options when using Gemini Embedding binding + if args.embedding_binding == "gemini": + try: + from lightrag.llm.binding_options import GeminiEmbeddingOptions + + self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict( + args + ) + logger.info( + f"Gemini Embedding Options: {self.gemini_embedding_options}" + ) + except ImportError: + logger.warning( + "GeminiEmbeddingOptions not available, using default configuration" + ) + self.gemini_embedding_options = {} + def check_frontend_build(): """Check if frontend is built and optionally check if source is up-to-date @@ -296,6 +314,7 @@ def create_app(args): "azure_openai", "aws_bedrock", "jina", + "gemini", ]: raise Exception("embedding binding not supported") @@ -649,6 +668,26 @@ def create_app(args): base_url=host, api_key=api_key, ) + elif binding == "gemini": + from lightrag.llm.gemini import gemini_embed + + # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + if config_cache.gemini_embedding_options is not None: + gemini_options = config_cache.gemini_embedding_options + else: + # Fallback for cases where config cache wasn't initialized properly + from lightrag.llm.binding_options import GeminiEmbeddingOptions + + gemini_options = GeminiEmbeddingOptions.options_dict(args) + + return await gemini_embed( + texts, + model=model, + base_url=host, + api_key=api_key, + embedding_dim=embedding_dim, + task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"), + ) else: # openai and compatible from lightrag.llm.openai import openai_embed diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py index 8f69711a..1cb52a81 100644 --- a/lightrag/llm/binding_options.py +++ b/lightrag/llm/binding_options.py @@ -508,6 +508,19 @@ class GeminiLLMOptions(BindingOptions): } +@dataclass +class GeminiEmbeddingOptions(BindingOptions): + """Options for Google Gemini embedding models.""" + + _binding_name: ClassVar[str] = "gemini_embedding" + + task_type: str = "RETRIEVAL_DOCUMENT" + + _help: ClassVar[dict[str, str]] = { + "task_type": "Task type for embedding optimization (RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION)", + } + + # ============================================================================= # Binding Options for OpenAI # ============================================================================= diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index f06ec6b3..3954e814 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -16,7 +16,20 @@ from collections.abc import AsyncIterator from functools import lru_cache from typing import Any -from lightrag.utils import logger, remove_think_tags, safe_unicode_decode +import numpy as np +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) + +from lightrag.utils import ( + logger, + remove_think_tags, + safe_unicode_decode, + wrap_embedding_func_with_attrs, +) import pipmaster as pm @@ -416,7 +429,136 @@ async def gemini_model_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1536) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(Exception) # Gemini uses generic exceptions + ), +) +async def gemini_embed( + texts: list[str], + model: str = "gemini-embedding-001", + base_url: str | None = None, + api_key: str | None = None, + embedding_dim: int | None = None, + task_type: str = "RETRIEVAL_DOCUMENT", + timeout: int | None = None, + token_tracker: Any | None = None, +) -> np.ndarray: + """Generate embeddings for a list of texts using Gemini's API. + + This function uses Google's Gemini embedding model to generate text embeddings. + It supports dynamic dimension control and automatic normalization for dimensions + less than 3072. + + Args: + texts: List of texts to embed. + model: The Gemini embedding model to use. Default is "gemini-embedding-001". + base_url: Optional custom API endpoint. + api_key: Optional Gemini API key. If None, uses environment variables. + embedding_dim: Optional embedding dimension for dynamic dimension reduction. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. + Do NOT manually pass this parameter when calling the function directly. + The dimension is controlled by the @wrap_embedding_func_with_attrs decorator + or the EMBEDDING_DIM environment variable. + Supported range: 128-3072. Recommended values: 768, 1536, 3072. + task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT". + Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, + RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY, + QUESTION_ANSWERING, FACT_VERIFICATION. + timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). + token_tracker: Optional token usage tracker for monitoring API usage. + + Returns: + A numpy array of embeddings, one per input text. For dimensions < 3072, + the embeddings are L2-normalized to ensure optimal semantic similarity performance. + + Raises: + ValueError: If API key is not provided or configured. + RuntimeError: If the response from Gemini is invalid or empty. + + Note: + - For dimension 3072: Embeddings are already normalized by the API + - For dimensions < 3072: Embeddings are L2-normalized after retrieval + - Normalization ensures accurate semantic similarity via cosine distance + """ + loop = asyncio.get_running_loop() + + key = _ensure_api_key(api_key) + # Convert timeout from seconds to milliseconds for Gemini API + timeout_ms = timeout * 1000 if timeout else None + client = _get_gemini_client(key, base_url, timeout_ms) + + # Prepare embedding configuration + config_kwargs: dict[str, Any] = {} + + # Add task_type to config + if task_type: + config_kwargs["task_type"] = task_type + + # Add output_dimensionality if embedding_dim is provided + if embedding_dim is not None: + config_kwargs["output_dimensionality"] = embedding_dim + + # Create config object if we have parameters + config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None + + def _call_embed() -> Any: + """Call Gemini embedding API in executor thread.""" + request_kwargs: dict[str, Any] = { + "model": model, + "contents": texts, + } + if config_obj is not None: + request_kwargs["config"] = config_obj + + return client.models.embed_content(**request_kwargs) + + # Execute API call in thread pool + response = await loop.run_in_executor(None, _call_embed) + + # Extract embeddings from response + if not hasattr(response, "embeddings") or not response.embeddings: + raise RuntimeError("Gemini response did not contain embeddings.") + + # Convert embeddings to numpy array + embeddings = np.array( + [np.array(e.values, dtype=np.float32) for e in response.embeddings] + ) + + # Apply L2 normalization for dimensions < 3072 + # The 3072 dimension embedding is already normalized by Gemini API + if embedding_dim and embedding_dim < 3072: + # Normalize each embedding vector to unit length + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + # Avoid division by zero + norms = np.where(norms == 0, 1, norms) + embeddings = embeddings / norms + logger.debug( + f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}" + ) + + # Track token usage if tracker is provided + # Note: Gemini embedding API may not provide usage metadata + if token_tracker and hasattr(response, "usage_metadata"): + usage = response.usage_metadata + token_counts = { + "prompt_tokens": getattr(usage, "prompt_token_count", 0), + "total_tokens": getattr(usage, "total_token_count", 0), + } + token_tracker.add_usage(token_counts) + + logger.debug( + f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}" + ) + + return embeddings + + __all__ = [ "gemini_complete_if_cache", "gemini_model_complete", + "gemini_embed", ] From a624a9508af9a2af164b11ceec397d2f64f4e0b9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Nov 2025 03:54:50 +0800 Subject: [PATCH 2/2] Add Gemini to APIs requiring embedding dimension parameter --- lightrag/api/lightrag_server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 7a291018..ded70d67 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -757,12 +757,12 @@ def create_app(args): has_embedding_dim_param = "embedding_dim" in sig.parameters # Determine send_dimensions value based on binding type - # Jina REQUIRES dimension parameter (forced to True) + # Jina and Gemini REQUIRE dimension parameter (forced to True) # OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable - if args.embedding_binding == "jina": - # Jina API requires dimension parameter - always send it + if args.embedding_binding in ["jina", "gemini"]: + # Jina and Gemini APIs require dimension parameter - always send it send_dimensions = has_embedding_dim_param - dimension_control = "forced by Jina API" + dimension_control = f"forced by {args.embedding_binding.title()} API" else: # For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting send_dimensions = embedding_send_dim and has_embedding_dim_param