extended to use gemini, sswitched to use gemini-flash-latest

This commit is contained in:
Humphry 2025-10-20 13:17:16 +03:00
parent c0f69395c7
commit 0b3d31507e
10 changed files with 429 additions and 5 deletions

View file

@ -120,6 +120,8 @@ cp env.example .env
docker compose up
```
> Tip: When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`). The server now understands this binding out of the box.
> Historical versions of LightRAG docker images can be found here: [LightRAG Docker Images]( https://github.com/HKUDS/LightRAG/pkgs/container/lightrag)
### Install LightRAG Core

View file

@ -154,7 +154,7 @@ MAX_PARALLEL_INSERT=2
###########################################################
### LLM Configuration
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini
###########################################################
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
# LLM_TIMEOUT=180
@ -174,6 +174,14 @@ LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING=openai
### Gemini example
# LLM_BINDING=gemini
# LLM_MODEL=gemini-flash-latest
# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
# LLM_BINDING_API_KEY=your_gemini_api_key
# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192
# GEMINI_LLM_TEMPERATURE=0.7
### OpenAI Compatible API Specific Parameters
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
# OPENAI_LLM_TEMPERATURE=0.9

View file

@ -8,6 +8,7 @@ import logging
from dotenv import load_dotenv
from lightrag.utils import get_env_value
from lightrag.llm.binding_options import (
GeminiLLMOptions,
OllamaEmbeddingOptions,
OllamaLLMOptions,
OpenAILLMOptions,
@ -63,6 +64,9 @@ def get_default_host(binding_type: str) -> str:
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
"gemini": os.getenv(
"LLM_BINDING_HOST", "https://generativelanguage.googleapis.com"
),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
@ -226,6 +230,7 @@ def parse_args() -> argparse.Namespace:
"openai-ollama",
"azure_openai",
"aws_bedrock",
"gemini",
],
help="LLM binding type (default: from env or ollama)",
)
@ -281,6 +286,16 @@ def parse_args() -> argparse.Namespace:
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
OpenAILLMOptions.add_args(parser)
if "--llm-binding" in sys.argv:
try:
idx = sys.argv.index("--llm-binding")
if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "gemini":
GeminiLLMOptions.add_args(parser)
except IndexError:
pass
elif os.environ.get("LLM_BINDING") == "gemini":
GeminiLLMOptions.add_args(parser)
args = parser.parse_args()
# convert relative path to absolute path

View file

@ -104,6 +104,7 @@ class LLMConfigCache:
# Initialize configurations based on binding conditions
self.openai_llm_options = None
self.gemini_llm_options = None
self.ollama_llm_options = None
self.ollama_embedding_options = None
@ -114,6 +115,12 @@ class LLMConfigCache:
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
if args.llm_binding == "gemini":
from lightrag.llm.binding_options import GeminiLLMOptions
self.gemini_llm_options = GeminiLLMOptions.options_dict(args)
logger.info(f"Gemini LLM Options: {self.gemini_llm_options}")
# Only initialize and log Ollama LLM options when using Ollama LLM binding
if args.llm_binding == "ollama":
try:
@ -282,6 +289,7 @@ def create_app(args):
"openai",
"azure_openai",
"aws_bedrock",
"gemini",
]:
raise Exception("llm binding not supported")
@ -500,6 +508,42 @@ def create_app(args):
return optimized_azure_openai_model_complete
def create_optimized_gemini_llm_func(
config_cache: LLMConfigCache, args
):
"""Create optimized Gemini LLM function with cached configuration"""
async def optimized_gemini_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
from lightrag.llm.gemini import gemini_complete_if_cache
if history_messages is None:
history_messages = []
if (
config_cache.gemini_llm_options is not None
and "generation_config" not in kwargs
):
kwargs["generation_config"] = dict(config_cache.gemini_llm_options)
return await gemini_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=args.llm_binding_api_key,
base_url=args.llm_binding_host,
keyword_extraction=keyword_extraction,
**kwargs,
)
return optimized_gemini_model_complete
def create_llm_model_func(binding: str):
"""
Create LLM model function based on binding type.
@ -521,6 +565,8 @@ def create_app(args):
return create_optimized_azure_openai_llm_func(
config_cache, args, llm_timeout
)
elif binding == "gemini":
return create_optimized_gemini_llm_func(config_cache, args)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)

View file

@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace
import argparse
import json
from dataclasses import asdict, dataclass, field
from typing import Any, ClassVar, List
from typing import Any, ClassVar, List, get_args, get_origin
from lightrag.utils import get_env_value
from lightrag.constants import DEFAULT_TEMPERATURE
def _resolve_optional_type(field_type: Any) -> Any:
"""Return the concrete type for Optional/Union annotations."""
origin = get_origin(field_type)
if origin in (list, dict, tuple):
return field_type
args = get_args(field_type)
if args:
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return non_none_args[0]
return field_type
# =============================================================================
# BindingOptions Base Class
# =============================================================================
@ -177,9 +191,13 @@ class BindingOptions:
help=arg_item["help"],
)
else:
resolved_type = arg_item["type"]
if resolved_type is not None:
resolved_type = _resolve_optional_type(resolved_type)
group.add_argument(
f"--{arg_item['argname']}",
type=arg_item["type"],
type=resolved_type,
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
@ -210,7 +228,7 @@ class BindingOptions:
argdef = {
"argname": f"{args_prefix}-{field.name}",
"env_name": f"{env_var_prefix}{field.name.upper()}",
"type": field.type,
"type": _resolve_optional_type(field.type),
"default": default_value,
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
}
@ -454,6 +472,39 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
_binding_name: ClassVar[str] = "ollama_llm"
@dataclass
class GeminiLLMOptions(BindingOptions):
"""Options for Google Gemini models."""
_binding_name: ClassVar[str] = "gemini_llm"
temperature: float = DEFAULT_TEMPERATURE
top_p: float = 0.95
top_k: int = 40
max_output_tokens: int | None = None
candidate_count: int = 1
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
stop_sequences: List[str] = field(default_factory=list)
response_mime_type: str | None = None
safety_settings: dict | None = None
system_instruction: str | None = None
_help: ClassVar[dict[str, str]] = {
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
"top_p": "Nucleus sampling parameter (0.0-1.0)",
"top_k": "Limits sampling to the top K tokens (1 disables the limit)",
"max_output_tokens": "Maximum tokens generated in the response",
"candidate_count": "Number of candidates returned per request",
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
"stop_sequences": 'Stop sequences (JSON array of strings, e.g., \'["END"]\')',
"response_mime_type": "Desired MIME type for the response (e.g., application/json)",
"safety_settings": "JSON object with Gemini safety settings overrides",
"system_instruction": "Default system instruction applied to every request",
}
# =============================================================================
# Binding Options for OpenAI
# =============================================================================

297
lightrag/llm/gemini.py Normal file
View file

@ -0,0 +1,297 @@
"""
Gemini LLM binding for LightRAG.
This module provides asynchronous helpers that adapt Google's Gemini models
to the same interface used by the rest of the LightRAG LLM bindings. The
implementation mirrors the OpenAI helpers while relying on the official
``google-genai`` client under the hood.
"""
from __future__ import annotations
import asyncio
import logging
import os
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
import pipmaster as pm
# Install the Google Gemini client on demand
if not pm.is_installed("google-genai"):
pm.install("google-genai")
from google import genai # type: ignore
from google.genai import types # type: ignore
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
LOG = logging.getLogger(__name__)
@lru_cache(maxsize=8)
def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client:
"""
Create (or fetch cached) Gemini client.
Args:
api_key: Google Gemini API key.
base_url: Optional custom API endpoint.
Returns:
genai.Client: Configured Gemini client instance.
"""
client_kwargs: dict[str, Any] = {"api_key": api_key}
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
try:
client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url)
except Exception as exc: # pragma: no cover - defensive
LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc)
try:
return genai.Client(**client_kwargs)
except TypeError:
# Older google-genai releases don't accept http_options; retry without it.
client_kwargs.pop("http_options", None)
return genai.Client(**client_kwargs)
def _ensure_api_key(api_key: str | None) -> str:
key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
if not key:
raise ValueError(
"Gemini API key not provided. "
"Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
)
return key
def _build_generation_config(
base_config: dict[str, Any] | None,
system_prompt: str | None,
keyword_extraction: bool,
) -> types.GenerateContentConfig | None:
config_data = dict(base_config or {})
if system_prompt:
if config_data.get("system_instruction"):
config_data["system_instruction"] = (
f"{config_data['system_instruction']}\n{system_prompt}"
)
else:
config_data["system_instruction"] = system_prompt
if keyword_extraction and not config_data.get("response_mime_type"):
config_data["response_mime_type"] = "application/json"
# Remove entries that are explicitly set to None to avoid type errors
sanitized = {
key: value
for key, value in config_data.items()
if value is not None and value != ""
}
if not sanitized:
return None
return types.GenerateContentConfig(**sanitized)
def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
if not history_messages:
return ""
history_lines: list[str] = []
for message in history_messages:
role = message.get("role", "user")
content = message.get("content", "")
history_lines.append(f"[{role}] {content}")
return "\n".join(history_lines)
def _extract_response_text(response: Any) -> str:
if getattr(response, "text", None):
return response.text
candidates = getattr(response, "candidates", None)
if not candidates:
return ""
parts: list[str] = []
for candidate in candidates:
if not getattr(candidate, "content", None):
continue
for part in getattr(candidate.content, "parts", []):
text = getattr(part, "text", None)
if text:
parts.append(text)
return "\n".join(parts)
async def gemini_complete_if_cache(
model: str,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
*,
api_key: str | None = None,
base_url: str | None = None,
generation_config: dict[str, Any] | None = None,
keyword_extraction: bool = False,
token_tracker: Any | None = None,
hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity
stream: bool | None = None,
enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently
timeout: float | None = None, # noqa: ARG001 - handled by caller if needed
**_: Any,
) -> str | AsyncIterator[str]:
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
client = _get_gemini_client(key, base_url)
history_block = _format_history_messages(history_messages)
prompt_sections = []
if history_block:
prompt_sections.append(history_block)
prompt_sections.append(f"[user] {prompt}")
combined_prompt = "\n".join(prompt_sections)
config_obj = _build_generation_config(
generation_config,
system_prompt=system_prompt,
keyword_extraction=keyword_extraction,
)
request_kwargs: dict[str, Any] = {
"model": model,
"contents": [combined_prompt],
}
if config_obj is not None:
request_kwargs["config"] = config_obj
def _call_model():
return client.models.generate_content(**request_kwargs)
if stream:
queue: asyncio.Queue[Any] = asyncio.Queue()
usage_container: dict[str, Any] = {}
def _stream_model() -> None:
try:
stream_kwargs = dict(request_kwargs)
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
for chunk in stream_iterator:
usage = getattr(chunk, "usage_metadata", None)
if usage is not None:
usage_container["usage"] = usage
text_piece = getattr(chunk, "text", None) or _extract_response_text(chunk)
if text_piece:
loop.call_soon_threadsafe(queue.put_nowait, text_piece)
loop.call_soon_threadsafe(queue.put_nowait, None)
except Exception as exc: # pragma: no cover - surface runtime issues
loop.call_soon_threadsafe(queue.put_nowait, exc)
loop.run_in_executor(None, _stream_model)
async def _async_stream() -> AsyncIterator[str]:
accumulated = ""
emitted = ""
try:
while True:
item = await queue.get()
if item is None:
break
if isinstance(item, Exception):
raise item
chunk_text = str(item)
if "\\u" in chunk_text:
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
accumulated += chunk_text
sanitized = remove_think_tags(accumulated)
if sanitized.startswith(emitted):
delta = sanitized[len(emitted) :]
else:
delta = sanitized
emitted = sanitized
if delta:
yield delta
finally:
usage = usage_container.get("usage")
if token_tracker and usage:
token_tracker.add_usage(
{
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"completion_tokens": getattr(
usage, "candidates_token_count", 0
),
"total_tokens": getattr(usage, "total_token_count", 0),
}
)
return _async_stream()
response = await asyncio.to_thread(_call_model)
text = _extract_response_text(response)
if not text:
raise RuntimeError("Gemini response did not contain any text content.")
if "\\u" in text:
text = safe_unicode_decode(text.encode("utf-8"))
text = remove_think_tags(text)
usage = getattr(response, "usage_metadata", None)
if token_tracker and usage:
token_tracker.add_usage(
{
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"completion_tokens": getattr(usage, "candidates_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
}
)
logger.debug("Gemini response length: %s", len(text))
return text
async def gemini_model_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
keyword_extraction: bool = False,
**kwargs: Any,
) -> str | AsyncIterator[str]:
hashing_kv = kwargs.get("hashing_kv")
model_name = None
if hashing_kv is not None:
model_name = hashing_kv.global_config.get("llm_model_name")
if model_name is None:
model_name = kwargs.pop("model_name", None)
if model_name is None:
raise ValueError("Gemini model name not provided in configuration.")
return await gemini_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
__all__ = [
"gemini_complete_if_cache",
"gemini_model_complete",
]

View file

@ -1783,7 +1783,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
- Filter out short numeric-only text (length < 3 and only digits/dots)
- remove_inner_quotes = True
remove Chinese quotes
remove English queotes in and around chinese
remove English quotes in and around chinese
Convert non-breaking spaces to regular spaces
Convert narrow non-breaking spaces after non-digits to regular spaces

View file

@ -24,6 +24,7 @@ dependencies = [
"aiohttp",
"configparser",
"future",
"google-genai>=1.0.0,<2.0.0",
"json_repair",
"nano-vectordb",
"networkx",
@ -59,6 +60,7 @@ api = [
"tenacity",
"tiktoken",
"xlsxwriter>=3.1.0",
"google-genai>=1.0.0,<2.0.0",
# API-specific dependencies
"aiofiles",
"ascii_colors",
@ -105,6 +107,7 @@ offline-llm = [
"aioboto3>=12.0.0,<16.0.0",
"voyageai>=0.2.0,<1.0.0",
"llama-index>=0.9.0,<1.0.0",
"google-genai>=1.0.0,<2.0.0",
]
offline = [

View file

@ -13,5 +13,6 @@ anthropic>=0.18.0,<1.0.0
llama-index>=0.9.0,<1.0.0
ollama>=0.1.0,<1.0.0
openai>=1.0.0,<2.0.0
google-genai>=1.0.0,<2.0.0
voyageai>=0.2.0,<1.0.0
zhipuai>=2.0.0,<3.0.0

View file

@ -19,6 +19,7 @@ llama-index>=0.9.0,<1.0.0
neo4j>=5.0.0,<7.0.0
ollama>=0.1.0,<1.0.0
openai>=1.0.0,<2.0.0
google-genai>=1.0.0,<2.0.0
openpyxl>=3.0.0,<4.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0