This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:30 +08:00
parent 96f23d59af
commit 3aa214b005
3 changed files with 181 additions and 3 deletions

View file

@ -16,22 +16,44 @@ from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
logger,
remove_think_tags,
safe_unicode_decode,
wrap_embedding_func_with_attrs,
)
import pipmaster as pm
# Install the Google Gemini client on demand
# Install the Google Gemini client and its dependencies on demand
if not pm.is_installed("google-genai"):
pm.install("google-genai")
if not pm.is_installed("google-api-core"):
pm.install("google-api-core")
from google import genai # type: ignore
from google.genai import types # type: ignore
from google.api_core import exceptions as google_api_exceptions # type: ignore
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
LOG = logging.getLogger(__name__)
class InvalidResponseError(Exception):
"""Custom exception class for triggering retry mechanism when Gemini returns empty responses"""
pass
@lru_cache(maxsize=8)
def _get_gemini_client(
api_key: str, base_url: str | None, timeout: int | None = None
@ -163,6 +185,21 @@ def _extract_response_text(
return ("\n".join(regular_parts), "\n".join(thought_parts))
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(google_api_exceptions.InternalServerError)
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
| retry_if_exception_type(google_api_exceptions.BadGateway)
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
| retry_if_exception_type(google_api_exceptions.Aborted)
| retry_if_exception_type(google_api_exceptions.Unknown)
| retry_if_exception_type(InvalidResponseError)
),
)
async def gemini_complete_if_cache(
model: str,
prompt: str,
@ -369,7 +406,7 @@ async def gemini_complete_if_cache(
final_text = regular_text or ""
if not final_text:
raise RuntimeError("Gemini response did not contain any text content.")
raise InvalidResponseError("Gemini response did not contain any text content.")
if "\\u" in final_text:
final_text = safe_unicode_decode(final_text.encode("utf-8"))
@ -416,7 +453,143 @@ async def gemini_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(google_api_exceptions.InternalServerError)
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
| retry_if_exception_type(google_api_exceptions.BadGateway)
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
| retry_if_exception_type(google_api_exceptions.Aborted)
| retry_if_exception_type(google_api_exceptions.Unknown)
),
)
async def gemini_embed(
texts: list[str],
model: str = "gemini-embedding-001",
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
task_type: str = "RETRIEVAL_DOCUMENT",
timeout: int | None = None,
token_tracker: Any | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using Gemini's API.
This function uses Google's Gemini embedding model to generate text embeddings.
It supports dynamic dimension control and automatic normalization for dimensions
less than 3072.
Args:
texts: List of texts to embed.
model: The Gemini embedding model to use. Default is "gemini-embedding-001".
base_url: Optional custom API endpoint.
api_key: Optional Gemini API key. If None, uses environment variables.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
or the EMBEDDING_DIM environment variable.
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
QUESTION_ANSWERING, FACT_VERIFICATION.
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
token_tracker: Optional token usage tracker for monitoring API usage.
Returns:
A numpy array of embeddings, one per input text. For dimensions < 3072,
the embeddings are L2-normalized to ensure optimal semantic similarity performance.
Raises:
ValueError: If API key is not provided or configured.
RuntimeError: If the response from Gemini is invalid or empty.
Note:
- For dimension 3072: Embeddings are already normalized by the API
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
- Normalization ensures accurate semantic similarity via cosine distance
"""
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
# Convert timeout from seconds to milliseconds for Gemini API
timeout_ms = timeout * 1000 if timeout else None
client = _get_gemini_client(key, base_url, timeout_ms)
# Prepare embedding configuration
config_kwargs: dict[str, Any] = {}
# Add task_type to config
if task_type:
config_kwargs["task_type"] = task_type
# Add output_dimensionality if embedding_dim is provided
if embedding_dim is not None:
config_kwargs["output_dimensionality"] = embedding_dim
# Create config object if we have parameters
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
def _call_embed() -> Any:
"""Call Gemini embedding API in executor thread."""
request_kwargs: dict[str, Any] = {
"model": model,
"contents": texts,
}
if config_obj is not None:
request_kwargs["config"] = config_obj
return client.models.embed_content(**request_kwargs)
# Execute API call in thread pool
response = await loop.run_in_executor(None, _call_embed)
# Extract embeddings from response
if not hasattr(response, "embeddings") or not response.embeddings:
raise RuntimeError("Gemini response did not contain embeddings.")
# Convert embeddings to numpy array
embeddings = np.array(
[np.array(e.values, dtype=np.float32) for e in response.embeddings]
)
# Apply L2 normalization for dimensions < 3072
# The 3072 dimension embedding is already normalized by Gemini API
if embedding_dim and embedding_dim < 3072:
# Normalize each embedding vector to unit length
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
# Avoid division by zero
norms = np.where(norms == 0, 1, norms)
embeddings = embeddings / norms
logger.debug(
f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
)
# Track token usage if tracker is provided
# Note: Gemini embedding API may not provide usage metadata
if token_tracker and hasattr(response, "usage_metadata"):
usage = response.usage_metadata
token_counts = {
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(
f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
)
return embeddings
__all__ = [
"gemini_complete_if_cache",
"gemini_model_complete",
"gemini_embed",
]

View file

@ -24,6 +24,7 @@ dependencies = [
"aiohttp",
"configparser",
"future",
"google-api-core>=2.0.0,<3.0.0",
"google-genai>=1.0.0,<2.0.0",
"json_repair",
"nano-vectordb",
@ -60,6 +61,7 @@ api = [
"tenacity",
"tiktoken",
"xlsxwriter>=3.1.0",
"google-api-core>=2.0.0,<3.0.0",
"google-genai>=1.0.0,<2.0.0",
# API-specific dependencies
"aiofiles",
@ -107,6 +109,7 @@ offline-llm = [
"aioboto3>=12.0.0,<16.0.0",
"voyageai>=0.2.0,<1.0.0",
"llama-index>=0.9.0,<1.0.0",
"google-api-core>=2.0.0,<3.0.0",
"google-genai>=1.0.0,<2.0.0",
]

View file

@ -10,6 +10,8 @@
# LLM provider dependencies (with version constraints matching pyproject.toml)
aioboto3>=12.0.0,<16.0.0
anthropic>=0.18.0,<1.0.0
google-api-core>=2.0.0,<3.0.0
google-genai>=1.0.0,<2.0.0
llama-index>=0.9.0,<1.0.0
ollama>=0.1.0,<1.0.0
openai>=1.0.0,<3.0.0