Merge branch 'main' into feature-add-tigergraph-support

This commit is contained in:
Alexander Belikov 2025-11-07 14:18:33 +01:00
commit dc2898d358
26 changed files with 1361 additions and 1476 deletions

4
.gitignore vendored
View file

@ -67,8 +67,8 @@ download_models_hf.py
# Frontend build output (built during PyPI release)
/lightrag/api/webui/
# unit-test files
test_*
# temporary test files in project root
/test_*
# Cline files
memory-bank

View file

@ -53,7 +53,7 @@
## 🎉 新闻
- [x] [2025.11.05]🎯📢添加**基于RAGAS的**LightRAG评估框架。
- [x] [2025.11.05]🎯📢添加**基于RAGAS的**评估框架和**Langfuse**可观测性支持
- [x] [2025.10.22]🎯📢消除处理**大规模数据集**的瓶颈。
- [x] [2025.09.15]🎯📢显著提升**小型LLM**如Qwen3-30B-A3B的知识图谱提取准确性。
- [x] [2025.08.29]🎯📢现已支持**Reranker**,显著提升混合查询性能。
@ -1463,6 +1463,50 @@ LightRAG服务器提供全面的知识图谱可视化功能。它支持各种重
![iShot_2025-03-23_12.40.08](./README.assets/iShot_2025-03-23_12.40.08.png)
## Langfuse 可观测性集成
Langfuse 为 OpenAI 客户端提供了直接替代方案,可自动跟踪所有 LLM 交互,使开发者能够在无需修改代码的情况下监控、调试和优化其 RAG 系统。
### 安装 Langfuse 可选依赖
```
pip install lightrag-hku
pip install lightrag-hku[observability]
# 或从源代码安装并启用调试模式
pip install -e .
pip install -e ".[observability]"
```
### 配置 Langfuse 环境变量
修改 .env 文件:
```
## Langfuse 可观测性(可选)
# LLM 可观测性和追踪平台
# 安装命令: pip install lightrag-hku[observability]
# 注册地址: https://cloud.langfuse.com 或自托管部署
LANGFUSE_SECRET_KEY=""
LANGFUSE_PUBLIC_KEY=""
LANGFUSE_HOST="https://cloud.langfuse.com" # 或您的自托管实例地址
LANGFUSE_ENABLE_TRACE=true
```
### Langfuse 使用说明
安装并配置完成后Langfuse 会自动追踪所有 OpenAI LLM 调用。Langfuse 仪表板功能包括:
- **追踪**:查看完整的 LLM 调用链
- **分析**Token 使用量、延迟、成本指标
- **调试**:检查提示词和响应内容
- **评估**:比较模型输出结果
- **监控**:实时告警功能
### 重要提示
**注意**LightRAG 目前仅把 OpenAI 兼容的 API 调用接入了 Langfuse。Ollama、Azure 和 AWS Bedrock 等 API 还无法使用 Langfuse 的可观测性功能。
## RAGAS评估
**RAGAS**Retrieval Augmented Generation Assessment检索增强生成评估是一个使用LLM对RAG系统进行无参考评估的框架。我们提供了基于RAGAS的评估脚本。详细信息请参阅[基于RAGAS的评估框架](lightrag/evaluation/README.md)。

View file

@ -51,7 +51,7 @@
---
## 🎉 News
- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework for LightRAG.
- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework and **Langfuse** observability for LightRAG.
- [x] [2025.10.22]🎯📢Eliminate bottlenecks in processing **large-scale datasets**.
- [x] [2025.09.15]🎯📢Significantly enhances KG extraction accuracy for **small LLMs** like Qwen3-30B-A3B.
- [x] [2025.08.29]🎯📢**Reranker** is supported now , significantly boosting performance for mixed queries.
@ -1543,6 +1543,50 @@ The LightRAG Server offers a comprehensive knowledge graph visualization feature
![iShot_2025-03-23_12.40.08](./README.assets/iShot_2025-03-23_12.40.08.png)
## Langfuse observability integration
Langfuse provides a drop-in replacement for the OpenAI client that automatically tracks all LLM interactions, enabling developers to monitor, debug, and optimize their RAG systems without code changes.
### Installation with Langfuse option
```
pip install lightrag-hku
pip install lightrag-hku[observability]
# Or install from souce code with debug mode enabled
pip install -e .
pip install -e ".[observability]"
```
### Config Langfuse env vars
modify .env file:
```
## Langfuse Observability (Optional)
# LLM observability and tracing platform
# Install with: pip install lightrag-hku[observability]
# Sign up at: https://cloud.langfuse.com or self-host
LANGFUSE_SECRET_KEY=""
LANGFUSE_PUBLIC_KEY=""
LANGFUSE_HOST="https://cloud.langfuse.com" # or your self-hosted instance
LANGFUSE_ENABLE_TRACE=true
```
### Langfuse Usage
Once installed and configured, Langfuse automatically traces all OpenAI LLM calls. Langfuse dashboard features include:
- **Tracing**: View complete LLM call chains
- **Analytics**: Token usage, latency, cost metrics
- **Debugging**: Inspect prompts and responses
- **Evaluation**: Compare model outputs
- **Monitoring**: Real-time alerting
### Important Notice
**Note**: LightRAG currently only integrates OpenAI-compatible API calls with Langfuse. APIs such as Ollama, Azure, and AWS Bedrock are not yet supported for Langfuse observability.
## RAGAS-based Evaluation
**RAGAS** (Retrieval Augmented Generation Assessment) is a framework for reference-free evaluation of RAG systems using LLMs. There is an evaluation script based on RAGAS. For detailed information, please refer to [RAGAS-based Evaluation Framework](lightrag/evaluation/README.md).

View file

@ -170,7 +170,7 @@ MAX_PARALLEL_INSERT=2
###########################################################################
### LLM Configuration
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini
### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service
###########################################################################
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
@ -191,6 +191,15 @@ LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING=openai
### Gemini example
# LLM_BINDING=gemini
# LLM_MODEL=gemini-flash-latest
# LLM_BINDING_API_KEY=your_gemini_api_key
# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}'
# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000
# GEMINI_LLM_TEMPERATURE=0.7
### OpenAI Compatible API Specific Parameters
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
# OPENAI_LLM_TEMPERATURE=0.9

View file

@ -1,105 +0,0 @@
# pip install -q -U google-genai to use gemini as a client
import os
import numpy as np
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc
from lightrag import LightRAG, QueryParam
from sentence_transformers import SentenceTransformer
from lightrag.kg.shared_storage import initialize_pipeline_status
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
)
# 4. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
file_path = "story.txt"
with open(file_path, "r") as file:
text = file.read()
rag.insert(text)
response = rag.query(
query="What is the main theme of the story?",
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
)
print(response)
if __name__ == "__main__":
main()

View file

@ -1,230 +0,0 @@
# pip install -q -U google-genai to use gemini as a client
import os
from typing import Optional
import dataclasses
from pathlib import Path
import hashlib
import numpy as np
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc, Tokenizer
from lightrag import LightRAG, QueryParam
from sentence_transformers import SentenceTransformer
from lightrag.kg.shared_storage import initialize_pipeline_status
import sentencepiece as spm
import requests
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
class GemmaTokenizer(Tokenizer):
# adapted from google-cloud-aiplatform[tokenization]
@dataclasses.dataclass(frozen=True)
class _TokenizerConfig:
tokenizer_model_url: str
tokenizer_model_hash: str
_TOKENIZERS = {
"google/gemma2": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
),
"google/gemma3": _TokenizerConfig(
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
),
}
def __init__(
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
):
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
if "1.5" in model_name or "1.0" in model_name:
# up to gemini 1.5 gemma2 is a comparable local tokenizer
# https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
tokenizer_name = "google/gemma2"
else:
# for gemini > 2.0 gemma3 was used
tokenizer_name = "google/gemma3"
file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
tokenizer_model_name = file_url.rsplit("/", 1)[1]
expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
tokenizer_dir = Path(tokenizer_dir)
if tokenizer_dir.is_dir():
file_path = tokenizer_dir / tokenizer_model_name
model_data = self._maybe_load_from_cache(
file_path=file_path, expected_hash=expected_hash
)
else:
model_data = None
if not model_data:
model_data = self._load_from_url(
file_url=file_url, expected_hash=expected_hash
)
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
tokenizer = spm.SentencePieceProcessor()
tokenizer.LoadFromSerializedProto(model_data)
super().__init__(model_name=model_name, tokenizer=tokenizer)
def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
"""Returns true if the content is valid by checking the hash."""
return hashlib.sha256(model_data).hexdigest() == expected_hash
def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
"""Loads the model data from the cache path."""
if not file_path.is_file():
return
with open(file_path, "rb") as f:
content = f.read()
if self._is_valid_model(model_data=content, expected_hash=expected_hash):
return content
# Cached file corrupted.
self._maybe_remove_file(file_path)
def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
"""Loads model bytes from the given file url."""
resp = requests.get(file_url)
resp.raise_for_status()
content = resp.content
if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
actual_hash = hashlib.sha256(content).hexdigest()
raise ValueError(
f"Downloaded model file is corrupted."
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
)
return content
@staticmethod
def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
"""Saves the model data to the cache path."""
try:
if not cache_path.is_file():
cache_dir = cache_path.parent
cache_dir.mkdir(parents=True, exist_ok=True)
with open(cache_path, "wb") as f:
f.write(model_data)
except OSError:
# Don't raise if we cannot write file.
pass
@staticmethod
def _maybe_remove_file(file_path: Path) -> None:
"""Removes the file if exists."""
if not file_path.is_file():
return
try:
file_path.unlink()
except OSError:
# Don't raise if we cannot remove file.
pass
# def encode(self, content: str) -> list[int]:
# return self.tokenizer.encode(content)
# def decode(self, tokens: list[int]) -> str:
# return self.tokenizer.decode(tokens)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
)
# 4. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
# tiktoken_model_name="gpt-4o-mini",
tokenizer=GemmaTokenizer(
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
model_name="gemini-2.0-flash",
),
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
file_path = "story.txt"
with open(file_path, "r") as file:
text = file.read()
rag.insert(text)
response = rag.query(
query="What is the main theme of the story?",
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
)
print(response)
if __name__ == "__main__":
main()

View file

@ -1,151 +0,0 @@
# pip install -q -U google-genai to use gemini as a client
import os
import asyncio
import numpy as np
import nest_asyncio
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc
from lightrag import LightRAG, QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.utils import setup_logger
from lightrag.utils import TokenTracker
setup_logger("lightrag", level="DEBUG")
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
token_tracker = TokenTracker()
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(
max_output_tokens=5000, temperature=0, top_k=10
),
)
# 4. Get token counts with null safety
usage = getattr(response, "usage_metadata", None)
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
total_tokens = getattr(usage, "total_token_count", 0) or (
prompt_tokens + completion_tokens
)
token_counts = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
token_tracker.add_usage(token_counts)
# 5. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
model="BAAI/bge-m3",
api_key=siliconflow_api_key,
max_token_size=512,
)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
entity_extract_max_gleaning=1,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=True,
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Context Manager Method
with token_tracker:
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
print(
rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
)
print(
rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
)
if __name__ == "__main__":
main()

View file

@ -50,6 +50,7 @@ LightRAG necessitates the integration of both an LLM (Large Language Model) and
* openai or openai compatible
* azure_openai
* aws_bedrock
* gemini
It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**.
@ -72,6 +73,8 @@ EMBEDDING_DIM=1024
# EMBEDDING_BINDING_API_KEY=your_api_key
```
> When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`).
* Ollama LLM + Ollama Embedding:
```

View file

@ -1 +1 @@
__api_version__ = "0250"
__api_version__ = "0251"

View file

@ -8,6 +8,7 @@ import logging
from dotenv import load_dotenv
from lightrag.utils import get_env_value
from lightrag.llm.binding_options import (
GeminiLLMOptions,
OllamaEmbeddingOptions,
OllamaLLMOptions,
OpenAILLMOptions,
@ -63,6 +64,9 @@ def get_default_host(binding_type: str) -> str:
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
"gemini": os.getenv(
"LLM_BINDING_HOST", "https://generativelanguage.googleapis.com"
),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
@ -226,6 +230,7 @@ def parse_args() -> argparse.Namespace:
"openai-ollama",
"azure_openai",
"aws_bedrock",
"gemini",
],
help="LLM binding type (default: from env or ollama)",
)
@ -281,6 +286,16 @@ def parse_args() -> argparse.Namespace:
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
OpenAILLMOptions.add_args(parser)
if "--llm-binding" in sys.argv:
try:
idx = sys.argv.index("--llm-binding")
if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "gemini":
GeminiLLMOptions.add_args(parser)
except IndexError:
pass
elif os.environ.get("LLM_BINDING") == "gemini":
GeminiLLMOptions.add_args(parser)
args = parser.parse_args()
# convert relative path to absolute path

View file

@ -89,6 +89,7 @@ class LLMConfigCache:
# Initialize configurations based on binding conditions
self.openai_llm_options = None
self.gemini_llm_options = None
self.ollama_llm_options = None
self.ollama_embedding_options = None
@ -99,6 +100,12 @@ class LLMConfigCache:
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
if args.llm_binding == "gemini":
from lightrag.llm.binding_options import GeminiLLMOptions
self.gemini_llm_options = GeminiLLMOptions.options_dict(args)
logger.info(f"Gemini LLM Options: {self.gemini_llm_options}")
# Only initialize and log Ollama LLM options when using Ollama LLM binding
if args.llm_binding == "ollama":
try:
@ -279,6 +286,7 @@ def create_app(args):
"openai",
"azure_openai",
"aws_bedrock",
"gemini",
]:
raise Exception("llm binding not supported")
@ -504,6 +512,44 @@ def create_app(args):
return optimized_azure_openai_model_complete
def create_optimized_gemini_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
):
"""Create optimized Gemini LLM function with cached configuration"""
async def optimized_gemini_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
from lightrag.llm.gemini import gemini_complete_if_cache
if history_messages is None:
history_messages = []
# Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout
if (
config_cache.gemini_llm_options is not None
and "generation_config" not in kwargs
):
kwargs["generation_config"] = dict(config_cache.gemini_llm_options)
return await gemini_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=args.llm_binding_api_key,
base_url=args.llm_binding_host,
keyword_extraction=keyword_extraction,
**kwargs,
)
return optimized_gemini_model_complete
def create_llm_model_func(binding: str):
"""
Create LLM model function based on binding type.
@ -525,6 +571,8 @@ def create_app(args):
return create_optimized_azure_openai_llm_func(
config_cache, args, llm_timeout
)
elif binding == "gemini":
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)

View file

@ -19,7 +19,6 @@ from typing import (
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
from .constants import (
GRAPH_FIELD_SEP,
DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K,
DEFAULT_MAX_ENTITY_TOKENS,
@ -528,56 +527,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else []
return result
@abstractmethod
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
@abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
# Default implementation iterates through all nodes and their edges, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_edges = []
all_labels = await self.get_all_labels()
processed_edges = set()
for label in all_labels:
edges = await self.get_node_edges(label)
if edges:
for src_id, tgt_id in edges:
# Avoid processing the same edge twice in an undirected graph
edge_tuple = tuple(sorted((src_id, tgt_id)))
if edge_tuple in processed_edges:
continue
processed_edges.add(edge_tuple)
edge = await self.get_edge(src_id, tgt_id)
if edge and "source_id" in edge:
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
# Add source and target to the edge dict for easier processing later
edge_with_nodes = edge.copy()
edge_with_nodes["source"] = src_id
edge_with_nodes["target"] = tgt_id
all_edges.append(edge_with_nodes)
return all_edges
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.

View file

@ -8,7 +8,6 @@ import configparser
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm
@ -784,79 +783,6 @@ class MemgraphStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated nodes for
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated edges for
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target
WITH a, b, r,
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
async def get_knowledge_graph(
self,
node_label: str,

View file

@ -1036,45 +1036,6 @@ class MongoGraphStorage(BaseGraphStorage):
return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if not chunk_ids:
return []
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
return [doc async for doc in cursor]
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if not chunk_ids:
return []
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
edges = []
async for edge in cursor:
edge["source"] = edge["source_node_id"]
edge["target"] = edge["target_node_id"]
edges.append(edge)
return edges
#
# -------------------------------------------------------------------------
# UPSERTS

View file

@ -16,7 +16,6 @@ import logging
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm
@ -904,49 +903,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are fully consumed
return edges_dict
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),

View file

@ -5,7 +5,6 @@ from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
from lightrag.constants import GRAPH_FIELD_SEP
import networkx as nx
from .shared_storage import (
get_storage_lock,
@ -470,33 +469,6 @@ class NetworkXStorage(BaseGraphStorage):
)
return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_nodes = []
for node_id, node_data in graph.nodes(data=True):
if "source_id" in node_data:
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
if not node_source_ids.isdisjoint(chunk_ids_set):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
matching_nodes.append(node_data_with_id)
return matching_nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_edges = []
for u, v, edge_data in graph.edges(data=True):
if "source_id" in edge_data:
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
if not edge_source_ids.isdisjoint(chunk_ids_set):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
matching_edges.append(edge_data_with_nodes)
return matching_edges
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.

View file

@ -33,7 +33,6 @@ from ..base import (
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
import pipmaster as pm
@ -3569,17 +3568,13 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties"""
label = self._normalize_node_id(node_id)
result = await self.get_nodes_batch(node_ids=[label])
result = await self.get_nodes_batch(node_ids=[node_id])
if result and node_id in result:
return result[node_id]
return None
async def node_degree(self, node_id: str) -> int:
label = self._normalize_node_id(node_id)
result = await self.node_degrees_batch(node_ids=[label])
result = await self.node_degrees_batch(node_ids=[node_id])
if result and node_id in result:
return result[node_id]
@ -3592,12 +3587,11 @@ class PGGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes"""
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
if result and (src_label, tgt_label) in result:
return result[(src_label, tgt_label)]
result = await self.get_edges_batch(
[{"src": source_node_id, "tgt": target_node_id}]
)
if result and (source_node_id, target_node_id) in result:
return result[(source_node_id, target_node_id)]
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -3795,13 +3789,17 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
seen = set()
unique_ids = []
seen: set[str] = set()
unique_ids: list[str] = []
lookup: dict[str, str] = {}
requested: set[str] = set()
for nid in node_ids:
nid_norm = self._normalize_node_id(nid)
if nid_norm not in seen:
seen.add(nid_norm)
unique_ids.append(nid_norm)
if nid not in seen:
seen.add(nid)
unique_ids.append(nid)
requested.add(nid)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
# Build result dictionary
nodes_dict = {}
@ -3840,10 +3838,18 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
nodes_dict[result["node_id"]] = node_dict
node_key = result["node_id"]
original_key = lookup.get(node_key)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
nodes_dict[original_key] = node_dict
return nodes_dict
@ -3866,13 +3872,17 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
seen = set()
seen: set[str] = set()
unique_ids: list[str] = []
lookup: dict[str, str] = {}
requested: set[str] = set()
for nid in node_ids:
n = self._normalize_node_id(nid)
if n not in seen:
seen.add(n)
unique_ids.append(n)
if nid not in seen:
seen.add(nid)
unique_ids.append(nid)
requested.add(nid)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
out_degrees = {}
in_degrees = {}
@ -3924,8 +3934,16 @@ class PGGraphStorage(BaseGraphStorage):
node_id = row["node_id"]
if not node_id:
continue
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
node_key = node_id
original_key = lookup.get(node_key)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
degrees_dict = {}
for node_id in node_ids:
@ -4054,7 +4072,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
)
continue
@ -4070,7 +4088,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
@ -4175,102 +4193,6 @@ class PGGraphStorage(BaseGraphStorage):
labels.append(result["label"])
return labels
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all nodes
where the `source_id` property contains any of the specified chunk IDs.
"""
# The string representation of the list for the cypher query
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
RETURN n
$$) AS (n agtype);
"""
results = await self._query(query)
# Build result list
nodes = []
for result in results:
if result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
node_dict["id"] = node_dict["entity_id"]
nodes.append(node_dict)
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves edges from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all edges
where the `source_id` property contains any of the specified chunk IDs.
"""
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH ()-[r]-()
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
$$) AS (edge agtype, source agtype, target agtype);
"""
results = await self._query(query)
edges = []
if results:
for item in results:
edge_agtype = item["edge"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_agtype, str):
try:
edge_agtype = json.loads(edge_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
)
source_agtype = item["source"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(source_agtype, str):
try:
source_agtype = json.loads(source_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
)
target_agtype = item["target"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(target_agtype, str):
try:
target_agtype = json.loads(target_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
)
if edge_agtype and source_agtype and target_agtype:
edge_properties = edge_agtype
edge_properties["source"] = source_agtype["entity_id"]
edge_properties["target"] = target_agtype["entity_id"]
edges.append(edge_properties)
return edges
async def _bfs_subgraph(
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:

View file

@ -3235,38 +3235,31 @@ class LightRAG:
if entity_chunk_updates and self.entity_chunks:
entity_upsert_payload = {}
entity_delete_ids: set[str] = set()
for entity_name, remaining in entity_chunk_updates.items():
if not remaining:
entity_delete_ids.add(entity_name)
else:
entity_upsert_payload[entity_name] = {
"chunk_ids": remaining,
"count": len(remaining),
"updated_at": current_time,
}
if entity_delete_ids:
await self.entity_chunks.delete(list(entity_delete_ids))
# Empty entities are deleted alongside graph nodes later
continue
entity_upsert_payload[entity_name] = {
"chunk_ids": remaining,
"count": len(remaining),
"updated_at": current_time,
}
if entity_upsert_payload:
await self.entity_chunks.upsert(entity_upsert_payload)
if relation_chunk_updates and self.relation_chunks:
relation_upsert_payload = {}
relation_delete_ids: set[str] = set()
for edge_tuple, remaining in relation_chunk_updates.items():
storage_key = make_relation_chunk_key(*edge_tuple)
if not remaining:
relation_delete_ids.add(storage_key)
else:
relation_upsert_payload[storage_key] = {
"chunk_ids": remaining,
"count": len(remaining),
"updated_at": current_time,
}
# Empty relations are deleted alongside graph edges later
continue
storage_key = make_relation_chunk_key(*edge_tuple)
relation_upsert_payload[storage_key] = {
"chunk_ids": remaining,
"count": len(remaining),
"updated_at": current_time,
}
if relation_delete_ids:
await self.relation_chunks.delete(list(relation_delete_ids))
if relation_upsert_payload:
await self.relation_chunks.upsert(relation_upsert_payload)
@ -3296,7 +3289,7 @@ class LightRAG:
# 6. Delete relationships that have no remaining sources
if relationships_to_delete:
try:
# Delete from vector database
# Delete from relation vdb
rel_ids_to_delete = []
for src, tgt in relationships_to_delete:
rel_ids_to_delete.extend(
@ -3333,15 +3326,16 @@ class LightRAG:
# 7. Delete entities that have no remaining sources
if entities_to_delete:
try:
# Batch get all edges for entities to avoid N+1 query problem
nodes_edges_dict = await self.chunk_entity_relation_graph.get_nodes_edges_batch(
list(entities_to_delete)
)
# Debug: Check and log all edges before deleting nodes
edges_to_delete = set()
edges_still_exist = 0
for entity in entities_to_delete:
edges = (
await self.chunk_entity_relation_graph.get_node_edges(
entity
)
)
for entity, edges in nodes_edges_dict.items():
if edges:
for src, tgt in edges:
# Normalize edge representation (sorted for consistency)
@ -3364,6 +3358,7 @@ class LightRAG:
f"Edge still exists: {src} <-- {tgt}"
)
edges_still_exist += 1
if edges_still_exist:
logger.warning(
f"⚠️ {edges_still_exist} entities still has edges before deletion"
@ -3399,7 +3394,7 @@ class LightRAG:
list(entities_to_delete)
)
# Delete from vector database
# Delete from vector vdb
entity_vdb_ids = [
compute_mdhash_id(entity, prefix="ent-")
for entity in entities_to_delete

View file

@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace
import argparse
import json
from dataclasses import asdict, dataclass, field
from typing import Any, ClassVar, List
from typing import Any, ClassVar, List, get_args, get_origin
from lightrag.utils import get_env_value
from lightrag.constants import DEFAULT_TEMPERATURE
def _resolve_optional_type(field_type: Any) -> Any:
"""Return the concrete type for Optional/Union annotations."""
origin = get_origin(field_type)
if origin in (list, dict, tuple):
return field_type
args = get_args(field_type)
if args:
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return non_none_args[0]
return field_type
# =============================================================================
# BindingOptions Base Class
# =============================================================================
@ -177,9 +191,13 @@ class BindingOptions:
help=arg_item["help"],
)
else:
resolved_type = arg_item["type"]
if resolved_type is not None:
resolved_type = _resolve_optional_type(resolved_type)
group.add_argument(
f"--{arg_item['argname']}",
type=arg_item["type"],
type=resolved_type,
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
@ -210,7 +228,7 @@ class BindingOptions:
argdef = {
"argname": f"{args_prefix}-{field.name}",
"env_name": f"{env_var_prefix}{field.name.upper()}",
"type": field.type,
"type": _resolve_optional_type(field.type),
"default": default_value,
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
}
@ -454,6 +472,39 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
_binding_name: ClassVar[str] = "ollama_llm"
@dataclass
class GeminiLLMOptions(BindingOptions):
"""Options for Google Gemini models."""
_binding_name: ClassVar[str] = "gemini_llm"
temperature: float = DEFAULT_TEMPERATURE
top_p: float = 0.95
top_k: int = 40
max_output_tokens: int | None = None
candidate_count: int = 1
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
stop_sequences: List[str] = field(default_factory=list)
seed: int | None = None
thinking_config: dict | None = None
safety_settings: dict | None = None
_help: ClassVar[dict[str, str]] = {
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
"top_p": "Nucleus sampling parameter (0.0-1.0)",
"top_k": "Limits sampling to the top K tokens (1 disables the limit)",
"max_output_tokens": "Maximum tokens generated in the response",
"candidate_count": "Number of candidates returned per request",
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
"stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
"seed": "Random seed for reproducible generation (leave empty for random)",
"thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
"safety_settings": "JSON object with Gemini safety settings overrides",
}
# =============================================================================
# Binding Options for OpenAI
# =============================================================================

422
lightrag/llm/gemini.py Normal file
View file

@ -0,0 +1,422 @@
"""
Gemini LLM binding for LightRAG.
This module provides asynchronous helpers that adapt Google's Gemini models
to the same interface used by the rest of the LightRAG LLM bindings. The
implementation mirrors the OpenAI helpers while relying on the official
``google-genai`` client under the hood.
"""
from __future__ import annotations
import asyncio
import logging
import os
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
import pipmaster as pm
# Install the Google Gemini client on demand
if not pm.is_installed("google-genai"):
pm.install("google-genai")
from google import genai # type: ignore
from google.genai import types # type: ignore
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
LOG = logging.getLogger(__name__)
@lru_cache(maxsize=8)
def _get_gemini_client(
api_key: str, base_url: str | None, timeout: int | None = None
) -> genai.Client:
"""
Create (or fetch cached) Gemini client.
Args:
api_key: Google Gemini API key.
base_url: Optional custom API endpoint.
timeout: Optional request timeout in milliseconds.
Returns:
genai.Client: Configured Gemini client instance.
"""
client_kwargs: dict[str, Any] = {"api_key": api_key}
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None:
try:
http_options_kwargs = {}
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
http_options_kwargs["api_endpoint"] = base_url
if timeout is not None:
http_options_kwargs["timeout"] = timeout
client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
except Exception as exc: # pragma: no cover - defensive
LOG.warning("Failed to apply custom Gemini http_options: %s", exc)
try:
return genai.Client(**client_kwargs)
except TypeError:
# Older google-genai releases don't accept http_options; retry without it.
client_kwargs.pop("http_options", None)
return genai.Client(**client_kwargs)
def _ensure_api_key(api_key: str | None) -> str:
key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
if not key:
raise ValueError(
"Gemini API key not provided. "
"Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
)
return key
def _build_generation_config(
base_config: dict[str, Any] | None,
system_prompt: str | None,
keyword_extraction: bool,
) -> types.GenerateContentConfig | None:
config_data = dict(base_config or {})
if system_prompt:
if config_data.get("system_instruction"):
config_data["system_instruction"] = (
f"{config_data['system_instruction']}\n{system_prompt}"
)
else:
config_data["system_instruction"] = system_prompt
if keyword_extraction and not config_data.get("response_mime_type"):
config_data["response_mime_type"] = "application/json"
# Remove entries that are explicitly set to None to avoid type errors
sanitized = {
key: value
for key, value in config_data.items()
if value is not None and value != ""
}
if not sanitized:
return None
return types.GenerateContentConfig(**sanitized)
def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
if not history_messages:
return ""
history_lines: list[str] = []
for message in history_messages:
role = message.get("role", "user")
content = message.get("content", "")
history_lines.append(f"[{role}] {content}")
return "\n".join(history_lines)
def _extract_response_text(
response: Any, extract_thoughts: bool = False
) -> tuple[str, str]:
"""
Extract text content from Gemini response, separating regular content from thoughts.
Args:
response: Gemini API response object
extract_thoughts: Whether to extract thought content separately
Returns:
Tuple of (regular_text, thought_text)
"""
candidates = getattr(response, "candidates", None)
if not candidates:
return ("", "")
regular_parts: list[str] = []
thought_parts: list[str] = []
for candidate in candidates:
if not getattr(candidate, "content", None):
continue
# Use 'or []' to handle None values from parts attribute
for part in getattr(candidate.content, "parts", None) or []:
text = getattr(part, "text", None)
if not text:
continue
# Check if this part is thought content using the 'thought' attribute
is_thought = getattr(part, "thought", False)
if is_thought and extract_thoughts:
thought_parts.append(text)
elif not is_thought:
regular_parts.append(text)
return ("\n".join(regular_parts), "\n".join(thought_parts))
async def gemini_complete_if_cache(
model: str,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
stream: bool | None = None,
keyword_extraction: bool = False,
generation_config: dict[str, Any] | None = None,
timeout: int | None = None,
**_: Any,
) -> str | AsyncIterator[str]:
"""
Complete a prompt using Gemini's API with Chain of Thought (COT) support.
This function supports automatic integration of reasoning content from Gemini models
that provide Chain of Thought capabilities via the thinking_config API feature.
COT Integration:
- When enable_cot=True: Thought content is wrapped in <think>...</think> tags
- When enable_cot=False: Thought content is filtered out, only regular content returned
- Thought content is identified by the 'thought' attribute on response parts
- Requires thinking_config to be enabled in generation_config for API to return thoughts
Args:
model: The Gemini model to use.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
api_key: Optional Gemini API key. If None, uses environment variable.
base_url: Optional custom API endpoint.
generation_config: Optional generation configuration dict.
keyword_extraction: Whether to use JSON response format.
token_tracker: Optional token usage tracker for monitoring API usage.
stream: Whether to stream the response.
hashing_kv: Storage interface (for interface parity with other bindings).
enable_cot: Whether to include Chain of Thought content in the response.
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
**_: Additional keyword arguments (ignored).
Returns:
The completed text (with COT content if enable_cot=True) or an async iterator
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
Raises:
RuntimeError: If the response from Gemini is empty.
ValueError: If API key is not provided or configured.
"""
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
# Convert timeout from seconds to milliseconds for Gemini API
timeout_ms = timeout * 1000 if timeout else None
client = _get_gemini_client(key, base_url, timeout_ms)
history_block = _format_history_messages(history_messages)
prompt_sections = []
if history_block:
prompt_sections.append(history_block)
prompt_sections.append(f"[user] {prompt}")
combined_prompt = "\n".join(prompt_sections)
config_obj = _build_generation_config(
generation_config,
system_prompt=system_prompt,
keyword_extraction=keyword_extraction,
)
request_kwargs: dict[str, Any] = {
"model": model,
"contents": [combined_prompt],
}
if config_obj is not None:
request_kwargs["config"] = config_obj
def _call_model():
return client.models.generate_content(**request_kwargs)
if stream:
queue: asyncio.Queue[Any] = asyncio.Queue()
usage_container: dict[str, Any] = {}
def _stream_model() -> None:
# COT state tracking for streaming
cot_active = False
cot_started = False
initial_content_seen = False
try:
stream_kwargs = dict(request_kwargs)
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
for chunk in stream_iterator:
usage = getattr(chunk, "usage_metadata", None)
if usage is not None:
usage_container["usage"] = usage
# Extract both regular and thought content
regular_text, thought_text = _extract_response_text(
chunk, extract_thoughts=True
)
if enable_cot:
# Process regular content
if regular_text:
if not initial_content_seen:
initial_content_seen = True
# Close COT section if it was active
if cot_active:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
cot_active = False
# Send regular content
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
# Process thought content
if thought_text:
if not initial_content_seen and not cot_started:
# Start COT section
loop.call_soon_threadsafe(queue.put_nowait, "<think>")
cot_active = True
cot_started = True
# Send thought content if COT is active
if cot_active:
loop.call_soon_threadsafe(
queue.put_nowait, thought_text
)
else:
# COT disabled - only send regular content
if regular_text:
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
# Ensure COT is properly closed if still active
if cot_active:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
loop.call_soon_threadsafe(queue.put_nowait, None)
except Exception as exc: # pragma: no cover - surface runtime issues
# Try to close COT tag before reporting error
if cot_active:
try:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
except Exception:
pass
loop.call_soon_threadsafe(queue.put_nowait, exc)
loop.run_in_executor(None, _stream_model)
async def _async_stream() -> AsyncIterator[str]:
try:
while True:
item = await queue.get()
if item is None:
break
if isinstance(item, Exception):
raise item
chunk_text = str(item)
if "\\u" in chunk_text:
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
# Yield the chunk directly without filtering
# COT filtering is already handled in _stream_model()
yield chunk_text
finally:
usage = usage_container.get("usage")
if token_tracker and usage:
token_tracker.add_usage(
{
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"completion_tokens": getattr(
usage, "candidates_token_count", 0
),
"total_tokens": getattr(usage, "total_token_count", 0),
}
)
return _async_stream()
response = await asyncio.to_thread(_call_model)
# Extract both regular text and thought text
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
# Apply COT filtering logic based on enable_cot parameter
if enable_cot:
# Include thought content wrapped in <think> tags
if thought_text and thought_text.strip():
if not regular_text or regular_text.strip() == "":
# Only thought content available
final_text = f"<think>{thought_text}</think>"
else:
# Both content types present: prepend thought to regular content
final_text = f"<think>{thought_text}</think>{regular_text}"
else:
# No thought content, use regular content only
final_text = regular_text or ""
else:
# Filter out thought content, return only regular content
final_text = regular_text or ""
if not final_text:
raise RuntimeError("Gemini response did not contain any text content.")
if "\\u" in final_text:
final_text = safe_unicode_decode(final_text.encode("utf-8"))
final_text = remove_think_tags(final_text)
usage = getattr(response, "usage_metadata", None)
if token_tracker and usage:
token_tracker.add_usage(
{
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"completion_tokens": getattr(usage, "candidates_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
}
)
logger.debug("Gemini response length: %s", len(final_text))
return final_text
async def gemini_model_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
keyword_extraction: bool = False,
**kwargs: Any,
) -> str | AsyncIterator[str]:
hashing_kv = kwargs.get("hashing_kv")
model_name = None
if hashing_kv is not None:
model_name = hashing_kv.global_config.get("llm_model_name")
if model_name is None:
model_name = kwargs.pop("model_name", None)
if model_name is None:
raise ValueError("Gemini model name not provided in configuration.")
return await gemini_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
__all__ = [
"gemini_complete_if_cache",
"gemini_model_complete",
]

View file

@ -138,6 +138,9 @@ async def openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI
stream: bool | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@ -172,8 +175,9 @@ async def openai_complete_if_cache(
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
- stream: Whether to stream the response. Default is False.
- timeout: Request timeout in seconds. Default is None.
Returns:
The completed text (with integrated COT content if available) or an async iterator

View file

@ -1795,7 +1795,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
- Filter out short numeric-only text (length < 3 and only digits/dots)
- remove_inner_quotes = True
remove Chinese quotes
remove English queotes in and around chinese
remove English quotes in and around chinese
Convert non-breaking spaces to regular spaces
Convert narrow non-breaking spaces after non-digits to regular spaces

View file

@ -24,6 +24,7 @@ dependencies = [
"aiohttp",
"configparser",
"future",
"google-genai>=1.0.0,<2.0.0",
"json_repair",
"nano-vectordb",
"networkx",
@ -59,6 +60,7 @@ api = [
"tenacity",
"tiktoken",
"xlsxwriter>=3.1.0",
"google-genai>=1.0.0,<2.0.0",
# API-specific dependencies
"aiofiles",
"ascii_colors",
@ -107,6 +109,7 @@ offline-llm = [
"aioboto3>=12.0.0,<16.0.0",
"voyageai>=0.2.0,<1.0.0",
"llama-index>=0.9.0,<1.0.0",
"google-genai>=1.0.0,<2.0.0",
]
offline = [

View file

@ -10,6 +10,7 @@
# LLM provider dependencies (with version constraints matching pyproject.toml)
aioboto3>=12.0.0,<16.0.0
anthropic>=0.18.0,<1.0.0
google-genai>=1.0.0,<2.0.0
llama-index>=0.9.0,<1.0.0
ollama>=0.1.0,<1.0.0
openai>=1.0.0,<3.0.0

View file

@ -13,6 +13,7 @@ anthropic>=0.18.0,<1.0.0
# Storage backend dependencies
asyncpg>=0.29.0,<1.0.0
google-genai>=1.0.0,<2.0.0
# Document processing dependencies
llama-index>=0.9.0,<1.0.0

File diff suppressed because it is too large Load diff