extended to use gemini, sswitched to use gemini-flash-latest
This commit is contained in:
parent
c0f69395c7
commit
0b3d31507e
10 changed files with 429 additions and 5 deletions
|
|
@ -120,6 +120,8 @@ cp env.example .env
|
||||||
docker compose up
|
docker compose up
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> Tip: When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`). The server now understands this binding out of the box.
|
||||||
|
|
||||||
> Historical versions of LightRAG docker images can be found here: [LightRAG Docker Images]( https://github.com/HKUDS/LightRAG/pkgs/container/lightrag)
|
> Historical versions of LightRAG docker images can be found here: [LightRAG Docker Images]( https://github.com/HKUDS/LightRAG/pkgs/container/lightrag)
|
||||||
|
|
||||||
### Install LightRAG Core
|
### Install LightRAG Core
|
||||||
|
|
|
||||||
10
env.example
10
env.example
|
|
@ -154,7 +154,7 @@ MAX_PARALLEL_INSERT=2
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
### LLM Configuration
|
### LLM Configuration
|
||||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini
|
||||||
###########################################################
|
###########################################################
|
||||||
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
|
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
|
||||||
# LLM_TIMEOUT=180
|
# LLM_TIMEOUT=180
|
||||||
|
|
@ -174,6 +174,14 @@ LLM_BINDING_API_KEY=your_api_key
|
||||||
# LLM_BINDING_API_KEY=your_api_key
|
# LLM_BINDING_API_KEY=your_api_key
|
||||||
# LLM_BINDING=openai
|
# LLM_BINDING=openai
|
||||||
|
|
||||||
|
### Gemini example
|
||||||
|
# LLM_BINDING=gemini
|
||||||
|
# LLM_MODEL=gemini-flash-latest
|
||||||
|
# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
|
||||||
|
# LLM_BINDING_API_KEY=your_gemini_api_key
|
||||||
|
# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192
|
||||||
|
# GEMINI_LLM_TEMPERATURE=0.7
|
||||||
|
|
||||||
### OpenAI Compatible API Specific Parameters
|
### OpenAI Compatible API Specific Parameters
|
||||||
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
|
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
|
||||||
# OPENAI_LLM_TEMPERATURE=0.9
|
# OPENAI_LLM_TEMPERATURE=0.9
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from lightrag.utils import get_env_value
|
from lightrag.utils import get_env_value
|
||||||
from lightrag.llm.binding_options import (
|
from lightrag.llm.binding_options import (
|
||||||
|
GeminiLLMOptions,
|
||||||
OllamaEmbeddingOptions,
|
OllamaEmbeddingOptions,
|
||||||
OllamaLLMOptions,
|
OllamaLLMOptions,
|
||||||
OpenAILLMOptions,
|
OpenAILLMOptions,
|
||||||
|
|
@ -63,6 +64,9 @@ def get_default_host(binding_type: str) -> str:
|
||||||
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||||
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||||
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||||
|
"gemini": os.getenv(
|
||||||
|
"LLM_BINDING_HOST", "https://generativelanguage.googleapis.com"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
return default_hosts.get(
|
return default_hosts.get(
|
||||||
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||||
|
|
@ -226,6 +230,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"openai-ollama",
|
"openai-ollama",
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
"aws_bedrock",
|
"aws_bedrock",
|
||||||
|
"gemini",
|
||||||
],
|
],
|
||||||
help="LLM binding type (default: from env or ollama)",
|
help="LLM binding type (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
|
|
@ -281,6 +286,16 @@ def parse_args() -> argparse.Namespace:
|
||||||
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
||||||
OpenAILLMOptions.add_args(parser)
|
OpenAILLMOptions.add_args(parser)
|
||||||
|
|
||||||
|
if "--llm-binding" in sys.argv:
|
||||||
|
try:
|
||||||
|
idx = sys.argv.index("--llm-binding")
|
||||||
|
if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "gemini":
|
||||||
|
GeminiLLMOptions.add_args(parser)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
elif os.environ.get("LLM_BINDING") == "gemini":
|
||||||
|
GeminiLLMOptions.add_args(parser)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# convert relative path to absolute path
|
# convert relative path to absolute path
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,7 @@ class LLMConfigCache:
|
||||||
|
|
||||||
# Initialize configurations based on binding conditions
|
# Initialize configurations based on binding conditions
|
||||||
self.openai_llm_options = None
|
self.openai_llm_options = None
|
||||||
|
self.gemini_llm_options = None
|
||||||
self.ollama_llm_options = None
|
self.ollama_llm_options = None
|
||||||
self.ollama_embedding_options = None
|
self.ollama_embedding_options = None
|
||||||
|
|
||||||
|
|
@ -114,6 +115,12 @@ class LLMConfigCache:
|
||||||
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
|
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
|
||||||
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
|
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
|
||||||
|
|
||||||
|
if args.llm_binding == "gemini":
|
||||||
|
from lightrag.llm.binding_options import GeminiLLMOptions
|
||||||
|
|
||||||
|
self.gemini_llm_options = GeminiLLMOptions.options_dict(args)
|
||||||
|
logger.info(f"Gemini LLM Options: {self.gemini_llm_options}")
|
||||||
|
|
||||||
# Only initialize and log Ollama LLM options when using Ollama LLM binding
|
# Only initialize and log Ollama LLM options when using Ollama LLM binding
|
||||||
if args.llm_binding == "ollama":
|
if args.llm_binding == "ollama":
|
||||||
try:
|
try:
|
||||||
|
|
@ -282,6 +289,7 @@ def create_app(args):
|
||||||
"openai",
|
"openai",
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
"aws_bedrock",
|
"aws_bedrock",
|
||||||
|
"gemini",
|
||||||
]:
|
]:
|
||||||
raise Exception("llm binding not supported")
|
raise Exception("llm binding not supported")
|
||||||
|
|
||||||
|
|
@ -500,6 +508,42 @@ def create_app(args):
|
||||||
|
|
||||||
return optimized_azure_openai_model_complete
|
return optimized_azure_openai_model_complete
|
||||||
|
|
||||||
|
def create_optimized_gemini_llm_func(
|
||||||
|
config_cache: LLMConfigCache, args
|
||||||
|
):
|
||||||
|
"""Create optimized Gemini LLM function with cached configuration"""
|
||||||
|
|
||||||
|
async def optimized_gemini_model_complete(
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=None,
|
||||||
|
keyword_extraction=False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
from lightrag.llm.gemini import gemini_complete_if_cache
|
||||||
|
|
||||||
|
if history_messages is None:
|
||||||
|
history_messages = []
|
||||||
|
|
||||||
|
if (
|
||||||
|
config_cache.gemini_llm_options is not None
|
||||||
|
and "generation_config" not in kwargs
|
||||||
|
):
|
||||||
|
kwargs["generation_config"] = dict(config_cache.gemini_llm_options)
|
||||||
|
|
||||||
|
return await gemini_complete_if_cache(
|
||||||
|
args.llm_model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=args.llm_binding_api_key,
|
||||||
|
base_url=args.llm_binding_host,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimized_gemini_model_complete
|
||||||
|
|
||||||
def create_llm_model_func(binding: str):
|
def create_llm_model_func(binding: str):
|
||||||
"""
|
"""
|
||||||
Create LLM model function based on binding type.
|
Create LLM model function based on binding type.
|
||||||
|
|
@ -521,6 +565,8 @@ def create_app(args):
|
||||||
return create_optimized_azure_openai_llm_func(
|
return create_optimized_azure_openai_llm_func(
|
||||||
config_cache, args, llm_timeout
|
config_cache, args, llm_timeout
|
||||||
)
|
)
|
||||||
|
elif binding == "gemini":
|
||||||
|
return create_optimized_gemini_llm_func(config_cache, args)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
# Use optimized function with pre-processed configuration
|
# Use optimized function with pre-processed configuration
|
||||||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, ClassVar, List
|
from typing import Any, ClassVar, List, get_args, get_origin
|
||||||
|
|
||||||
from lightrag.utils import get_env_value
|
from lightrag.utils import get_env_value
|
||||||
from lightrag.constants import DEFAULT_TEMPERATURE
|
from lightrag.constants import DEFAULT_TEMPERATURE
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_optional_type(field_type: Any) -> Any:
|
||||||
|
"""Return the concrete type for Optional/Union annotations."""
|
||||||
|
origin = get_origin(field_type)
|
||||||
|
if origin in (list, dict, tuple):
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
args = get_args(field_type)
|
||||||
|
if args:
|
||||||
|
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||||
|
if len(non_none_args) == 1:
|
||||||
|
return non_none_args[0]
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# BindingOptions Base Class
|
# BindingOptions Base Class
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -177,9 +191,13 @@ class BindingOptions:
|
||||||
help=arg_item["help"],
|
help=arg_item["help"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
resolved_type = arg_item["type"]
|
||||||
|
if resolved_type is not None:
|
||||||
|
resolved_type = _resolve_optional_type(resolved_type)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
f"--{arg_item['argname']}",
|
f"--{arg_item['argname']}",
|
||||||
type=arg_item["type"],
|
type=resolved_type,
|
||||||
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
||||||
help=arg_item["help"],
|
help=arg_item["help"],
|
||||||
)
|
)
|
||||||
|
|
@ -210,7 +228,7 @@ class BindingOptions:
|
||||||
argdef = {
|
argdef = {
|
||||||
"argname": f"{args_prefix}-{field.name}",
|
"argname": f"{args_prefix}-{field.name}",
|
||||||
"env_name": f"{env_var_prefix}{field.name.upper()}",
|
"env_name": f"{env_var_prefix}{field.name.upper()}",
|
||||||
"type": field.type,
|
"type": _resolve_optional_type(field.type),
|
||||||
"default": default_value,
|
"default": default_value,
|
||||||
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
|
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
|
||||||
}
|
}
|
||||||
|
|
@ -454,6 +472,39 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
|
||||||
_binding_name: ClassVar[str] = "ollama_llm"
|
_binding_name: ClassVar[str] = "ollama_llm"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeminiLLMOptions(BindingOptions):
|
||||||
|
"""Options for Google Gemini models."""
|
||||||
|
|
||||||
|
_binding_name: ClassVar[str] = "gemini_llm"
|
||||||
|
|
||||||
|
temperature: float = DEFAULT_TEMPERATURE
|
||||||
|
top_p: float = 0.95
|
||||||
|
top_k: int = 40
|
||||||
|
max_output_tokens: int | None = None
|
||||||
|
candidate_count: int = 1
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
stop_sequences: List[str] = field(default_factory=list)
|
||||||
|
response_mime_type: str | None = None
|
||||||
|
safety_settings: dict | None = None
|
||||||
|
system_instruction: str | None = None
|
||||||
|
|
||||||
|
_help: ClassVar[dict[str, str]] = {
|
||||||
|
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
|
||||||
|
"top_p": "Nucleus sampling parameter (0.0-1.0)",
|
||||||
|
"top_k": "Limits sampling to the top K tokens (1 disables the limit)",
|
||||||
|
"max_output_tokens": "Maximum tokens generated in the response",
|
||||||
|
"candidate_count": "Number of candidates returned per request",
|
||||||
|
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
|
||||||
|
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
|
||||||
|
"stop_sequences": 'Stop sequences (JSON array of strings, e.g., \'["END"]\')',
|
||||||
|
"response_mime_type": "Desired MIME type for the response (e.g., application/json)",
|
||||||
|
"safety_settings": "JSON object with Gemini safety settings overrides",
|
||||||
|
"system_instruction": "Default system instruction applied to every request",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Binding Options for OpenAI
|
# Binding Options for OpenAI
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
297
lightrag/llm/gemini.py
Normal file
297
lightrag/llm/gemini.py
Normal file
|
|
@ -0,0 +1,297 @@
|
||||||
|
"""
|
||||||
|
Gemini LLM binding for LightRAG.
|
||||||
|
|
||||||
|
This module provides asynchronous helpers that adapt Google's Gemini models
|
||||||
|
to the same interface used by the rest of the LightRAG LLM bindings. The
|
||||||
|
implementation mirrors the OpenAI helpers while relying on the official
|
||||||
|
``google-genai`` client under the hood.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
|
||||||
|
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
# Install the Google Gemini client on demand
|
||||||
|
if not pm.is_installed("google-genai"):
|
||||||
|
pm.install("google-genai")
|
||||||
|
|
||||||
|
from google import genai # type: ignore
|
||||||
|
from google.genai import types # type: ignore
|
||||||
|
|
||||||
|
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client:
|
||||||
|
"""
|
||||||
|
Create (or fetch cached) Gemini client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Google Gemini API key.
|
||||||
|
base_url: Optional custom API endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
genai.Client: Configured Gemini client instance.
|
||||||
|
"""
|
||||||
|
client_kwargs: dict[str, Any] = {"api_key": api_key}
|
||||||
|
|
||||||
|
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
|
||||||
|
try:
|
||||||
|
client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return genai.Client(**client_kwargs)
|
||||||
|
except TypeError:
|
||||||
|
# Older google-genai releases don't accept http_options; retry without it.
|
||||||
|
client_kwargs.pop("http_options", None)
|
||||||
|
return genai.Client(**client_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_api_key(api_key: str | None) -> str:
|
||||||
|
key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||||
|
if not key:
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini API key not provided. "
|
||||||
|
"Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def _build_generation_config(
|
||||||
|
base_config: dict[str, Any] | None,
|
||||||
|
system_prompt: str | None,
|
||||||
|
keyword_extraction: bool,
|
||||||
|
) -> types.GenerateContentConfig | None:
|
||||||
|
config_data = dict(base_config or {})
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
if config_data.get("system_instruction"):
|
||||||
|
config_data["system_instruction"] = (
|
||||||
|
f"{config_data['system_instruction']}\n{system_prompt}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config_data["system_instruction"] = system_prompt
|
||||||
|
|
||||||
|
if keyword_extraction and not config_data.get("response_mime_type"):
|
||||||
|
config_data["response_mime_type"] = "application/json"
|
||||||
|
|
||||||
|
# Remove entries that are explicitly set to None to avoid type errors
|
||||||
|
sanitized = {
|
||||||
|
key: value
|
||||||
|
for key, value in config_data.items()
|
||||||
|
if value is not None and value != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if not sanitized:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return types.GenerateContentConfig(**sanitized)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
|
||||||
|
if not history_messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
history_lines: list[str] = []
|
||||||
|
for message in history_messages:
|
||||||
|
role = message.get("role", "user")
|
||||||
|
content = message.get("content", "")
|
||||||
|
history_lines.append(f"[{role}] {content}")
|
||||||
|
|
||||||
|
return "\n".join(history_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_text(response: Any) -> str:
|
||||||
|
if getattr(response, "text", None):
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
candidates = getattr(response, "candidates", None)
|
||||||
|
if not candidates:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
for candidate in candidates:
|
||||||
|
if not getattr(candidate, "content", None):
|
||||||
|
continue
|
||||||
|
for part in getattr(candidate.content, "parts", []):
|
||||||
|
text = getattr(part, "text", None)
|
||||||
|
if text:
|
||||||
|
parts.append(text)
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def gemini_complete_if_cache(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
history_messages: list[dict[str, Any]] | None = None,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
generation_config: dict[str, Any] | None = None,
|
||||||
|
keyword_extraction: bool = False,
|
||||||
|
token_tracker: Any | None = None,
|
||||||
|
hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity
|
||||||
|
stream: bool | None = None,
|
||||||
|
enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently
|
||||||
|
timeout: float | None = None, # noqa: ARG001 - handled by caller if needed
|
||||||
|
**_: Any,
|
||||||
|
) -> str | AsyncIterator[str]:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
key = _ensure_api_key(api_key)
|
||||||
|
client = _get_gemini_client(key, base_url)
|
||||||
|
|
||||||
|
history_block = _format_history_messages(history_messages)
|
||||||
|
prompt_sections = []
|
||||||
|
if history_block:
|
||||||
|
prompt_sections.append(history_block)
|
||||||
|
prompt_sections.append(f"[user] {prompt}")
|
||||||
|
combined_prompt = "\n".join(prompt_sections)
|
||||||
|
|
||||||
|
config_obj = _build_generation_config(
|
||||||
|
generation_config,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_kwargs: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"contents": [combined_prompt],
|
||||||
|
}
|
||||||
|
if config_obj is not None:
|
||||||
|
request_kwargs["config"] = config_obj
|
||||||
|
|
||||||
|
def _call_model():
|
||||||
|
return client.models.generate_content(**request_kwargs)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
queue: asyncio.Queue[Any] = asyncio.Queue()
|
||||||
|
usage_container: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def _stream_model() -> None:
|
||||||
|
try:
|
||||||
|
stream_kwargs = dict(request_kwargs)
|
||||||
|
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
|
||||||
|
for chunk in stream_iterator:
|
||||||
|
usage = getattr(chunk, "usage_metadata", None)
|
||||||
|
if usage is not None:
|
||||||
|
usage_container["usage"] = usage
|
||||||
|
text_piece = getattr(chunk, "text", None) or _extract_response_text(chunk)
|
||||||
|
if text_piece:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, text_piece)
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||||
|
except Exception as exc: # pragma: no cover - surface runtime issues
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
||||||
|
|
||||||
|
loop.run_in_executor(None, _stream_model)
|
||||||
|
|
||||||
|
async def _async_stream() -> AsyncIterator[str]:
|
||||||
|
accumulated = ""
|
||||||
|
emitted = ""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
|
||||||
|
chunk_text = str(item)
|
||||||
|
if "\\u" in chunk_text:
|
||||||
|
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
|
||||||
|
|
||||||
|
accumulated += chunk_text
|
||||||
|
sanitized = remove_think_tags(accumulated)
|
||||||
|
if sanitized.startswith(emitted):
|
||||||
|
delta = sanitized[len(emitted) :]
|
||||||
|
else:
|
||||||
|
delta = sanitized
|
||||||
|
emitted = sanitized
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
finally:
|
||||||
|
usage = usage_container.get("usage")
|
||||||
|
if token_tracker and usage:
|
||||||
|
token_tracker.add_usage(
|
||||||
|
{
|
||||||
|
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||||
|
"completion_tokens": getattr(
|
||||||
|
usage, "candidates_token_count", 0
|
||||||
|
),
|
||||||
|
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return _async_stream()
|
||||||
|
|
||||||
|
response = await asyncio.to_thread(_call_model)
|
||||||
|
|
||||||
|
text = _extract_response_text(response)
|
||||||
|
if not text:
|
||||||
|
raise RuntimeError("Gemini response did not contain any text content.")
|
||||||
|
|
||||||
|
if "\\u" in text:
|
||||||
|
text = safe_unicode_decode(text.encode("utf-8"))
|
||||||
|
|
||||||
|
text = remove_think_tags(text)
|
||||||
|
|
||||||
|
usage = getattr(response, "usage_metadata", None)
|
||||||
|
if token_tracker and usage:
|
||||||
|
token_tracker.add_usage(
|
||||||
|
{
|
||||||
|
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||||
|
"completion_tokens": getattr(usage, "candidates_token_count", 0),
|
||||||
|
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Gemini response length: %s", len(text))
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
async def gemini_model_complete(
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
history_messages: list[dict[str, Any]] | None = None,
|
||||||
|
keyword_extraction: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str | AsyncIterator[str]:
|
||||||
|
hashing_kv = kwargs.get("hashing_kv")
|
||||||
|
model_name = None
|
||||||
|
if hashing_kv is not None:
|
||||||
|
model_name = hashing_kv.global_config.get("llm_model_name")
|
||||||
|
if model_name is None:
|
||||||
|
model_name = kwargs.pop("model_name", None)
|
||||||
|
if model_name is None:
|
||||||
|
raise ValueError("Gemini model name not provided in configuration.")
|
||||||
|
|
||||||
|
return await gemini_complete_if_cache(
|
||||||
|
model_name,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"gemini_complete_if_cache",
|
||||||
|
"gemini_model_complete",
|
||||||
|
]
|
||||||
|
|
@ -1783,7 +1783,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
|
||||||
- Filter out short numeric-only text (length < 3 and only digits/dots)
|
- Filter out short numeric-only text (length < 3 and only digits/dots)
|
||||||
- remove_inner_quotes = True
|
- remove_inner_quotes = True
|
||||||
remove Chinese quotes
|
remove Chinese quotes
|
||||||
remove English queotes in and around chinese
|
remove English quotes in and around chinese
|
||||||
Convert non-breaking spaces to regular spaces
|
Convert non-breaking spaces to regular spaces
|
||||||
Convert narrow non-breaking spaces after non-digits to regular spaces
|
Convert narrow non-breaking spaces after non-digits to regular spaces
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"configparser",
|
"configparser",
|
||||||
"future",
|
"future",
|
||||||
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
"json_repair",
|
"json_repair",
|
||||||
"nano-vectordb",
|
"nano-vectordb",
|
||||||
"networkx",
|
"networkx",
|
||||||
|
|
@ -59,6 +60,7 @@ api = [
|
||||||
"tenacity",
|
"tenacity",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"xlsxwriter>=3.1.0",
|
"xlsxwriter>=3.1.0",
|
||||||
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
# API-specific dependencies
|
# API-specific dependencies
|
||||||
"aiofiles",
|
"aiofiles",
|
||||||
"ascii_colors",
|
"ascii_colors",
|
||||||
|
|
@ -105,6 +107,7 @@ offline-llm = [
|
||||||
"aioboto3>=12.0.0,<16.0.0",
|
"aioboto3>=12.0.0,<16.0.0",
|
||||||
"voyageai>=0.2.0,<1.0.0",
|
"voyageai>=0.2.0,<1.0.0",
|
||||||
"llama-index>=0.9.0,<1.0.0",
|
"llama-index>=0.9.0,<1.0.0",
|
||||||
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
offline = [
|
offline = [
|
||||||
|
|
|
||||||
|
|
@ -13,5 +13,6 @@ anthropic>=0.18.0,<1.0.0
|
||||||
llama-index>=0.9.0,<1.0.0
|
llama-index>=0.9.0,<1.0.0
|
||||||
ollama>=0.1.0,<1.0.0
|
ollama>=0.1.0,<1.0.0
|
||||||
openai>=1.0.0,<2.0.0
|
openai>=1.0.0,<2.0.0
|
||||||
|
google-genai>=1.0.0,<2.0.0
|
||||||
voyageai>=0.2.0,<1.0.0
|
voyageai>=0.2.0,<1.0.0
|
||||||
zhipuai>=2.0.0,<3.0.0
|
zhipuai>=2.0.0,<3.0.0
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ llama-index>=0.9.0,<1.0.0
|
||||||
neo4j>=5.0.0,<7.0.0
|
neo4j>=5.0.0,<7.0.0
|
||||||
ollama>=0.1.0,<1.0.0
|
ollama>=0.1.0,<1.0.0
|
||||||
openai>=1.0.0,<2.0.0
|
openai>=1.0.0,<2.0.0
|
||||||
|
google-genai>=1.0.0,<2.0.0
|
||||||
openpyxl>=3.0.0,<4.0.0
|
openpyxl>=3.0.0,<4.0.0
|
||||||
pymilvus>=2.6.2,<3.0.0
|
pymilvus>=2.6.2,<3.0.0
|
||||||
pymongo>=4.0.0,<5.0.0
|
pymongo>=4.0.0,<5.0.0
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue