Merge pull request #2329 from danielaskdd/gemini-embedding

Feat: Add Gemini Embedding Support to LightRAG
This commit is contained in:
Daniel.y 2025-11-08 04:10:52 +08:00 committed by GitHub
commit 29a349f25b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 220 additions and 10 deletions

View file

@ -8,6 +8,7 @@ import logging
from dotenv import load_dotenv
from lightrag.utils import get_env_value
from lightrag.llm.binding_options import (
GeminiEmbeddingOptions,
GeminiLLMOptions,
OllamaEmbeddingOptions,
OllamaLLMOptions,
@ -238,7 +239,15 @@ def parse_args() -> argparse.Namespace:
"--embedding-binding",
type=str,
default=get_env_value("EMBEDDING_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
choices=[
"lollms",
"ollama",
"openai",
"azure_openai",
"aws_bedrock",
"jina",
"gemini",
],
help="Embedding binding type (default: from env or ollama)",
)
parser.add_argument(
@ -265,12 +274,19 @@ def parse_args() -> argparse.Namespace:
if "--embedding-binding" in sys.argv:
try:
idx = sys.argv.index("--embedding-binding")
if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "ollama":
OllamaEmbeddingOptions.add_args(parser)
if idx + 1 < len(sys.argv):
if sys.argv[idx + 1] == "ollama":
OllamaEmbeddingOptions.add_args(parser)
elif sys.argv[idx + 1] == "gemini":
GeminiEmbeddingOptions.add_args(parser)
except IndexError:
pass
elif os.environ.get("EMBEDDING_BINDING") == "ollama":
OllamaEmbeddingOptions.add_args(parser)
else:
env_embedding_binding = os.environ.get("EMBEDDING_BINDING")
if env_embedding_binding == "ollama":
OllamaEmbeddingOptions.add_args(parser)
elif env_embedding_binding == "gemini":
GeminiEmbeddingOptions.add_args(parser)
# Add OpenAI LLM options when llm-binding is openai or azure_openai
if "--llm-binding" in sys.argv:

View file

@ -89,6 +89,7 @@ class LLMConfigCache:
# Initialize configurations based on binding conditions
self.openai_llm_options = None
self.gemini_llm_options = None
self.gemini_embedding_options = None
self.ollama_llm_options = None
self.ollama_embedding_options = None
@ -135,6 +136,23 @@ class LLMConfigCache:
)
self.ollama_embedding_options = {}
# Only initialize and log Gemini Embedding options when using Gemini Embedding binding
if args.embedding_binding == "gemini":
try:
from lightrag.llm.binding_options import GeminiEmbeddingOptions
self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict(
args
)
logger.info(
f"Gemini Embedding Options: {self.gemini_embedding_options}"
)
except ImportError:
logger.warning(
"GeminiEmbeddingOptions not available, using default configuration"
)
self.gemini_embedding_options = {}
def check_frontend_build():
"""Check if frontend is built and optionally check if source is up-to-date
@ -296,6 +314,7 @@ def create_app(args):
"azure_openai",
"aws_bedrock",
"jina",
"gemini",
]:
raise Exception("embedding binding not supported")
@ -649,6 +668,26 @@ def create_app(args):
base_url=host,
api_key=api_key,
)
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
if config_cache.gemini_embedding_options is not None:
gemini_options = config_cache.gemini_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import GeminiEmbeddingOptions
gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await gemini_embed(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
)
else: # openai and compatible
from lightrag.llm.openai import openai_embed
@ -718,12 +757,12 @@ def create_app(args):
has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type
# Jina REQUIRES dimension parameter (forced to True)
# Jina and Gemini REQUIRE dimension parameter (forced to True)
# OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable
if args.embedding_binding == "jina":
# Jina API requires dimension parameter - always send it
if args.embedding_binding in ["jina", "gemini"]:
# Jina and Gemini APIs require dimension parameter - always send it
send_dimensions = has_embedding_dim_param
dimension_control = "forced by Jina API"
dimension_control = f"forced by {args.embedding_binding.title()} API"
else:
# For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting
send_dimensions = embedding_send_dim and has_embedding_dim_param

View file

@ -508,6 +508,19 @@ class GeminiLLMOptions(BindingOptions):
}
@dataclass
class GeminiEmbeddingOptions(BindingOptions):
"""Options for Google Gemini embedding models."""
_binding_name: ClassVar[str] = "gemini_embedding"
task_type: str = "RETRIEVAL_DOCUMENT"
_help: ClassVar[dict[str, str]] = {
"task_type": "Task type for embedding optimization (RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION)",
}
# =============================================================================
# Binding Options for OpenAI
# =============================================================================

View file

@ -16,7 +16,20 @@ from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
logger,
remove_think_tags,
safe_unicode_decode,
wrap_embedding_func_with_attrs,
)
import pipmaster as pm
@ -416,7 +429,136 @@ async def gemini_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(Exception) # Gemini uses generic exceptions
),
)
async def gemini_embed(
texts: list[str],
model: str = "gemini-embedding-001",
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
task_type: str = "RETRIEVAL_DOCUMENT",
timeout: int | None = None,
token_tracker: Any | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using Gemini's API.
This function uses Google's Gemini embedding model to generate text embeddings.
It supports dynamic dimension control and automatic normalization for dimensions
less than 3072.
Args:
texts: List of texts to embed.
model: The Gemini embedding model to use. Default is "gemini-embedding-001".
base_url: Optional custom API endpoint.
api_key: Optional Gemini API key. If None, uses environment variables.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
or the EMBEDDING_DIM environment variable.
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
QUESTION_ANSWERING, FACT_VERIFICATION.
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
token_tracker: Optional token usage tracker for monitoring API usage.
Returns:
A numpy array of embeddings, one per input text. For dimensions < 3072,
the embeddings are L2-normalized to ensure optimal semantic similarity performance.
Raises:
ValueError: If API key is not provided or configured.
RuntimeError: If the response from Gemini is invalid or empty.
Note:
- For dimension 3072: Embeddings are already normalized by the API
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
- Normalization ensures accurate semantic similarity via cosine distance
"""
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
# Convert timeout from seconds to milliseconds for Gemini API
timeout_ms = timeout * 1000 if timeout else None
client = _get_gemini_client(key, base_url, timeout_ms)
# Prepare embedding configuration
config_kwargs: dict[str, Any] = {}
# Add task_type to config
if task_type:
config_kwargs["task_type"] = task_type
# Add output_dimensionality if embedding_dim is provided
if embedding_dim is not None:
config_kwargs["output_dimensionality"] = embedding_dim
# Create config object if we have parameters
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
def _call_embed() -> Any:
"""Call Gemini embedding API in executor thread."""
request_kwargs: dict[str, Any] = {
"model": model,
"contents": texts,
}
if config_obj is not None:
request_kwargs["config"] = config_obj
return client.models.embed_content(**request_kwargs)
# Execute API call in thread pool
response = await loop.run_in_executor(None, _call_embed)
# Extract embeddings from response
if not hasattr(response, "embeddings") or not response.embeddings:
raise RuntimeError("Gemini response did not contain embeddings.")
# Convert embeddings to numpy array
embeddings = np.array(
[np.array(e.values, dtype=np.float32) for e in response.embeddings]
)
# Apply L2 normalization for dimensions < 3072
# The 3072 dimension embedding is already normalized by Gemini API
if embedding_dim and embedding_dim < 3072:
# Normalize each embedding vector to unit length
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
# Avoid division by zero
norms = np.where(norms == 0, 1, norms)
embeddings = embeddings / norms
logger.debug(
f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
)
# Track token usage if tracker is provided
# Note: Gemini embedding API may not provide usage metadata
if token_tracker and hasattr(response, "usage_metadata"):
usage = response.usage_metadata
token_counts = {
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(
f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
)
return embeddings
__all__ = [
"gemini_complete_if_cache",
"gemini_model_complete",
"gemini_embed",
]