cherry-pick 3d9de5ed
This commit is contained in:
parent
96f23d59af
commit
3aa214b005
3 changed files with 181 additions and 3 deletions
|
|
@ -16,22 +16,44 @@ from collections.abc import AsyncIterator
|
|||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
|
||||
import numpy as np
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
remove_think_tags,
|
||||
safe_unicode_decode,
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
# Install the Google Gemini client on demand
|
||||
# Install the Google Gemini client and its dependencies on demand
|
||||
if not pm.is_installed("google-genai"):
|
||||
pm.install("google-genai")
|
||||
if not pm.is_installed("google-api-core"):
|
||||
pm.install("google-api-core")
|
||||
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
from google.api_core import exceptions as google_api_exceptions # type: ignore
|
||||
|
||||
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidResponseError(Exception):
|
||||
"""Custom exception class for triggering retry mechanism when Gemini returns empty responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _get_gemini_client(
|
||||
api_key: str, base_url: str | None, timeout: int | None = None
|
||||
|
|
@ -163,6 +185,21 @@ def _extract_response_text(
|
|||
return ("\n".join(regular_parts), "\n".join(thought_parts))
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(google_api_exceptions.InternalServerError)
|
||||
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
|
||||
| retry_if_exception_type(google_api_exceptions.BadGateway)
|
||||
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
|
||||
| retry_if_exception_type(google_api_exceptions.Aborted)
|
||||
| retry_if_exception_type(google_api_exceptions.Unknown)
|
||||
| retry_if_exception_type(InvalidResponseError)
|
||||
),
|
||||
)
|
||||
async def gemini_complete_if_cache(
|
||||
model: str,
|
||||
prompt: str,
|
||||
|
|
@ -369,7 +406,7 @@ async def gemini_complete_if_cache(
|
|||
final_text = regular_text or ""
|
||||
|
||||
if not final_text:
|
||||
raise RuntimeError("Gemini response did not contain any text content.")
|
||||
raise InvalidResponseError("Gemini response did not contain any text content.")
|
||||
|
||||
if "\\u" in final_text:
|
||||
final_text = safe_unicode_decode(final_text.encode("utf-8"))
|
||||
|
|
@ -416,7 +453,143 @@ async def gemini_model_complete(
|
|||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(google_api_exceptions.InternalServerError)
|
||||
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
|
||||
| retry_if_exception_type(google_api_exceptions.BadGateway)
|
||||
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
|
||||
| retry_if_exception_type(google_api_exceptions.Aborted)
|
||||
| retry_if_exception_type(google_api_exceptions.Unknown)
|
||||
),
|
||||
)
|
||||
async def gemini_embed(
|
||||
texts: list[str],
|
||||
model: str = "gemini-embedding-001",
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
embedding_dim: int | None = None,
|
||||
task_type: str = "RETRIEVAL_DOCUMENT",
|
||||
timeout: int | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using Gemini's API.
|
||||
|
||||
This function uses Google's Gemini embedding model to generate text embeddings.
|
||||
It supports dynamic dimension control and automatic normalization for dimensions
|
||||
less than 3072.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The Gemini embedding model to use. Default is "gemini-embedding-001".
|
||||
base_url: Optional custom API endpoint.
|
||||
api_key: Optional Gemini API key. If None, uses environment variables.
|
||||
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||
Do NOT manually pass this parameter when calling the function directly.
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
|
||||
or the EMBEDDING_DIM environment variable.
|
||||
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
|
||||
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
|
||||
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
|
||||
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
|
||||
QUESTION_ANSWERING, FACT_VERIFICATION.
|
||||
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text. For dimensions < 3072,
|
||||
the embeddings are L2-normalized to ensure optimal semantic similarity performance.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not provided or configured.
|
||||
RuntimeError: If the response from Gemini is invalid or empty.
|
||||
|
||||
Note:
|
||||
- For dimension 3072: Embeddings are already normalized by the API
|
||||
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
|
||||
- Normalization ensures accurate semantic similarity via cosine distance
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
key = _ensure_api_key(api_key)
|
||||
# Convert timeout from seconds to milliseconds for Gemini API
|
||||
timeout_ms = timeout * 1000 if timeout else None
|
||||
client = _get_gemini_client(key, base_url, timeout_ms)
|
||||
|
||||
# Prepare embedding configuration
|
||||
config_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Add task_type to config
|
||||
if task_type:
|
||||
config_kwargs["task_type"] = task_type
|
||||
|
||||
# Add output_dimensionality if embedding_dim is provided
|
||||
if embedding_dim is not None:
|
||||
config_kwargs["output_dimensionality"] = embedding_dim
|
||||
|
||||
# Create config object if we have parameters
|
||||
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
|
||||
|
||||
def _call_embed() -> Any:
|
||||
"""Call Gemini embedding API in executor thread."""
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"contents": texts,
|
||||
}
|
||||
if config_obj is not None:
|
||||
request_kwargs["config"] = config_obj
|
||||
|
||||
return client.models.embed_content(**request_kwargs)
|
||||
|
||||
# Execute API call in thread pool
|
||||
response = await loop.run_in_executor(None, _call_embed)
|
||||
|
||||
# Extract embeddings from response
|
||||
if not hasattr(response, "embeddings") or not response.embeddings:
|
||||
raise RuntimeError("Gemini response did not contain embeddings.")
|
||||
|
||||
# Convert embeddings to numpy array
|
||||
embeddings = np.array(
|
||||
[np.array(e.values, dtype=np.float32) for e in response.embeddings]
|
||||
)
|
||||
|
||||
# Apply L2 normalization for dimensions < 3072
|
||||
# The 3072 dimension embedding is already normalized by Gemini API
|
||||
if embedding_dim and embedding_dim < 3072:
|
||||
# Normalize each embedding vector to unit length
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
# Avoid division by zero
|
||||
norms = np.where(norms == 0, 1, norms)
|
||||
embeddings = embeddings / norms
|
||||
logger.debug(
|
||||
f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
|
||||
)
|
||||
|
||||
# Track token usage if tracker is provided
|
||||
# Note: Gemini embedding API may not provide usage metadata
|
||||
if token_tracker and hasattr(response, "usage_metadata"):
|
||||
usage = response.usage_metadata
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
logger.debug(
|
||||
f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
__all__ = [
|
||||
"gemini_complete_if_cache",
|
||||
"gemini_model_complete",
|
||||
"gemini_embed",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
|||
"aiohttp",
|
||||
"configparser",
|
||||
"future",
|
||||
"google-api-core>=2.0.0,<3.0.0",
|
||||
"google-genai>=1.0.0,<2.0.0",
|
||||
"json_repair",
|
||||
"nano-vectordb",
|
||||
|
|
@ -60,6 +61,7 @@ api = [
|
|||
"tenacity",
|
||||
"tiktoken",
|
||||
"xlsxwriter>=3.1.0",
|
||||
"google-api-core>=2.0.0,<3.0.0",
|
||||
"google-genai>=1.0.0,<2.0.0",
|
||||
# API-specific dependencies
|
||||
"aiofiles",
|
||||
|
|
@ -107,6 +109,7 @@ offline-llm = [
|
|||
"aioboto3>=12.0.0,<16.0.0",
|
||||
"voyageai>=0.2.0,<1.0.0",
|
||||
"llama-index>=0.9.0,<1.0.0",
|
||||
"google-api-core>=2.0.0,<3.0.0",
|
||||
"google-genai>=1.0.0,<2.0.0",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
||||
aioboto3>=12.0.0,<16.0.0
|
||||
anthropic>=0.18.0,<1.0.0
|
||||
google-api-core>=2.0.0,<3.0.0
|
||||
google-genai>=1.0.0,<2.0.0
|
||||
llama-index>=0.9.0,<1.0.0
|
||||
ollama>=0.1.0,<1.0.0
|
||||
openai>=1.0.0,<3.0.0
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue