From 3b986f046f5c84678b2bf156e4d677a7de4703cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:19:01 +0800 Subject: [PATCH] cherry-pick 6e36ff41 --- lightrag/api/lightrag_server.py | 44 ++++ lightrag/llm/binding_options.py | 11 +- lightrag/llm/gemini.py | 392 ++++---------------------------- requirements-offline-llm.txt | 3 +- requirements-offline.txt | 7 +- 5 files changed, 102 insertions(+), 355 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index fc1e0484..70e17bb6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -89,6 +89,7 @@ class LLMConfigCache: # Initialize configurations based on binding conditions self.openai_llm_options = None + self.gemini_llm_options = None self.ollama_llm_options = None self.ollama_embedding_options = None @@ -99,6 +100,12 @@ class LLMConfigCache: self.openai_llm_options = OpenAILLMOptions.options_dict(args) logger.info(f"OpenAI LLM Options: {self.openai_llm_options}") + if args.llm_binding == "gemini": + from lightrag.llm.binding_options import GeminiLLMOptions + + self.gemini_llm_options = GeminiLLMOptions.options_dict(args) + logger.info(f"Gemini LLM Options: {self.gemini_llm_options}") + # Only initialize and log Ollama LLM options when using Ollama LLM binding if args.llm_binding == "ollama": try: @@ -279,6 +286,7 @@ def create_app(args): "openai", "azure_openai", "aws_bedrock", + "gemini", ]: raise Exception("llm binding not supported") @@ -504,6 +512,40 @@ def create_app(args): return optimized_azure_openai_model_complete + def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args): + """Create optimized Gemini LLM function with cached configuration""" + + async def optimized_gemini_model_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, + ) -> str: + from lightrag.llm.gemini import gemini_complete_if_cache + + if history_messages is None: + history_messages = [] + + if ( + config_cache.gemini_llm_options is not None + and "generation_config" not in kwargs + ): + kwargs["generation_config"] = dict(config_cache.gemini_llm_options) + + return await gemini_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=args.llm_binding_api_key, + base_url=args.llm_binding_host, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + return optimized_gemini_model_complete + def create_llm_model_func(binding: str): """ Create LLM model function based on binding type. @@ -525,6 +567,8 @@ def create_app(args): return create_optimized_azure_openai_llm_func( config_cache, args, llm_timeout ) + elif binding == "gemini": + return create_optimized_gemini_llm_func(config_cache, args) else: # openai and compatible # Use optimized function with pre-processed configuration return create_optimized_openai_llm_func(config_cache, args, llm_timeout) diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py index 8f69711a..44ab5d2f 100644 --- a/lightrag/llm/binding_options.py +++ b/lightrag/llm/binding_options.py @@ -472,9 +472,6 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions): _binding_name: ClassVar[str] = "ollama_llm" -# ============================================================================= -# Binding Options for Gemini -# ============================================================================= @dataclass class GeminiLLMOptions(BindingOptions): """Options for Google Gemini models.""" @@ -489,9 +486,9 @@ class GeminiLLMOptions(BindingOptions): presence_penalty: float = 0.0 frequency_penalty: float = 0.0 stop_sequences: List[str] = field(default_factory=list) - seed: int | None = None - thinking_config: dict | None = None + response_mime_type: str | None = None safety_settings: dict | None = None + system_instruction: str | None = None _help: ClassVar[dict[str, str]] = { "temperature": "Controls randomness (0.0-2.0, higher = more creative)", @@ -502,9 +499,9 @@ class GeminiLLMOptions(BindingOptions): "presence_penalty": "Penalty for token presence (-2.0 to 2.0)", "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)", "stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')", - "seed": "Random seed for reproducible generation (leave empty for random)", - "thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')", + "response_mime_type": "Desired MIME type for the response (e.g., application/json)", "safety_settings": "JSON object with Gemini safety settings overrides", + "system_instruction": "Default system instruction applied to every request", } diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 983d6b9f..b8c64b31 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -16,72 +16,41 @@ from collections.abc import AsyncIterator from functools import lru_cache from typing import Any -import numpy as np -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, -) - -from lightrag.utils import ( - logger, - remove_think_tags, - safe_unicode_decode, - wrap_embedding_func_with_attrs, -) +from lightrag.utils import logger, remove_think_tags, safe_unicode_decode import pipmaster as pm -# Install the Google Gemini client and its dependencies on demand +# Install the Google Gemini client on demand if not pm.is_installed("google-genai"): pm.install("google-genai") -if not pm.is_installed("google-api-core"): - pm.install("google-api-core") from google import genai # type: ignore from google.genai import types # type: ignore -from google.api_core import exceptions as google_api_exceptions # type: ignore DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com" LOG = logging.getLogger(__name__) -class InvalidResponseError(Exception): - """Custom exception class for triggering retry mechanism when Gemini returns empty responses""" - - pass - - @lru_cache(maxsize=8) -def _get_gemini_client( - api_key: str, base_url: str | None, timeout: int | None = None -) -> genai.Client: +def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client: """ Create (or fetch cached) Gemini client. Args: api_key: Google Gemini API key. base_url: Optional custom API endpoint. - timeout: Optional request timeout in milliseconds. Returns: genai.Client: Configured Gemini client instance. """ client_kwargs: dict[str, Any] = {"api_key": api_key} - if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None: + if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: try: - http_options_kwargs = {} - if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: - http_options_kwargs["api_endpoint"] = base_url - if timeout is not None: - http_options_kwargs["timeout"] = timeout - - client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs) + client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url) except Exception as exc: # pragma: no cover - defensive - LOG.warning("Failed to apply custom Gemini http_options: %s", exc) + LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc) try: return genai.Client(**client_kwargs) @@ -145,118 +114,47 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s return "\n".join(history_lines) -def _extract_response_text( - response: Any, extract_thoughts: bool = False -) -> tuple[str, str]: - """ - Extract text content from Gemini response, separating regular content from thoughts. +def _extract_response_text(response: Any) -> str: + if getattr(response, "text", None): + return response.text - Args: - response: Gemini API response object - extract_thoughts: Whether to extract thought content separately - - Returns: - Tuple of (regular_text, thought_text) - """ candidates = getattr(response, "candidates", None) if not candidates: - return ("", "") - - regular_parts: list[str] = [] - thought_parts: list[str] = [] + return "" + parts: list[str] = [] for candidate in candidates: if not getattr(candidate, "content", None): continue - # Use 'or []' to handle None values from parts attribute - for part in getattr(candidate.content, "parts", None) or []: + for part in getattr(candidate.content, "parts", []): text = getattr(part, "text", None) - if not text: - continue + if text: + parts.append(text) - # Check if this part is thought content using the 'thought' attribute - is_thought = getattr(part, "thought", False) - - if is_thought and extract_thoughts: - thought_parts.append(text) - elif not is_thought: - regular_parts.append(text) - - return ("\n".join(regular_parts), "\n".join(thought_parts)) + return "\n".join(parts) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=( - retry_if_exception_type(google_api_exceptions.InternalServerError) - | retry_if_exception_type(google_api_exceptions.ServiceUnavailable) - | retry_if_exception_type(google_api_exceptions.ResourceExhausted) - | retry_if_exception_type(google_api_exceptions.GatewayTimeout) - | retry_if_exception_type(google_api_exceptions.BadGateway) - | retry_if_exception_type(google_api_exceptions.DeadlineExceeded) - | retry_if_exception_type(google_api_exceptions.Aborted) - | retry_if_exception_type(google_api_exceptions.Unknown) - | retry_if_exception_type(InvalidResponseError) - ), -) async def gemini_complete_if_cache( model: str, prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, - enable_cot: bool = False, - base_url: str | None = None, + *, api_key: str | None = None, - token_tracker: Any | None = None, - stream: bool | None = None, - keyword_extraction: bool = False, + base_url: str | None = None, generation_config: dict[str, Any] | None = None, - timeout: int | None = None, + keyword_extraction: bool = False, + token_tracker: Any | None = None, + hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity + stream: bool | None = None, + enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently + timeout: float | None = None, # noqa: ARG001 - handled by caller if needed **_: Any, ) -> str | AsyncIterator[str]: - """ - Complete a prompt using Gemini's API with Chain of Thought (COT) support. - - This function supports automatic integration of reasoning content from Gemini models - that provide Chain of Thought capabilities via the thinking_config API feature. - - COT Integration: - - When enable_cot=True: Thought content is wrapped in ... tags - - When enable_cot=False: Thought content is filtered out, only regular content returned - - Thought content is identified by the 'thought' attribute on response parts - - Requires thinking_config to be enabled in generation_config for API to return thoughts - - Args: - model: The Gemini model to use. - prompt: The prompt to complete. - system_prompt: Optional system prompt to include. - history_messages: Optional list of previous messages in the conversation. - api_key: Optional Gemini API key. If None, uses environment variable. - base_url: Optional custom API endpoint. - generation_config: Optional generation configuration dict. - keyword_extraction: Whether to use JSON response format. - token_tracker: Optional token usage tracker for monitoring API usage. - stream: Whether to stream the response. - hashing_kv: Storage interface (for interface parity with other bindings). - enable_cot: Whether to include Chain of Thought content in the response. - timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). - **_: Additional keyword arguments (ignored). - - Returns: - The completed text (with COT content if enable_cot=True) or an async iterator - of text chunks if streaming. COT content is wrapped in ... tags. - - Raises: - RuntimeError: If the response from Gemini is empty. - ValueError: If API key is not provided or configured. - """ loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) - # Convert timeout from seconds to milliseconds for Gemini API - timeout_ms = timeout * 1000 if timeout else None - client = _get_gemini_client(key, base_url, timeout_ms) + client = _get_gemini_client(key, base_url) history_block = _format_history_messages(history_messages) prompt_sections = [] @@ -286,11 +184,6 @@ async def gemini_complete_if_cache( usage_container: dict[str, Any] = {} def _stream_model() -> None: - # COT state tracking for streaming - cot_active = False - cot_started = False - initial_content_seen = False - try: stream_kwargs = dict(request_kwargs) stream_iterator = client.models.generate_content_stream(**stream_kwargs) @@ -298,61 +191,20 @@ async def gemini_complete_if_cache( usage = getattr(chunk, "usage_metadata", None) if usage is not None: usage_container["usage"] = usage - - # Extract both regular and thought content - regular_text, thought_text = _extract_response_text( - chunk, extract_thoughts=True + text_piece = getattr(chunk, "text", None) or _extract_response_text( + chunk ) - - if enable_cot: - # Process regular content - if regular_text: - if not initial_content_seen: - initial_content_seen = True - - # Close COT section if it was active - if cot_active: - loop.call_soon_threadsafe(queue.put_nowait, "") - cot_active = False - - # Send regular content - loop.call_soon_threadsafe(queue.put_nowait, regular_text) - - # Process thought content - if thought_text: - if not initial_content_seen and not cot_started: - # Start COT section - loop.call_soon_threadsafe(queue.put_nowait, "") - cot_active = True - cot_started = True - - # Send thought content if COT is active - if cot_active: - loop.call_soon_threadsafe( - queue.put_nowait, thought_text - ) - else: - # COT disabled - only send regular content - if regular_text: - loop.call_soon_threadsafe(queue.put_nowait, regular_text) - - # Ensure COT is properly closed if still active - if cot_active: - loop.call_soon_threadsafe(queue.put_nowait, "") - + if text_piece: + loop.call_soon_threadsafe(queue.put_nowait, text_piece) loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as exc: # pragma: no cover - surface runtime issues - # Try to close COT tag before reporting error - if cot_active: - try: - loop.call_soon_threadsafe(queue.put_nowait, "") - except Exception: - pass loop.call_soon_threadsafe(queue.put_nowait, exc) loop.run_in_executor(None, _stream_model) async def _async_stream() -> AsyncIterator[str]: + accumulated = "" + emitted = "" try: while True: item = await queue.get() @@ -365,9 +217,16 @@ async def gemini_complete_if_cache( if "\\u" in chunk_text: chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) - # Yield the chunk directly without filtering - # COT filtering is already handled in _stream_model() - yield chunk_text + accumulated += chunk_text + sanitized = remove_think_tags(accumulated) + if sanitized.startswith(emitted): + delta = sanitized[len(emitted) :] + else: + delta = sanitized + emitted = sanitized + + if delta: + yield delta finally: usage = usage_container.get("usage") if token_tracker and usage: @@ -385,33 +244,14 @@ async def gemini_complete_if_cache( response = await asyncio.to_thread(_call_model) - # Extract both regular text and thought text - regular_text, thought_text = _extract_response_text(response, extract_thoughts=True) + text = _extract_response_text(response) + if not text: + raise RuntimeError("Gemini response did not contain any text content.") - # Apply COT filtering logic based on enable_cot parameter - if enable_cot: - # Include thought content wrapped in tags - if thought_text and thought_text.strip(): - if not regular_text or regular_text.strip() == "": - # Only thought content available - final_text = f"{thought_text}" - else: - # Both content types present: prepend thought to regular content - final_text = f"{thought_text}{regular_text}" - else: - # No thought content, use regular content only - final_text = regular_text or "" - else: - # Filter out thought content, return only regular content - final_text = regular_text or "" + if "\\u" in text: + text = safe_unicode_decode(text.encode("utf-8")) - if not final_text: - raise InvalidResponseError("Gemini response did not contain any text content.") - - if "\\u" in final_text: - final_text = safe_unicode_decode(final_text.encode("utf-8")) - - final_text = remove_think_tags(final_text) + text = remove_think_tags(text) usage = getattr(response, "usage_metadata", None) if token_tracker and usage: @@ -423,8 +263,8 @@ async def gemini_complete_if_cache( } ) - logger.debug("Gemini response length: %s", len(final_text)) - return final_text + logger.debug("Gemini response length: %s", len(text)) + return text async def gemini_model_complete( @@ -453,143 +293,7 @@ async def gemini_model_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1536) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=( - retry_if_exception_type(google_api_exceptions.InternalServerError) - | retry_if_exception_type(google_api_exceptions.ServiceUnavailable) - | retry_if_exception_type(google_api_exceptions.ResourceExhausted) - | retry_if_exception_type(google_api_exceptions.GatewayTimeout) - | retry_if_exception_type(google_api_exceptions.BadGateway) - | retry_if_exception_type(google_api_exceptions.DeadlineExceeded) - | retry_if_exception_type(google_api_exceptions.Aborted) - | retry_if_exception_type(google_api_exceptions.Unknown) - ), -) -async def gemini_embed( - texts: list[str], - model: str = "gemini-embedding-001", - base_url: str | None = None, - api_key: str | None = None, - embedding_dim: int | None = None, - task_type: str = "RETRIEVAL_DOCUMENT", - timeout: int | None = None, - token_tracker: Any | None = None, -) -> np.ndarray: - """Generate embeddings for a list of texts using Gemini's API. - - This function uses Google's Gemini embedding model to generate text embeddings. - It supports dynamic dimension control and automatic normalization for dimensions - less than 3072. - - Args: - texts: List of texts to embed. - model: The Gemini embedding model to use. Default is "gemini-embedding-001". - base_url: Optional custom API endpoint. - api_key: Optional Gemini API key. If None, uses environment variables. - embedding_dim: Optional embedding dimension for dynamic dimension reduction. - **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. - Do NOT manually pass this parameter when calling the function directly. - The dimension is controlled by the @wrap_embedding_func_with_attrs decorator - or the EMBEDDING_DIM environment variable. - Supported range: 128-3072. Recommended values: 768, 1536, 3072. - task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT". - Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, - RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY, - QUESTION_ANSWERING, FACT_VERIFICATION. - timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). - token_tracker: Optional token usage tracker for monitoring API usage. - - Returns: - A numpy array of embeddings, one per input text. For dimensions < 3072, - the embeddings are L2-normalized to ensure optimal semantic similarity performance. - - Raises: - ValueError: If API key is not provided or configured. - RuntimeError: If the response from Gemini is invalid or empty. - - Note: - - For dimension 3072: Embeddings are already normalized by the API - - For dimensions < 3072: Embeddings are L2-normalized after retrieval - - Normalization ensures accurate semantic similarity via cosine distance - """ - loop = asyncio.get_running_loop() - - key = _ensure_api_key(api_key) - # Convert timeout from seconds to milliseconds for Gemini API - timeout_ms = timeout * 1000 if timeout else None - client = _get_gemini_client(key, base_url, timeout_ms) - - # Prepare embedding configuration - config_kwargs: dict[str, Any] = {} - - # Add task_type to config - if task_type: - config_kwargs["task_type"] = task_type - - # Add output_dimensionality if embedding_dim is provided - if embedding_dim is not None: - config_kwargs["output_dimensionality"] = embedding_dim - - # Create config object if we have parameters - config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None - - def _call_embed() -> Any: - """Call Gemini embedding API in executor thread.""" - request_kwargs: dict[str, Any] = { - "model": model, - "contents": texts, - } - if config_obj is not None: - request_kwargs["config"] = config_obj - - return client.models.embed_content(**request_kwargs) - - # Execute API call in thread pool - response = await loop.run_in_executor(None, _call_embed) - - # Extract embeddings from response - if not hasattr(response, "embeddings") or not response.embeddings: - raise RuntimeError("Gemini response did not contain embeddings.") - - # Convert embeddings to numpy array - embeddings = np.array( - [np.array(e.values, dtype=np.float32) for e in response.embeddings] - ) - - # Apply L2 normalization for dimensions < 3072 - # The 3072 dimension embedding is already normalized by Gemini API - if embedding_dim and embedding_dim < 3072: - # Normalize each embedding vector to unit length - norms = np.linalg.norm(embeddings, axis=1, keepdims=True) - # Avoid division by zero - norms = np.where(norms == 0, 1, norms) - embeddings = embeddings / norms - logger.debug( - f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}" - ) - - # Track token usage if tracker is provided - # Note: Gemini embedding API may not provide usage metadata - if token_tracker and hasattr(response, "usage_metadata"): - usage = response.usage_metadata - token_counts = { - "prompt_tokens": getattr(usage, "prompt_token_count", 0), - "total_tokens": getattr(usage, "total_token_count", 0), - } - token_tracker.add_usage(token_counts) - - logger.debug( - f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}" - ) - - return embeddings - - __all__ = [ "gemini_complete_if_cache", "gemini_model_complete", - "gemini_embed", ] diff --git a/requirements-offline-llm.txt b/requirements-offline-llm.txt index fe3fc747..4e8b7168 100644 --- a/requirements-offline-llm.txt +++ b/requirements-offline-llm.txt @@ -10,8 +10,9 @@ # LLM provider dependencies (with version constraints matching pyproject.toml) aioboto3>=12.0.0,<16.0.0 anthropic>=0.18.0,<1.0.0 +google-genai>=1.0.0,<2.0.0 llama-index>=0.9.0,<1.0.0 ollama>=0.1.0,<1.0.0 -openai>=1.0.0,<2.0.0 +openai>=1.0.0,<3.0.0 voyageai>=0.2.0,<1.0.0 zhipuai>=2.0.0,<3.0.0 diff --git a/requirements-offline.txt b/requirements-offline.txt index fe063e88..8dfb1b01 100644 --- a/requirements-offline.txt +++ b/requirements-offline.txt @@ -13,20 +13,21 @@ anthropic>=0.18.0,<1.0.0 # Storage backend dependencies asyncpg>=0.29.0,<1.0.0 +google-genai>=1.0.0,<2.0.0 # Document processing dependencies -docling>=1.0.0,<3.0.0 llama-index>=0.9.0,<1.0.0 neo4j>=5.0.0,<7.0.0 ollama>=0.1.0,<1.0.0 -openai>=1.0.0,<2.0.0 +openai>=1.0.0,<3.0.0 openpyxl>=3.0.0,<4.0.0 +pycryptodome>=3.0.0,<4.0.0 pymilvus>=2.6.2,<3.0.0 pymongo>=4.0.0,<5.0.0 pypdf2>=3.0.0 python-docx>=0.8.11,<2.0.0 python-pptx>=0.6.21,<2.0.0 qdrant-client>=1.7.0,<2.0.0 -redis>=5.0.0,<7.0.0 +redis>=5.0.0,<8.0.0 voyageai>=0.2.0,<1.0.0 zhipuai>=2.0.0,<3.0.0