From 3aa214b005f7dc94172b2fbf4edc9e2dca6a90fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:30 +0800 Subject: [PATCH] cherry-pick 3d9de5ed --- lightrag/llm/gemini.py | 179 ++++++++++++++++++++++++++++++++++- pyproject.toml | 3 + requirements-offline-llm.txt | 2 + 3 files changed, 181 insertions(+), 3 deletions(-) diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index f06ec6b3..983d6b9f 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -16,22 +16,44 @@ from collections.abc import AsyncIterator from functools import lru_cache from typing import Any -from lightrag.utils import logger, remove_think_tags, safe_unicode_decode +import numpy as np +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) + +from lightrag.utils import ( + logger, + remove_think_tags, + safe_unicode_decode, + wrap_embedding_func_with_attrs, +) import pipmaster as pm -# Install the Google Gemini client on demand +# Install the Google Gemini client and its dependencies on demand if not pm.is_installed("google-genai"): pm.install("google-genai") +if not pm.is_installed("google-api-core"): + pm.install("google-api-core") from google import genai # type: ignore from google.genai import types # type: ignore +from google.api_core import exceptions as google_api_exceptions # type: ignore DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com" LOG = logging.getLogger(__name__) +class InvalidResponseError(Exception): + """Custom exception class for triggering retry mechanism when Gemini returns empty responses""" + + pass + + @lru_cache(maxsize=8) def _get_gemini_client( api_key: str, base_url: str | None, timeout: int | None = None @@ -163,6 +185,21 @@ def _extract_response_text( return ("\n".join(regular_parts), "\n".join(thought_parts)) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(google_api_exceptions.InternalServerError) + | retry_if_exception_type(google_api_exceptions.ServiceUnavailable) + | retry_if_exception_type(google_api_exceptions.ResourceExhausted) + | retry_if_exception_type(google_api_exceptions.GatewayTimeout) + | retry_if_exception_type(google_api_exceptions.BadGateway) + | retry_if_exception_type(google_api_exceptions.DeadlineExceeded) + | retry_if_exception_type(google_api_exceptions.Aborted) + | retry_if_exception_type(google_api_exceptions.Unknown) + | retry_if_exception_type(InvalidResponseError) + ), +) async def gemini_complete_if_cache( model: str, prompt: str, @@ -369,7 +406,7 @@ async def gemini_complete_if_cache( final_text = regular_text or "" if not final_text: - raise RuntimeError("Gemini response did not contain any text content.") + raise InvalidResponseError("Gemini response did not contain any text content.") if "\\u" in final_text: final_text = safe_unicode_decode(final_text.encode("utf-8")) @@ -416,7 +453,143 @@ async def gemini_model_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1536) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(google_api_exceptions.InternalServerError) + | retry_if_exception_type(google_api_exceptions.ServiceUnavailable) + | retry_if_exception_type(google_api_exceptions.ResourceExhausted) + | retry_if_exception_type(google_api_exceptions.GatewayTimeout) + | retry_if_exception_type(google_api_exceptions.BadGateway) + | retry_if_exception_type(google_api_exceptions.DeadlineExceeded) + | retry_if_exception_type(google_api_exceptions.Aborted) + | retry_if_exception_type(google_api_exceptions.Unknown) + ), +) +async def gemini_embed( + texts: list[str], + model: str = "gemini-embedding-001", + base_url: str | None = None, + api_key: str | None = None, + embedding_dim: int | None = None, + task_type: str = "RETRIEVAL_DOCUMENT", + timeout: int | None = None, + token_tracker: Any | None = None, +) -> np.ndarray: + """Generate embeddings for a list of texts using Gemini's API. + + This function uses Google's Gemini embedding model to generate text embeddings. + It supports dynamic dimension control and automatic normalization for dimensions + less than 3072. + + Args: + texts: List of texts to embed. + model: The Gemini embedding model to use. Default is "gemini-embedding-001". + base_url: Optional custom API endpoint. + api_key: Optional Gemini API key. If None, uses environment variables. + embedding_dim: Optional embedding dimension for dynamic dimension reduction. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. + Do NOT manually pass this parameter when calling the function directly. + The dimension is controlled by the @wrap_embedding_func_with_attrs decorator + or the EMBEDDING_DIM environment variable. + Supported range: 128-3072. Recommended values: 768, 1536, 3072. + task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT". + Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, + RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY, + QUESTION_ANSWERING, FACT_VERIFICATION. + timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). + token_tracker: Optional token usage tracker for monitoring API usage. + + Returns: + A numpy array of embeddings, one per input text. For dimensions < 3072, + the embeddings are L2-normalized to ensure optimal semantic similarity performance. + + Raises: + ValueError: If API key is not provided or configured. + RuntimeError: If the response from Gemini is invalid or empty. + + Note: + - For dimension 3072: Embeddings are already normalized by the API + - For dimensions < 3072: Embeddings are L2-normalized after retrieval + - Normalization ensures accurate semantic similarity via cosine distance + """ + loop = asyncio.get_running_loop() + + key = _ensure_api_key(api_key) + # Convert timeout from seconds to milliseconds for Gemini API + timeout_ms = timeout * 1000 if timeout else None + client = _get_gemini_client(key, base_url, timeout_ms) + + # Prepare embedding configuration + config_kwargs: dict[str, Any] = {} + + # Add task_type to config + if task_type: + config_kwargs["task_type"] = task_type + + # Add output_dimensionality if embedding_dim is provided + if embedding_dim is not None: + config_kwargs["output_dimensionality"] = embedding_dim + + # Create config object if we have parameters + config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None + + def _call_embed() -> Any: + """Call Gemini embedding API in executor thread.""" + request_kwargs: dict[str, Any] = { + "model": model, + "contents": texts, + } + if config_obj is not None: + request_kwargs["config"] = config_obj + + return client.models.embed_content(**request_kwargs) + + # Execute API call in thread pool + response = await loop.run_in_executor(None, _call_embed) + + # Extract embeddings from response + if not hasattr(response, "embeddings") or not response.embeddings: + raise RuntimeError("Gemini response did not contain embeddings.") + + # Convert embeddings to numpy array + embeddings = np.array( + [np.array(e.values, dtype=np.float32) for e in response.embeddings] + ) + + # Apply L2 normalization for dimensions < 3072 + # The 3072 dimension embedding is already normalized by Gemini API + if embedding_dim and embedding_dim < 3072: + # Normalize each embedding vector to unit length + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + # Avoid division by zero + norms = np.where(norms == 0, 1, norms) + embeddings = embeddings / norms + logger.debug( + f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}" + ) + + # Track token usage if tracker is provided + # Note: Gemini embedding API may not provide usage metadata + if token_tracker and hasattr(response, "usage_metadata"): + usage = response.usage_metadata + token_counts = { + "prompt_tokens": getattr(usage, "prompt_token_count", 0), + "total_tokens": getattr(usage, "total_token_count", 0), + } + token_tracker.add_usage(token_counts) + + logger.debug( + f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}" + ) + + return embeddings + + __all__ = [ "gemini_complete_if_cache", "gemini_model_complete", + "gemini_embed", ] diff --git a/pyproject.toml b/pyproject.toml index df59b33b..ba1ba2cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "aiohttp", "configparser", "future", + "google-api-core>=2.0.0,<3.0.0", "google-genai>=1.0.0,<2.0.0", "json_repair", "nano-vectordb", @@ -60,6 +61,7 @@ api = [ "tenacity", "tiktoken", "xlsxwriter>=3.1.0", + "google-api-core>=2.0.0,<3.0.0", "google-genai>=1.0.0,<2.0.0", # API-specific dependencies "aiofiles", @@ -107,6 +109,7 @@ offline-llm = [ "aioboto3>=12.0.0,<16.0.0", "voyageai>=0.2.0,<1.0.0", "llama-index>=0.9.0,<1.0.0", + "google-api-core>=2.0.0,<3.0.0", "google-genai>=1.0.0,<2.0.0", ] diff --git a/requirements-offline-llm.txt b/requirements-offline-llm.txt index 269847a2..1539552a 100644 --- a/requirements-offline-llm.txt +++ b/requirements-offline-llm.txt @@ -10,6 +10,8 @@ # LLM provider dependencies (with version constraints matching pyproject.toml) aioboto3>=12.0.0,<16.0.0 anthropic>=0.18.0,<1.0.0 +google-api-core>=2.0.0,<3.0.0 +google-genai>=1.0.0,<2.0.0 llama-index>=0.9.0,<1.0.0 ollama>=0.1.0,<1.0.0 openai>=1.0.0,<3.0.0