cherry-pick 3d9de5ed
This commit is contained in:
parent
96f23d59af
commit
3aa214b005
3 changed files with 181 additions and 3 deletions
|
|
@ -16,22 +16,44 @@ from collections.abc import AsyncIterator
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
|
import numpy as np
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
retry_if_exception_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lightrag.utils import (
|
||||||
|
logger,
|
||||||
|
remove_think_tags,
|
||||||
|
safe_unicode_decode,
|
||||||
|
wrap_embedding_func_with_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
# Install the Google Gemini client on demand
|
# Install the Google Gemini client and its dependencies on demand
|
||||||
if not pm.is_installed("google-genai"):
|
if not pm.is_installed("google-genai"):
|
||||||
pm.install("google-genai")
|
pm.install("google-genai")
|
||||||
|
if not pm.is_installed("google-api-core"):
|
||||||
|
pm.install("google-api-core")
|
||||||
|
|
||||||
from google import genai # type: ignore
|
from google import genai # type: ignore
|
||||||
from google.genai import types # type: ignore
|
from google.genai import types # type: ignore
|
||||||
|
from google.api_core import exceptions as google_api_exceptions # type: ignore
|
||||||
|
|
||||||
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
|
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidResponseError(Exception):
|
||||||
|
"""Custom exception class for triggering retry mechanism when Gemini returns empty responses"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def _get_gemini_client(
|
def _get_gemini_client(
|
||||||
api_key: str, base_url: str | None, timeout: int | None = None
|
api_key: str, base_url: str | None, timeout: int | None = None
|
||||||
|
|
@ -163,6 +185,21 @@ def _extract_response_text(
|
||||||
return ("\n".join(regular_parts), "\n".join(thought_parts))
|
return ("\n".join(regular_parts), "\n".join(thought_parts))
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(google_api_exceptions.InternalServerError)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.BadGateway)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.Aborted)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.Unknown)
|
||||||
|
| retry_if_exception_type(InvalidResponseError)
|
||||||
|
),
|
||||||
|
)
|
||||||
async def gemini_complete_if_cache(
|
async def gemini_complete_if_cache(
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
@ -369,7 +406,7 @@ async def gemini_complete_if_cache(
|
||||||
final_text = regular_text or ""
|
final_text = regular_text or ""
|
||||||
|
|
||||||
if not final_text:
|
if not final_text:
|
||||||
raise RuntimeError("Gemini response did not contain any text content.")
|
raise InvalidResponseError("Gemini response did not contain any text content.")
|
||||||
|
|
||||||
if "\\u" in final_text:
|
if "\\u" in final_text:
|
||||||
final_text = safe_unicode_decode(final_text.encode("utf-8"))
|
final_text = safe_unicode_decode(final_text.encode("utf-8"))
|
||||||
|
|
@ -416,7 +453,143 @@ async def gemini_model_complete(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(google_api_exceptions.InternalServerError)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.BadGateway)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.Aborted)
|
||||||
|
| retry_if_exception_type(google_api_exceptions.Unknown)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def gemini_embed(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "gemini-embedding-001",
|
||||||
|
base_url: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
embedding_dim: int | None = None,
|
||||||
|
task_type: str = "RETRIEVAL_DOCUMENT",
|
||||||
|
timeout: int | None = None,
|
||||||
|
token_tracker: Any | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate embeddings for a list of texts using Gemini's API.
|
||||||
|
|
||||||
|
This function uses Google's Gemini embedding model to generate text embeddings.
|
||||||
|
It supports dynamic dimension control and automatic normalization for dimensions
|
||||||
|
less than 3072.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to embed.
|
||||||
|
model: The Gemini embedding model to use. Default is "gemini-embedding-001".
|
||||||
|
base_url: Optional custom API endpoint.
|
||||||
|
api_key: Optional Gemini API key. If None, uses environment variables.
|
||||||
|
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||||
|
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||||
|
Do NOT manually pass this parameter when calling the function directly.
|
||||||
|
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
|
||||||
|
or the EMBEDDING_DIM environment variable.
|
||||||
|
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
|
||||||
|
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
|
||||||
|
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
|
||||||
|
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
|
||||||
|
QUESTION_ANSWERING, FACT_VERIFICATION.
|
||||||
|
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
|
||||||
|
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array of embeddings, one per input text. For dimensions < 3072,
|
||||||
|
the embeddings are L2-normalized to ensure optimal semantic similarity performance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If API key is not provided or configured.
|
||||||
|
RuntimeError: If the response from Gemini is invalid or empty.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- For dimension 3072: Embeddings are already normalized by the API
|
||||||
|
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
|
||||||
|
- Normalization ensures accurate semantic similarity via cosine distance
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
key = _ensure_api_key(api_key)
|
||||||
|
# Convert timeout from seconds to milliseconds for Gemini API
|
||||||
|
timeout_ms = timeout * 1000 if timeout else None
|
||||||
|
client = _get_gemini_client(key, base_url, timeout_ms)
|
||||||
|
|
||||||
|
# Prepare embedding configuration
|
||||||
|
config_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Add task_type to config
|
||||||
|
if task_type:
|
||||||
|
config_kwargs["task_type"] = task_type
|
||||||
|
|
||||||
|
# Add output_dimensionality if embedding_dim is provided
|
||||||
|
if embedding_dim is not None:
|
||||||
|
config_kwargs["output_dimensionality"] = embedding_dim
|
||||||
|
|
||||||
|
# Create config object if we have parameters
|
||||||
|
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
|
||||||
|
|
||||||
|
def _call_embed() -> Any:
|
||||||
|
"""Call Gemini embedding API in executor thread."""
|
||||||
|
request_kwargs: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"contents": texts,
|
||||||
|
}
|
||||||
|
if config_obj is not None:
|
||||||
|
request_kwargs["config"] = config_obj
|
||||||
|
|
||||||
|
return client.models.embed_content(**request_kwargs)
|
||||||
|
|
||||||
|
# Execute API call in thread pool
|
||||||
|
response = await loop.run_in_executor(None, _call_embed)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
if not hasattr(response, "embeddings") or not response.embeddings:
|
||||||
|
raise RuntimeError("Gemini response did not contain embeddings.")
|
||||||
|
|
||||||
|
# Convert embeddings to numpy array
|
||||||
|
embeddings = np.array(
|
||||||
|
[np.array(e.values, dtype=np.float32) for e in response.embeddings]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply L2 normalization for dimensions < 3072
|
||||||
|
# The 3072 dimension embedding is already normalized by Gemini API
|
||||||
|
if embedding_dim and embedding_dim < 3072:
|
||||||
|
# Normalize each embedding vector to unit length
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
# Avoid division by zero
|
||||||
|
norms = np.where(norms == 0, 1, norms)
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
logger.debug(
|
||||||
|
f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track token usage if tracker is provided
|
||||||
|
# Note: Gemini embedding API may not provide usage metadata
|
||||||
|
if token_tracker and hasattr(response, "usage_metadata"):
|
||||||
|
usage = response.usage_metadata
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||||
|
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"gemini_complete_if_cache",
|
"gemini_complete_if_cache",
|
||||||
"gemini_model_complete",
|
"gemini_model_complete",
|
||||||
|
"gemini_embed",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"configparser",
|
"configparser",
|
||||||
"future",
|
"future",
|
||||||
|
"google-api-core>=2.0.0,<3.0.0",
|
||||||
"google-genai>=1.0.0,<2.0.0",
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
"json_repair",
|
"json_repair",
|
||||||
"nano-vectordb",
|
"nano-vectordb",
|
||||||
|
|
@ -60,6 +61,7 @@ api = [
|
||||||
"tenacity",
|
"tenacity",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"xlsxwriter>=3.1.0",
|
"xlsxwriter>=3.1.0",
|
||||||
|
"google-api-core>=2.0.0,<3.0.0",
|
||||||
"google-genai>=1.0.0,<2.0.0",
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
# API-specific dependencies
|
# API-specific dependencies
|
||||||
"aiofiles",
|
"aiofiles",
|
||||||
|
|
@ -107,6 +109,7 @@ offline-llm = [
|
||||||
"aioboto3>=12.0.0,<16.0.0",
|
"aioboto3>=12.0.0,<16.0.0",
|
||||||
"voyageai>=0.2.0,<1.0.0",
|
"voyageai>=0.2.0,<1.0.0",
|
||||||
"llama-index>=0.9.0,<1.0.0",
|
"llama-index>=0.9.0,<1.0.0",
|
||||||
|
"google-api-core>=2.0.0,<3.0.0",
|
||||||
"google-genai>=1.0.0,<2.0.0",
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@
|
||||||
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
||||||
aioboto3>=12.0.0,<16.0.0
|
aioboto3>=12.0.0,<16.0.0
|
||||||
anthropic>=0.18.0,<1.0.0
|
anthropic>=0.18.0,<1.0.0
|
||||||
|
google-api-core>=2.0.0,<3.0.0
|
||||||
|
google-genai>=1.0.0,<2.0.0
|
||||||
llama-index>=0.9.0,<1.0.0
|
llama-index>=0.9.0,<1.0.0
|
||||||
ollama>=0.1.0,<1.0.0
|
ollama>=0.1.0,<1.0.0
|
||||||
openai>=1.0.0,<3.0.0
|
openai>=1.0.0,<3.0.0
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue