Merge branch 'main' into feature-add-tigergraph-support
This commit is contained in:
commit
dc2898d358
26 changed files with 1361 additions and 1476 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -67,8 +67,8 @@ download_models_hf.py
|
|||
# Frontend build output (built during PyPI release)
|
||||
/lightrag/api/webui/
|
||||
|
||||
# unit-test files
|
||||
test_*
|
||||
# temporary test files in project root
|
||||
/test_*
|
||||
|
||||
# Cline files
|
||||
memory-bank
|
||||
|
|
|
|||
46
README-zh.md
46
README-zh.md
|
|
@ -53,7 +53,7 @@
|
|||
|
||||
## 🎉 新闻
|
||||
|
||||
- [x] [2025.11.05]🎯📢添加**基于RAGAS的**LightRAG评估框架。
|
||||
- [x] [2025.11.05]🎯📢添加**基于RAGAS的**评估框架和**Langfuse**可观测性支持。
|
||||
- [x] [2025.10.22]🎯📢消除处理**大规模数据集**的瓶颈。
|
||||
- [x] [2025.09.15]🎯📢显著提升**小型LLM**(如Qwen3-30B-A3B)的知识图谱提取准确性。
|
||||
- [x] [2025.08.29]🎯📢现已支持**Reranker**,显著提升混合查询性能。
|
||||
|
|
@ -1463,6 +1463,50 @@ LightRAG服务器提供全面的知识图谱可视化功能。它支持各种重
|
|||
|
||||

|
||||
|
||||
## Langfuse 可观测性集成
|
||||
|
||||
Langfuse 为 OpenAI 客户端提供了直接替代方案,可自动跟踪所有 LLM 交互,使开发者能够在无需修改代码的情况下监控、调试和优化其 RAG 系统。
|
||||
|
||||
### 安装 Langfuse 可选依赖
|
||||
|
||||
```
|
||||
pip install lightrag-hku
|
||||
pip install lightrag-hku[observability]
|
||||
|
||||
# 或从源代码安装并启用调试模式
|
||||
pip install -e .
|
||||
pip install -e ".[observability]"
|
||||
```
|
||||
|
||||
### 配置 Langfuse 环境变量
|
||||
|
||||
修改 .env 文件:
|
||||
|
||||
```
|
||||
## Langfuse 可观测性(可选)
|
||||
# LLM 可观测性和追踪平台
|
||||
# 安装命令: pip install lightrag-hku[observability]
|
||||
# 注册地址: https://cloud.langfuse.com 或自托管部署
|
||||
LANGFUSE_SECRET_KEY=""
|
||||
LANGFUSE_PUBLIC_KEY=""
|
||||
LANGFUSE_HOST="https://cloud.langfuse.com" # 或您的自托管实例地址
|
||||
LANGFUSE_ENABLE_TRACE=true
|
||||
```
|
||||
|
||||
### Langfuse 使用说明
|
||||
|
||||
安装并配置完成后,Langfuse 会自动追踪所有 OpenAI LLM 调用。Langfuse 仪表板功能包括:
|
||||
|
||||
- **追踪**:查看完整的 LLM 调用链
|
||||
- **分析**:Token 使用量、延迟、成本指标
|
||||
- **调试**:检查提示词和响应内容
|
||||
- **评估**:比较模型输出结果
|
||||
- **监控**:实时告警功能
|
||||
|
||||
### 重要提示
|
||||
|
||||
**注意**:LightRAG 目前仅把 OpenAI 兼容的 API 调用接入了 Langfuse。Ollama、Azure 和 AWS Bedrock 等 API 还无法使用 Langfuse 的可观测性功能。
|
||||
|
||||
## RAGAS评估
|
||||
|
||||
**RAGAS**(Retrieval Augmented Generation Assessment,检索增强生成评估)是一个使用LLM对RAG系统进行无参考评估的框架。我们提供了基于RAGAS的评估脚本。详细信息请参阅[基于RAGAS的评估框架](lightrag/evaluation/README.md)。
|
||||
|
|
|
|||
46
README.md
46
README.md
|
|
@ -51,7 +51,7 @@
|
|||
|
||||
---
|
||||
## 🎉 News
|
||||
- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework for LightRAG.
|
||||
- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework and **Langfuse** observability for LightRAG.
|
||||
- [x] [2025.10.22]🎯📢Eliminate bottlenecks in processing **large-scale datasets**.
|
||||
- [x] [2025.09.15]🎯📢Significantly enhances KG extraction accuracy for **small LLMs** like Qwen3-30B-A3B.
|
||||
- [x] [2025.08.29]🎯📢**Reranker** is supported now , significantly boosting performance for mixed queries.
|
||||
|
|
@ -1543,6 +1543,50 @@ The LightRAG Server offers a comprehensive knowledge graph visualization feature
|
|||
|
||||

|
||||
|
||||
## Langfuse observability integration
|
||||
|
||||
Langfuse provides a drop-in replacement for the OpenAI client that automatically tracks all LLM interactions, enabling developers to monitor, debug, and optimize their RAG systems without code changes.
|
||||
|
||||
### Installation with Langfuse option
|
||||
|
||||
```
|
||||
pip install lightrag-hku
|
||||
pip install lightrag-hku[observability]
|
||||
|
||||
# Or install from souce code with debug mode enabled
|
||||
pip install -e .
|
||||
pip install -e ".[observability]"
|
||||
```
|
||||
|
||||
### Config Langfuse env vars
|
||||
|
||||
modify .env file:
|
||||
|
||||
```
|
||||
## Langfuse Observability (Optional)
|
||||
# LLM observability and tracing platform
|
||||
# Install with: pip install lightrag-hku[observability]
|
||||
# Sign up at: https://cloud.langfuse.com or self-host
|
||||
LANGFUSE_SECRET_KEY=""
|
||||
LANGFUSE_PUBLIC_KEY=""
|
||||
LANGFUSE_HOST="https://cloud.langfuse.com" # or your self-hosted instance
|
||||
LANGFUSE_ENABLE_TRACE=true
|
||||
```
|
||||
|
||||
### Langfuse Usage
|
||||
|
||||
Once installed and configured, Langfuse automatically traces all OpenAI LLM calls. Langfuse dashboard features include:
|
||||
|
||||
- **Tracing**: View complete LLM call chains
|
||||
- **Analytics**: Token usage, latency, cost metrics
|
||||
- **Debugging**: Inspect prompts and responses
|
||||
- **Evaluation**: Compare model outputs
|
||||
- **Monitoring**: Real-time alerting
|
||||
|
||||
### Important Notice
|
||||
|
||||
**Note**: LightRAG currently only integrates OpenAI-compatible API calls with Langfuse. APIs such as Ollama, Azure, and AWS Bedrock are not yet supported for Langfuse observability.
|
||||
|
||||
## RAGAS-based Evaluation
|
||||
|
||||
**RAGAS** (Retrieval Augmented Generation Assessment) is a framework for reference-free evaluation of RAG systems using LLMs. There is an evaluation script based on RAGAS. For detailed information, please refer to [RAGAS-based Evaluation Framework](lightrag/evaluation/README.md).
|
||||
|
|
|
|||
11
env.example
11
env.example
|
|
@ -170,7 +170,7 @@ MAX_PARALLEL_INSERT=2
|
|||
|
||||
###########################################################################
|
||||
### LLM Configuration
|
||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini
|
||||
### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service
|
||||
###########################################################################
|
||||
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
|
||||
|
|
@ -191,6 +191,15 @@ LLM_BINDING_API_KEY=your_api_key
|
|||
# LLM_BINDING_API_KEY=your_api_key
|
||||
# LLM_BINDING=openai
|
||||
|
||||
### Gemini example
|
||||
# LLM_BINDING=gemini
|
||||
# LLM_MODEL=gemini-flash-latest
|
||||
# LLM_BINDING_API_KEY=your_gemini_api_key
|
||||
# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
|
||||
GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}'
|
||||
# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000
|
||||
# GEMINI_LLM_TEMPERATURE=0.7
|
||||
|
||||
### OpenAI Compatible API Specific Parameters
|
||||
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
|
||||
# OPENAI_LLM_TEMPERATURE=0.9
|
||||
|
|
|
|||
|
|
@ -1,105 +0,0 @@
|
|||
# pip install -q -U google-genai to use gemini as a client
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
load_dotenv()
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if os.path.exists(WORKING_DIR):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(WORKING_DIR)
|
||||
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
||||
client = genai.Client(api_key=gemini_api_key)
|
||||
|
||||
# 2. Combine prompts: system prompt, history, and user prompt
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
combined_prompt = ""
|
||||
if system_prompt:
|
||||
combined_prompt += f"{system_prompt}\n"
|
||||
|
||||
for msg in history_messages:
|
||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||
|
||||
# Finally, add the new user prompt
|
||||
combined_prompt += f"user: {prompt}"
|
||||
|
||||
# 3. Call the Gemini model
|
||||
response = client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=[combined_prompt],
|
||||
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
|
||||
)
|
||||
|
||||
# 4. Return the response text
|
||||
return response.text
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
embeddings = model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize RAG instance
|
||||
rag = asyncio.run(initialize_rag())
|
||||
file_path = "story.txt"
|
||||
with open(file_path, "r") as file:
|
||||
text = file.read()
|
||||
|
||||
rag.insert(text)
|
||||
|
||||
response = rag.query(
|
||||
query="What is the main theme of the story?",
|
||||
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
# pip install -q -U google-genai to use gemini as a client
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.utils import EmbeddingFunc, Tokenizer
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
import sentencepiece as spm
|
||||
import requests
|
||||
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
load_dotenv()
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if os.path.exists(WORKING_DIR):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(WORKING_DIR)
|
||||
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
class GemmaTokenizer(Tokenizer):
|
||||
# adapted from google-cloud-aiplatform[tokenization]
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _TokenizerConfig:
|
||||
tokenizer_model_url: str
|
||||
tokenizer_model_hash: str
|
||||
|
||||
_TOKENIZERS = {
|
||||
"google/gemma2": _TokenizerConfig(
|
||||
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
|
||||
tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
|
||||
),
|
||||
"google/gemma3": _TokenizerConfig(
|
||||
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
||||
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
|
||||
):
|
||||
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
||||
if "1.5" in model_name or "1.0" in model_name:
|
||||
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
||||
# https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
|
||||
tokenizer_name = "google/gemma2"
|
||||
else:
|
||||
# for gemini > 2.0 gemma3 was used
|
||||
tokenizer_name = "google/gemma3"
|
||||
|
||||
file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
|
||||
tokenizer_model_name = file_url.rsplit("/", 1)[1]
|
||||
expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
|
||||
|
||||
tokenizer_dir = Path(tokenizer_dir)
|
||||
if tokenizer_dir.is_dir():
|
||||
file_path = tokenizer_dir / tokenizer_model_name
|
||||
model_data = self._maybe_load_from_cache(
|
||||
file_path=file_path, expected_hash=expected_hash
|
||||
)
|
||||
else:
|
||||
model_data = None
|
||||
if not model_data:
|
||||
model_data = self._load_from_url(
|
||||
file_url=file_url, expected_hash=expected_hash
|
||||
)
|
||||
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
||||
|
||||
tokenizer = spm.SentencePieceProcessor()
|
||||
tokenizer.LoadFromSerializedProto(model_data)
|
||||
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
||||
|
||||
def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
|
||||
"""Returns true if the content is valid by checking the hash."""
|
||||
return hashlib.sha256(model_data).hexdigest() == expected_hash
|
||||
|
||||
def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
|
||||
"""Loads the model data from the cache path."""
|
||||
if not file_path.is_file():
|
||||
return
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
if self._is_valid_model(model_data=content, expected_hash=expected_hash):
|
||||
return content
|
||||
|
||||
# Cached file corrupted.
|
||||
self._maybe_remove_file(file_path)
|
||||
|
||||
def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
|
||||
"""Loads model bytes from the given file url."""
|
||||
resp = requests.get(file_url)
|
||||
resp.raise_for_status()
|
||||
content = resp.content
|
||||
|
||||
if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
|
||||
actual_hash = hashlib.sha256(content).hexdigest()
|
||||
raise ValueError(
|
||||
f"Downloaded model file is corrupted."
|
||||
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
|
||||
)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
|
||||
"""Saves the model data to the cache path."""
|
||||
try:
|
||||
if not cache_path.is_file():
|
||||
cache_dir = cache_path.parent
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_path, "wb") as f:
|
||||
f.write(model_data)
|
||||
except OSError:
|
||||
# Don't raise if we cannot write file.
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _maybe_remove_file(file_path: Path) -> None:
|
||||
"""Removes the file if exists."""
|
||||
if not file_path.is_file():
|
||||
return
|
||||
try:
|
||||
file_path.unlink()
|
||||
except OSError:
|
||||
# Don't raise if we cannot remove file.
|
||||
pass
|
||||
|
||||
# def encode(self, content: str) -> list[int]:
|
||||
# return self.tokenizer.encode(content)
|
||||
|
||||
# def decode(self, tokens: list[int]) -> str:
|
||||
# return self.tokenizer.decode(tokens)
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
||||
client = genai.Client(api_key=gemini_api_key)
|
||||
|
||||
# 2. Combine prompts: system prompt, history, and user prompt
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
combined_prompt = ""
|
||||
if system_prompt:
|
||||
combined_prompt += f"{system_prompt}\n"
|
||||
|
||||
for msg in history_messages:
|
||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||
|
||||
# Finally, add the new user prompt
|
||||
combined_prompt += f"user: {prompt}"
|
||||
|
||||
# 3. Call the Gemini model
|
||||
response = client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=[combined_prompt],
|
||||
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
|
||||
)
|
||||
|
||||
# 4. Return the response text
|
||||
return response.text
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
embeddings = model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
# tiktoken_model_name="gpt-4o-mini",
|
||||
tokenizer=GemmaTokenizer(
|
||||
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
|
||||
model_name="gemini-2.0-flash",
|
||||
),
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize RAG instance
|
||||
rag = asyncio.run(initialize_rag())
|
||||
file_path = "story.txt"
|
||||
with open(file_path, "r") as file:
|
||||
text = file.read()
|
||||
|
||||
rag.insert(text)
|
||||
|
||||
response = rag.query(
|
||||
query="What is the main theme of the story?",
|
||||
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,151 +0,0 @@
|
|||
# pip install -q -U google-genai to use gemini as a client
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import nest_asyncio
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
||||
from lightrag.utils import setup_logger
|
||||
from lightrag.utils import TokenTracker
|
||||
|
||||
setup_logger("lightrag", level="DEBUG")
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
load_dotenv()
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
token_tracker = TokenTracker()
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
||||
client = genai.Client(api_key=gemini_api_key)
|
||||
|
||||
# 2. Combine prompts: system prompt, history, and user prompt
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
combined_prompt = ""
|
||||
if system_prompt:
|
||||
combined_prompt += f"{system_prompt}\n"
|
||||
|
||||
for msg in history_messages:
|
||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||
|
||||
# Finally, add the new user prompt
|
||||
combined_prompt += f"user: {prompt}"
|
||||
|
||||
# 3. Call the Gemini model
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=[combined_prompt],
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=5000, temperature=0, top_k=10
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Get token counts with null safety
|
||||
usage = getattr(response, "usage_metadata", None)
|
||||
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
||||
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
||||
total_tokens = getattr(usage, "total_token_count", 0) or (
|
||||
prompt_tokens + completion_tokens
|
||||
)
|
||||
|
||||
token_counts = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
# 5. Return the response text
|
||||
return response.text
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await siliconcloud_embedding(
|
||||
texts,
|
||||
model="BAAI/bge-m3",
|
||||
api_key=siliconflow_api_key,
|
||||
max_token_size=512,
|
||||
)
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
entity_extract_max_gleaning=1,
|
||||
enable_llm_cache=True,
|
||||
enable_llm_cache_for_entity_extract=True,
|
||||
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize RAG instance
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Context Manager Method
|
||||
with token_tracker:
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(mode="global"),
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(mode="hybrid"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -50,6 +50,7 @@ LightRAG necessitates the integration of both an LLM (Large Language Model) and
|
|||
* openai or openai compatible
|
||||
* azure_openai
|
||||
* aws_bedrock
|
||||
* gemini
|
||||
|
||||
It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**.
|
||||
|
||||
|
|
@ -72,6 +73,8 @@ EMBEDDING_DIM=1024
|
|||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
```
|
||||
|
||||
> When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`).
|
||||
|
||||
* Ollama LLM + Ollama Embedding:
|
||||
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
__api_version__ = "0250"
|
||||
__api_version__ = "0251"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import logging
|
|||
from dotenv import load_dotenv
|
||||
from lightrag.utils import get_env_value
|
||||
from lightrag.llm.binding_options import (
|
||||
GeminiLLMOptions,
|
||||
OllamaEmbeddingOptions,
|
||||
OllamaLLMOptions,
|
||||
OpenAILLMOptions,
|
||||
|
|
@ -63,6 +64,9 @@ def get_default_host(binding_type: str) -> str:
|
|||
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||
"gemini": os.getenv(
|
||||
"LLM_BINDING_HOST", "https://generativelanguage.googleapis.com"
|
||||
),
|
||||
}
|
||||
return default_hosts.get(
|
||||
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||
|
|
@ -226,6 +230,7 @@ def parse_args() -> argparse.Namespace:
|
|||
"openai-ollama",
|
||||
"azure_openai",
|
||||
"aws_bedrock",
|
||||
"gemini",
|
||||
],
|
||||
help="LLM binding type (default: from env or ollama)",
|
||||
)
|
||||
|
|
@ -281,6 +286,16 @@ def parse_args() -> argparse.Namespace:
|
|||
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
||||
OpenAILLMOptions.add_args(parser)
|
||||
|
||||
if "--llm-binding" in sys.argv:
|
||||
try:
|
||||
idx = sys.argv.index("--llm-binding")
|
||||
if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "gemini":
|
||||
GeminiLLMOptions.add_args(parser)
|
||||
except IndexError:
|
||||
pass
|
||||
elif os.environ.get("LLM_BINDING") == "gemini":
|
||||
GeminiLLMOptions.add_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# convert relative path to absolute path
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ class LLMConfigCache:
|
|||
|
||||
# Initialize configurations based on binding conditions
|
||||
self.openai_llm_options = None
|
||||
self.gemini_llm_options = None
|
||||
self.ollama_llm_options = None
|
||||
self.ollama_embedding_options = None
|
||||
|
||||
|
|
@ -99,6 +100,12 @@ class LLMConfigCache:
|
|||
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
|
||||
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
|
||||
|
||||
if args.llm_binding == "gemini":
|
||||
from lightrag.llm.binding_options import GeminiLLMOptions
|
||||
|
||||
self.gemini_llm_options = GeminiLLMOptions.options_dict(args)
|
||||
logger.info(f"Gemini LLM Options: {self.gemini_llm_options}")
|
||||
|
||||
# Only initialize and log Ollama LLM options when using Ollama LLM binding
|
||||
if args.llm_binding == "ollama":
|
||||
try:
|
||||
|
|
@ -279,6 +286,7 @@ def create_app(args):
|
|||
"openai",
|
||||
"azure_openai",
|
||||
"aws_bedrock",
|
||||
"gemini",
|
||||
]:
|
||||
raise Exception("llm binding not supported")
|
||||
|
||||
|
|
@ -504,6 +512,44 @@ def create_app(args):
|
|||
|
||||
return optimized_azure_openai_model_complete
|
||||
|
||||
def create_optimized_gemini_llm_func(
|
||||
config_cache: LLMConfigCache, args, llm_timeout: int
|
||||
):
|
||||
"""Create optimized Gemini LLM function with cached configuration"""
|
||||
|
||||
async def optimized_gemini_model_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
from lightrag.llm.gemini import gemini_complete_if_cache
|
||||
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
# Use pre-processed configuration to avoid repeated parsing
|
||||
kwargs["timeout"] = llm_timeout
|
||||
if (
|
||||
config_cache.gemini_llm_options is not None
|
||||
and "generation_config" not in kwargs
|
||||
):
|
||||
kwargs["generation_config"] = dict(config_cache.gemini_llm_options)
|
||||
|
||||
return await gemini_complete_if_cache(
|
||||
args.llm_model,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=args.llm_binding_api_key,
|
||||
base_url=args.llm_binding_host,
|
||||
keyword_extraction=keyword_extraction,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return optimized_gemini_model_complete
|
||||
|
||||
def create_llm_model_func(binding: str):
|
||||
"""
|
||||
Create LLM model function based on binding type.
|
||||
|
|
@ -525,6 +571,8 @@ def create_app(args):
|
|||
return create_optimized_azure_openai_llm_func(
|
||||
config_cache, args, llm_timeout
|
||||
)
|
||||
elif binding == "gemini":
|
||||
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
|
||||
else: # openai and compatible
|
||||
# Use optimized function with pre-processed configuration
|
||||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from typing import (
|
|||
from .utils import EmbeddingFunc
|
||||
from .types import KnowledgeGraph
|
||||
from .constants import (
|
||||
GRAPH_FIELD_SEP,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_CHUNK_TOP_K,
|
||||
DEFAULT_MAX_ENTITY_TOKENS,
|
||||
|
|
@ -528,56 +527,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||
result[node_id] = edges if edges is not None else []
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all nodes that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
||||
An empty list if no matching nodes are found.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all edges that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
||||
An empty list if no matching edges are found.
|
||||
"""
|
||||
# Default implementation iterates through all nodes and their edges, which is inefficient.
|
||||
# This method should be overridden by subclasses for better performance.
|
||||
all_edges = []
|
||||
all_labels = await self.get_all_labels()
|
||||
processed_edges = set()
|
||||
|
||||
for label in all_labels:
|
||||
edges = await self.get_node_edges(label)
|
||||
if edges:
|
||||
for src_id, tgt_id in edges:
|
||||
# Avoid processing the same edge twice in an undirected graph
|
||||
edge_tuple = tuple(sorted((src_id, tgt_id)))
|
||||
if edge_tuple in processed_edges:
|
||||
continue
|
||||
processed_edges.add(edge_tuple)
|
||||
|
||||
edge = await self.get_edge(src_id, tgt_id)
|
||||
if edge and "source_id" in edge:
|
||||
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
|
||||
if not source_ids.isdisjoint(chunk_ids):
|
||||
# Add source and target to the edge dict for easier processing later
|
||||
edge_with_nodes = edge.copy()
|
||||
edge_with_nodes["source"] = src_id
|
||||
edge_with_nodes["target"] = tgt_id
|
||||
all_edges.append(edge_with_nodes)
|
||||
return all_edges
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""Insert a new node or update an existing node in the graph.
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import configparser
|
|||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
||||
import pipmaster as pm
|
||||
|
||||
|
|
@ -784,79 +783,6 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all nodes that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids: List of chunk IDs to find associated nodes for
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
||||
An empty list if no matching nodes are found.
|
||||
"""
|
||||
if self._driver is None:
|
||||
raise RuntimeError(
|
||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||
)
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = f"""
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (n:`{workspace_label}`)
|
||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
||||
RETURN DISTINCT n
|
||||
"""
|
||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||
nodes = []
|
||||
async for record in result:
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
node_dict["id"] = node_dict.get("entity_id")
|
||||
nodes.append(node_dict)
|
||||
await result.consume()
|
||||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all edges that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids: List of chunk IDs to find associated edges for
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
||||
An empty list if no matching edges are found.
|
||||
"""
|
||||
if self._driver is None:
|
||||
raise RuntimeError(
|
||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||
)
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = f"""
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
|
||||
// Ensure we only return each unique edge once by ordering the source and target
|
||||
WITH a, b, r,
|
||||
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
|
||||
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
|
||||
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
|
||||
"""
|
||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||
edges = []
|
||||
async for record in result:
|
||||
edge_properties = record["properties"]
|
||||
edge_properties["source"] = record["source"]
|
||||
edge_properties["target"] = record["target"]
|
||||
edges.append(edge_properties)
|
||||
await result.consume()
|
||||
return edges
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self,
|
||||
node_label: str,
|
||||
|
|
|
|||
|
|
@ -1036,45 +1036,6 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
|
||||
return result
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all nodes that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
||||
An empty list if no matching nodes are found.
|
||||
"""
|
||||
if not chunk_ids:
|
||||
return []
|
||||
|
||||
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
|
||||
return [doc async for doc in cursor]
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all edges that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
||||
An empty list if no matching edges are found.
|
||||
"""
|
||||
if not chunk_ids:
|
||||
return []
|
||||
|
||||
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
|
||||
|
||||
edges = []
|
||||
async for edge in cursor:
|
||||
edge["source"] = edge["source_node_id"]
|
||||
edge["target"] = edge["target_node_id"]
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# UPSERTS
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import logging
|
|||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
||||
import pipmaster as pm
|
||||
|
||||
|
|
@ -904,49 +903,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
await result.consume() # Ensure results are fully consumed
|
||||
return edges_dict
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = f"""
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (n:`{workspace_label}`)
|
||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
||||
RETURN DISTINCT n
|
||||
"""
|
||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||
nodes = []
|
||||
async for record in result:
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
# Add node id (entity_id) to the dictionary for easier access
|
||||
node_dict["id"] = node_dict.get("entity_id")
|
||||
nodes.append(node_dict)
|
||||
await result.consume()
|
||||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = f"""
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||
"""
|
||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||
edges = []
|
||||
async for record in result:
|
||||
edge_properties = record["properties"]
|
||||
edge_properties["source"] = record["source"]
|
||||
edge_properties["target"] = record["target"]
|
||||
edges.append(edge_properties)
|
||||
await result.consume()
|
||||
return edges
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import final
|
|||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from lightrag.utils import logger
|
||||
from lightrag.base import BaseGraphStorage
|
||||
from lightrag.constants import GRAPH_FIELD_SEP
|
||||
import networkx as nx
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
|
|
@ -470,33 +469,6 @@ class NetworkXStorage(BaseGraphStorage):
|
|||
)
|
||||
return result
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
chunk_ids_set = set(chunk_ids)
|
||||
graph = await self._get_graph()
|
||||
matching_nodes = []
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
if "source_id" in node_data:
|
||||
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||
if not node_source_ids.isdisjoint(chunk_ids_set):
|
||||
node_data_with_id = node_data.copy()
|
||||
node_data_with_id["id"] = node_id
|
||||
matching_nodes.append(node_data_with_id)
|
||||
return matching_nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
chunk_ids_set = set(chunk_ids)
|
||||
graph = await self._get_graph()
|
||||
matching_edges = []
|
||||
for u, v, edge_data in graph.edges(data=True):
|
||||
if "source_id" in edge_data:
|
||||
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||
if not edge_source_ids.isdisjoint(chunk_ids_set):
|
||||
edge_data_with_nodes = edge_data.copy()
|
||||
edge_data_with_nodes["source"] = u
|
||||
edge_data_with_nodes["target"] = v
|
||||
matching_edges.append(edge_data_with_nodes)
|
||||
return matching_edges
|
||||
|
||||
async def get_all_nodes(self) -> list[dict]:
|
||||
"""Get all nodes in the graph.
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from ..base import (
|
|||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
|
||||
|
||||
import pipmaster as pm
|
||||
|
|
@ -3569,17 +3568,13 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier, return only node properties"""
|
||||
|
||||
label = self._normalize_node_id(node_id)
|
||||
|
||||
result = await self.get_nodes_batch(node_ids=[label])
|
||||
result = await self.get_nodes_batch(node_ids=[node_id])
|
||||
if result and node_id in result:
|
||||
return result[node_id]
|
||||
return None
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
label = self._normalize_node_id(node_id)
|
||||
|
||||
result = await self.node_degrees_batch(node_ids=[label])
|
||||
result = await self.node_degrees_batch(node_ids=[node_id])
|
||||
if result and node_id in result:
|
||||
return result[node_id]
|
||||
|
||||
|
|
@ -3592,12 +3587,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
"""Get edge properties between two nodes"""
|
||||
src_label = self._normalize_node_id(source_node_id)
|
||||
tgt_label = self._normalize_node_id(target_node_id)
|
||||
|
||||
result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
|
||||
if result and (src_label, tgt_label) in result:
|
||||
return result[(src_label, tgt_label)]
|
||||
result = await self.get_edges_batch(
|
||||
[{"src": source_node_id, "tgt": target_node_id}]
|
||||
)
|
||||
if result and (source_node_id, target_node_id) in result:
|
||||
return result[(source_node_id, target_node_id)]
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
|
|
@ -3795,13 +3789,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
if not node_ids:
|
||||
return {}
|
||||
|
||||
seen = set()
|
||||
unique_ids = []
|
||||
seen: set[str] = set()
|
||||
unique_ids: list[str] = []
|
||||
lookup: dict[str, str] = {}
|
||||
requested: set[str] = set()
|
||||
for nid in node_ids:
|
||||
nid_norm = self._normalize_node_id(nid)
|
||||
if nid_norm not in seen:
|
||||
seen.add(nid_norm)
|
||||
unique_ids.append(nid_norm)
|
||||
if nid not in seen:
|
||||
seen.add(nid)
|
||||
unique_ids.append(nid)
|
||||
requested.add(nid)
|
||||
lookup[nid] = nid
|
||||
lookup[self._normalize_node_id(nid)] = nid
|
||||
|
||||
# Build result dictionary
|
||||
nodes_dict = {}
|
||||
|
|
@ -3840,10 +3838,18 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse node string in batch: {node_dict}"
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
||||
)
|
||||
|
||||
nodes_dict[result["node_id"]] = node_dict
|
||||
node_key = result["node_id"]
|
||||
original_key = lookup.get(node_key)
|
||||
if original_key is None:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Node {node_key} not found in lookup map"
|
||||
)
|
||||
original_key = node_key
|
||||
if original_key in requested:
|
||||
nodes_dict[original_key] = node_dict
|
||||
|
||||
return nodes_dict
|
||||
|
||||
|
|
@ -3866,13 +3872,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
if not node_ids:
|
||||
return {}
|
||||
|
||||
seen = set()
|
||||
seen: set[str] = set()
|
||||
unique_ids: list[str] = []
|
||||
lookup: dict[str, str] = {}
|
||||
requested: set[str] = set()
|
||||
for nid in node_ids:
|
||||
n = self._normalize_node_id(nid)
|
||||
if n not in seen:
|
||||
seen.add(n)
|
||||
unique_ids.append(n)
|
||||
if nid not in seen:
|
||||
seen.add(nid)
|
||||
unique_ids.append(nid)
|
||||
requested.add(nid)
|
||||
lookup[nid] = nid
|
||||
lookup[self._normalize_node_id(nid)] = nid
|
||||
|
||||
out_degrees = {}
|
||||
in_degrees = {}
|
||||
|
|
@ -3924,8 +3934,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
node_id = row["node_id"]
|
||||
if not node_id:
|
||||
continue
|
||||
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
|
||||
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
|
||||
node_key = node_id
|
||||
original_key = lookup.get(node_key)
|
||||
if original_key is None:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Node {node_key} not found in lookup map"
|
||||
)
|
||||
original_key = node_key
|
||||
if original_key in requested:
|
||||
out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
|
||||
in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
|
||||
|
||||
degrees_dict = {}
|
||||
for node_id in node_ids:
|
||||
|
|
@ -4054,7 +4072,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
edge_props = json.loads(edge_props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge properties string: {edge_props}"
|
||||
f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -4070,7 +4088,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
edge_props = json.loads(edge_props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge properties string: {edge_props}"
|
||||
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -4175,102 +4193,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
labels.append(result["label"])
|
||||
return labels
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""
|
||||
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
|
||||
This method uses a Cypher query with UNWIND to efficiently find all nodes
|
||||
where the `source_id` property contains any of the specified chunk IDs.
|
||||
"""
|
||||
# The string representation of the list for the cypher query
|
||||
chunk_ids_str = json.dumps(chunk_ids)
|
||||
|
||||
query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND {chunk_ids_str} AS chunk_id
|
||||
MATCH (n:base)
|
||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
|
||||
RETURN n
|
||||
$$) AS (n agtype);
|
||||
"""
|
||||
results = await self._query(query)
|
||||
|
||||
# Build result list
|
||||
nodes = []
|
||||
for result in results:
|
||||
if result["n"]:
|
||||
node_dict = result["n"]["properties"]
|
||||
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(node_dict, str):
|
||||
try:
|
||||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
||||
)
|
||||
|
||||
node_dict["id"] = node_dict["entity_id"]
|
||||
nodes.append(node_dict)
|
||||
|
||||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""
|
||||
Retrieves edges from the graph that are associated with a given list of chunk IDs.
|
||||
This method uses a Cypher query with UNWIND to efficiently find all edges
|
||||
where the `source_id` property contains any of the specified chunk IDs.
|
||||
"""
|
||||
chunk_ids_str = json.dumps(chunk_ids)
|
||||
|
||||
query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND {chunk_ids_str} AS chunk_id
|
||||
MATCH ()-[r]-()
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
|
||||
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
|
||||
$$) AS (edge agtype, source agtype, target agtype);
|
||||
"""
|
||||
results = await self._query(query)
|
||||
edges = []
|
||||
if results:
|
||||
for item in results:
|
||||
edge_agtype = item["edge"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(edge_agtype, str):
|
||||
try:
|
||||
edge_agtype = json.loads(edge_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
|
||||
)
|
||||
|
||||
source_agtype = item["source"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(source_agtype, str):
|
||||
try:
|
||||
source_agtype = json.loads(source_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
|
||||
)
|
||||
|
||||
target_agtype = item["target"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(target_agtype, str):
|
||||
try:
|
||||
target_agtype = json.loads(target_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
|
||||
)
|
||||
|
||||
if edge_agtype and source_agtype and target_agtype:
|
||||
edge_properties = edge_agtype
|
||||
edge_properties["source"] = source_agtype["entity_id"]
|
||||
edge_properties["target"] = target_agtype["entity_id"]
|
||||
edges.append(edge_properties)
|
||||
return edges
|
||||
|
||||
async def _bfs_subgraph(
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
|
|
|
|||
|
|
@ -3235,38 +3235,31 @@ class LightRAG:
|
|||
|
||||
if entity_chunk_updates and self.entity_chunks:
|
||||
entity_upsert_payload = {}
|
||||
entity_delete_ids: set[str] = set()
|
||||
for entity_name, remaining in entity_chunk_updates.items():
|
||||
if not remaining:
|
||||
entity_delete_ids.add(entity_name)
|
||||
else:
|
||||
entity_upsert_payload[entity_name] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
|
||||
if entity_delete_ids:
|
||||
await self.entity_chunks.delete(list(entity_delete_ids))
|
||||
# Empty entities are deleted alongside graph nodes later
|
||||
continue
|
||||
entity_upsert_payload[entity_name] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
if entity_upsert_payload:
|
||||
await self.entity_chunks.upsert(entity_upsert_payload)
|
||||
|
||||
if relation_chunk_updates and self.relation_chunks:
|
||||
relation_upsert_payload = {}
|
||||
relation_delete_ids: set[str] = set()
|
||||
for edge_tuple, remaining in relation_chunk_updates.items():
|
||||
storage_key = make_relation_chunk_key(*edge_tuple)
|
||||
if not remaining:
|
||||
relation_delete_ids.add(storage_key)
|
||||
else:
|
||||
relation_upsert_payload[storage_key] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
# Empty relations are deleted alongside graph edges later
|
||||
continue
|
||||
storage_key = make_relation_chunk_key(*edge_tuple)
|
||||
relation_upsert_payload[storage_key] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
|
||||
if relation_delete_ids:
|
||||
await self.relation_chunks.delete(list(relation_delete_ids))
|
||||
if relation_upsert_payload:
|
||||
await self.relation_chunks.upsert(relation_upsert_payload)
|
||||
|
||||
|
|
@ -3296,7 +3289,7 @@ class LightRAG:
|
|||
# 6. Delete relationships that have no remaining sources
|
||||
if relationships_to_delete:
|
||||
try:
|
||||
# Delete from vector database
|
||||
# Delete from relation vdb
|
||||
rel_ids_to_delete = []
|
||||
for src, tgt in relationships_to_delete:
|
||||
rel_ids_to_delete.extend(
|
||||
|
|
@ -3333,15 +3326,16 @@ class LightRAG:
|
|||
# 7. Delete entities that have no remaining sources
|
||||
if entities_to_delete:
|
||||
try:
|
||||
# Batch get all edges for entities to avoid N+1 query problem
|
||||
nodes_edges_dict = await self.chunk_entity_relation_graph.get_nodes_edges_batch(
|
||||
list(entities_to_delete)
|
||||
)
|
||||
|
||||
# Debug: Check and log all edges before deleting nodes
|
||||
edges_to_delete = set()
|
||||
edges_still_exist = 0
|
||||
for entity in entities_to_delete:
|
||||
edges = (
|
||||
await self.chunk_entity_relation_graph.get_node_edges(
|
||||
entity
|
||||
)
|
||||
)
|
||||
|
||||
for entity, edges in nodes_edges_dict.items():
|
||||
if edges:
|
||||
for src, tgt in edges:
|
||||
# Normalize edge representation (sorted for consistency)
|
||||
|
|
@ -3364,6 +3358,7 @@ class LightRAG:
|
|||
f"Edge still exists: {src} <-- {tgt}"
|
||||
)
|
||||
edges_still_exist += 1
|
||||
|
||||
if edges_still_exist:
|
||||
logger.warning(
|
||||
f"⚠️ {edges_still_exist} entities still has edges before deletion"
|
||||
|
|
@ -3399,7 +3394,7 @@ class LightRAG:
|
|||
list(entities_to_delete)
|
||||
)
|
||||
|
||||
# Delete from vector database
|
||||
# Delete from vector vdb
|
||||
entity_vdb_ids = [
|
||||
compute_mdhash_id(entity, prefix="ent-")
|
||||
for entity in entities_to_delete
|
||||
|
|
|
|||
|
|
@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace
|
|||
import argparse
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, ClassVar, List
|
||||
from typing import Any, ClassVar, List, get_args, get_origin
|
||||
|
||||
from lightrag.utils import get_env_value
|
||||
from lightrag.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
|
||||
def _resolve_optional_type(field_type: Any) -> Any:
|
||||
"""Return the concrete type for Optional/Union annotations."""
|
||||
origin = get_origin(field_type)
|
||||
if origin in (list, dict, tuple):
|
||||
return field_type
|
||||
|
||||
args = get_args(field_type)
|
||||
if args:
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
return non_none_args[0]
|
||||
return field_type
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BindingOptions Base Class
|
||||
# =============================================================================
|
||||
|
|
@ -177,9 +191,13 @@ class BindingOptions:
|
|||
help=arg_item["help"],
|
||||
)
|
||||
else:
|
||||
resolved_type = arg_item["type"]
|
||||
if resolved_type is not None:
|
||||
resolved_type = _resolve_optional_type(resolved_type)
|
||||
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=arg_item["type"],
|
||||
type=resolved_type,
|
||||
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
||||
help=arg_item["help"],
|
||||
)
|
||||
|
|
@ -210,7 +228,7 @@ class BindingOptions:
|
|||
argdef = {
|
||||
"argname": f"{args_prefix}-{field.name}",
|
||||
"env_name": f"{env_var_prefix}{field.name.upper()}",
|
||||
"type": field.type,
|
||||
"type": _resolve_optional_type(field.type),
|
||||
"default": default_value,
|
||||
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
|
||||
}
|
||||
|
|
@ -454,6 +472,39 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
|
|||
_binding_name: ClassVar[str] = "ollama_llm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeminiLLMOptions(BindingOptions):
|
||||
"""Options for Google Gemini models."""
|
||||
|
||||
_binding_name: ClassVar[str] = "gemini_llm"
|
||||
|
||||
temperature: float = DEFAULT_TEMPERATURE
|
||||
top_p: float = 0.95
|
||||
top_k: int = 40
|
||||
max_output_tokens: int | None = None
|
||||
candidate_count: int = 1
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
stop_sequences: List[str] = field(default_factory=list)
|
||||
seed: int | None = None
|
||||
thinking_config: dict | None = None
|
||||
safety_settings: dict | None = None
|
||||
|
||||
_help: ClassVar[dict[str, str]] = {
|
||||
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
|
||||
"top_p": "Nucleus sampling parameter (0.0-1.0)",
|
||||
"top_k": "Limits sampling to the top K tokens (1 disables the limit)",
|
||||
"max_output_tokens": "Maximum tokens generated in the response",
|
||||
"candidate_count": "Number of candidates returned per request",
|
||||
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
|
||||
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
|
||||
"stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
|
||||
"seed": "Random seed for reproducible generation (leave empty for random)",
|
||||
"thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
|
||||
"safety_settings": "JSON object with Gemini safety settings overrides",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Binding Options for OpenAI
|
||||
# =============================================================================
|
||||
|
|
|
|||
422
lightrag/llm/gemini.py
Normal file
422
lightrag/llm/gemini.py
Normal file
|
|
@ -0,0 +1,422 @@
|
|||
"""
|
||||
Gemini LLM binding for LightRAG.
|
||||
|
||||
This module provides asynchronous helpers that adapt Google's Gemini models
|
||||
to the same interface used by the rest of the LightRAG LLM bindings. The
|
||||
implementation mirrors the OpenAI helpers while relying on the official
|
||||
``google-genai`` client under the hood.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
# Install the Google Gemini client on demand
|
||||
if not pm.is_installed("google-genai"):
|
||||
pm.install("google-genai")
|
||||
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
|
||||
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _get_gemini_client(
|
||||
api_key: str, base_url: str | None, timeout: int | None = None
|
||||
) -> genai.Client:
|
||||
"""
|
||||
Create (or fetch cached) Gemini client.
|
||||
|
||||
Args:
|
||||
api_key: Google Gemini API key.
|
||||
base_url: Optional custom API endpoint.
|
||||
timeout: Optional request timeout in milliseconds.
|
||||
|
||||
Returns:
|
||||
genai.Client: Configured Gemini client instance.
|
||||
"""
|
||||
client_kwargs: dict[str, Any] = {"api_key": api_key}
|
||||
|
||||
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None:
|
||||
try:
|
||||
http_options_kwargs = {}
|
||||
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
|
||||
http_options_kwargs["api_endpoint"] = base_url
|
||||
if timeout is not None:
|
||||
http_options_kwargs["timeout"] = timeout
|
||||
|
||||
client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
LOG.warning("Failed to apply custom Gemini http_options: %s", exc)
|
||||
|
||||
try:
|
||||
return genai.Client(**client_kwargs)
|
||||
except TypeError:
|
||||
# Older google-genai releases don't accept http_options; retry without it.
|
||||
client_kwargs.pop("http_options", None)
|
||||
return genai.Client(**client_kwargs)
|
||||
|
||||
|
||||
def _ensure_api_key(api_key: str | None) -> str:
|
||||
key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
if not key:
|
||||
raise ValueError(
|
||||
"Gemini API key not provided. "
|
||||
"Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def _build_generation_config(
|
||||
base_config: dict[str, Any] | None,
|
||||
system_prompt: str | None,
|
||||
keyword_extraction: bool,
|
||||
) -> types.GenerateContentConfig | None:
|
||||
config_data = dict(base_config or {})
|
||||
|
||||
if system_prompt:
|
||||
if config_data.get("system_instruction"):
|
||||
config_data["system_instruction"] = (
|
||||
f"{config_data['system_instruction']}\n{system_prompt}"
|
||||
)
|
||||
else:
|
||||
config_data["system_instruction"] = system_prompt
|
||||
|
||||
if keyword_extraction and not config_data.get("response_mime_type"):
|
||||
config_data["response_mime_type"] = "application/json"
|
||||
|
||||
# Remove entries that are explicitly set to None to avoid type errors
|
||||
sanitized = {
|
||||
key: value
|
||||
for key, value in config_data.items()
|
||||
if value is not None and value != ""
|
||||
}
|
||||
|
||||
if not sanitized:
|
||||
return None
|
||||
|
||||
return types.GenerateContentConfig(**sanitized)
|
||||
|
||||
|
||||
def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
|
||||
if not history_messages:
|
||||
return ""
|
||||
|
||||
history_lines: list[str] = []
|
||||
for message in history_messages:
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
history_lines.append(f"[{role}] {content}")
|
||||
|
||||
return "\n".join(history_lines)
|
||||
|
||||
|
||||
def _extract_response_text(
|
||||
response: Any, extract_thoughts: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Extract text content from Gemini response, separating regular content from thoughts.
|
||||
|
||||
Args:
|
||||
response: Gemini API response object
|
||||
extract_thoughts: Whether to extract thought content separately
|
||||
|
||||
Returns:
|
||||
Tuple of (regular_text, thought_text)
|
||||
"""
|
||||
candidates = getattr(response, "candidates", None)
|
||||
if not candidates:
|
||||
return ("", "")
|
||||
|
||||
regular_parts: list[str] = []
|
||||
thought_parts: list[str] = []
|
||||
|
||||
for candidate in candidates:
|
||||
if not getattr(candidate, "content", None):
|
||||
continue
|
||||
# Use 'or []' to handle None values from parts attribute
|
||||
for part in getattr(candidate.content, "parts", None) or []:
|
||||
text = getattr(part, "text", None)
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# Check if this part is thought content using the 'thought' attribute
|
||||
is_thought = getattr(part, "thought", False)
|
||||
|
||||
if is_thought and extract_thoughts:
|
||||
thought_parts.append(text)
|
||||
elif not is_thought:
|
||||
regular_parts.append(text)
|
||||
|
||||
return ("\n".join(regular_parts), "\n".join(thought_parts))
|
||||
|
||||
|
||||
async def gemini_complete_if_cache(
|
||||
model: str,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
stream: bool | None = None,
|
||||
keyword_extraction: bool = False,
|
||||
generation_config: dict[str, Any] | None = None,
|
||||
timeout: int | None = None,
|
||||
**_: Any,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Complete a prompt using Gemini's API with Chain of Thought (COT) support.
|
||||
|
||||
This function supports automatic integration of reasoning content from Gemini models
|
||||
that provide Chain of Thought capabilities via the thinking_config API feature.
|
||||
|
||||
COT Integration:
|
||||
- When enable_cot=True: Thought content is wrapped in <think>...</think> tags
|
||||
- When enable_cot=False: Thought content is filtered out, only regular content returned
|
||||
- Thought content is identified by the 'thought' attribute on response parts
|
||||
- Requires thinking_config to be enabled in generation_config for API to return thoughts
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use.
|
||||
prompt: The prompt to complete.
|
||||
system_prompt: Optional system prompt to include.
|
||||
history_messages: Optional list of previous messages in the conversation.
|
||||
api_key: Optional Gemini API key. If None, uses environment variable.
|
||||
base_url: Optional custom API endpoint.
|
||||
generation_config: Optional generation configuration dict.
|
||||
keyword_extraction: Whether to use JSON response format.
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
stream: Whether to stream the response.
|
||||
hashing_kv: Storage interface (for interface parity with other bindings).
|
||||
enable_cot: Whether to include Chain of Thought content in the response.
|
||||
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
|
||||
**_: Additional keyword arguments (ignored).
|
||||
|
||||
Returns:
|
||||
The completed text (with COT content if enable_cot=True) or an async iterator
|
||||
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the response from Gemini is empty.
|
||||
ValueError: If API key is not provided or configured.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
key = _ensure_api_key(api_key)
|
||||
# Convert timeout from seconds to milliseconds for Gemini API
|
||||
timeout_ms = timeout * 1000 if timeout else None
|
||||
client = _get_gemini_client(key, base_url, timeout_ms)
|
||||
|
||||
history_block = _format_history_messages(history_messages)
|
||||
prompt_sections = []
|
||||
if history_block:
|
||||
prompt_sections.append(history_block)
|
||||
prompt_sections.append(f"[user] {prompt}")
|
||||
combined_prompt = "\n".join(prompt_sections)
|
||||
|
||||
config_obj = _build_generation_config(
|
||||
generation_config,
|
||||
system_prompt=system_prompt,
|
||||
keyword_extraction=keyword_extraction,
|
||||
)
|
||||
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"contents": [combined_prompt],
|
||||
}
|
||||
if config_obj is not None:
|
||||
request_kwargs["config"] = config_obj
|
||||
|
||||
def _call_model():
|
||||
return client.models.generate_content(**request_kwargs)
|
||||
|
||||
if stream:
|
||||
queue: asyncio.Queue[Any] = asyncio.Queue()
|
||||
usage_container: dict[str, Any] = {}
|
||||
|
||||
def _stream_model() -> None:
|
||||
# COT state tracking for streaming
|
||||
cot_active = False
|
||||
cot_started = False
|
||||
initial_content_seen = False
|
||||
|
||||
try:
|
||||
stream_kwargs = dict(request_kwargs)
|
||||
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
|
||||
for chunk in stream_iterator:
|
||||
usage = getattr(chunk, "usage_metadata", None)
|
||||
if usage is not None:
|
||||
usage_container["usage"] = usage
|
||||
|
||||
# Extract both regular and thought content
|
||||
regular_text, thought_text = _extract_response_text(
|
||||
chunk, extract_thoughts=True
|
||||
)
|
||||
|
||||
if enable_cot:
|
||||
# Process regular content
|
||||
if regular_text:
|
||||
if not initial_content_seen:
|
||||
initial_content_seen = True
|
||||
|
||||
# Close COT section if it was active
|
||||
if cot_active:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||
cot_active = False
|
||||
|
||||
# Send regular content
|
||||
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
|
||||
|
||||
# Process thought content
|
||||
if thought_text:
|
||||
if not initial_content_seen and not cot_started:
|
||||
# Start COT section
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "<think>")
|
||||
cot_active = True
|
||||
cot_started = True
|
||||
|
||||
# Send thought content if COT is active
|
||||
if cot_active:
|
||||
loop.call_soon_threadsafe(
|
||||
queue.put_nowait, thought_text
|
||||
)
|
||||
else:
|
||||
# COT disabled - only send regular content
|
||||
if regular_text:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
|
||||
|
||||
# Ensure COT is properly closed if still active
|
||||
if cot_active:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||
except Exception as exc: # pragma: no cover - surface runtime issues
|
||||
# Try to close COT tag before reporting error
|
||||
if cot_active:
|
||||
try:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||
except Exception:
|
||||
pass
|
||||
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
||||
|
||||
loop.run_in_executor(None, _stream_model)
|
||||
|
||||
async def _async_stream() -> AsyncIterator[str]:
|
||||
try:
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
|
||||
chunk_text = str(item)
|
||||
if "\\u" in chunk_text:
|
||||
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
|
||||
|
||||
# Yield the chunk directly without filtering
|
||||
# COT filtering is already handled in _stream_model()
|
||||
yield chunk_text
|
||||
finally:
|
||||
usage = usage_container.get("usage")
|
||||
if token_tracker and usage:
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||
"completion_tokens": getattr(
|
||||
usage, "candidates_token_count", 0
|
||||
),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
return _async_stream()
|
||||
|
||||
response = await asyncio.to_thread(_call_model)
|
||||
|
||||
# Extract both regular text and thought text
|
||||
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
|
||||
|
||||
# Apply COT filtering logic based on enable_cot parameter
|
||||
if enable_cot:
|
||||
# Include thought content wrapped in <think> tags
|
||||
if thought_text and thought_text.strip():
|
||||
if not regular_text or regular_text.strip() == "":
|
||||
# Only thought content available
|
||||
final_text = f"<think>{thought_text}</think>"
|
||||
else:
|
||||
# Both content types present: prepend thought to regular content
|
||||
final_text = f"<think>{thought_text}</think>{regular_text}"
|
||||
else:
|
||||
# No thought content, use regular content only
|
||||
final_text = regular_text or ""
|
||||
else:
|
||||
# Filter out thought content, return only regular content
|
||||
final_text = regular_text or ""
|
||||
|
||||
if not final_text:
|
||||
raise RuntimeError("Gemini response did not contain any text content.")
|
||||
|
||||
if "\\u" in final_text:
|
||||
final_text = safe_unicode_decode(final_text.encode("utf-8"))
|
||||
|
||||
final_text = remove_think_tags(final_text)
|
||||
|
||||
usage = getattr(response, "usage_metadata", None)
|
||||
if token_tracker and usage:
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||
"completion_tokens": getattr(usage, "candidates_token_count", 0),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Gemini response length: %s", len(final_text))
|
||||
return final_text
|
||||
|
||||
|
||||
async def gemini_model_complete(
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
keyword_extraction: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> str | AsyncIterator[str]:
|
||||
hashing_kv = kwargs.get("hashing_kv")
|
||||
model_name = None
|
||||
if hashing_kv is not None:
|
||||
model_name = hashing_kv.global_config.get("llm_model_name")
|
||||
if model_name is None:
|
||||
model_name = kwargs.pop("model_name", None)
|
||||
if model_name is None:
|
||||
raise ValueError("Gemini model name not provided in configuration.")
|
||||
|
||||
return await gemini_complete_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
keyword_extraction=keyword_extraction,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"gemini_complete_if_cache",
|
||||
"gemini_model_complete",
|
||||
]
|
||||
|
|
@ -138,6 +138,9 @@ async def openai_complete_if_cache(
|
|||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI
|
||||
stream: bool | None = None,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
|
||||
|
|
@ -172,8 +175,9 @@ async def openai_complete_if_cache(
|
|||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||
These will be passed to the client constructor but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
||||
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
||||
- stream: Whether to stream the response. Default is False.
|
||||
- timeout: Request timeout in seconds. Default is None.
|
||||
|
||||
Returns:
|
||||
The completed text (with integrated COT content if available) or an async iterator
|
||||
|
|
|
|||
|
|
@ -1795,7 +1795,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
|
|||
- Filter out short numeric-only text (length < 3 and only digits/dots)
|
||||
- remove_inner_quotes = True
|
||||
remove Chinese quotes
|
||||
remove English queotes in and around chinese
|
||||
remove English quotes in and around chinese
|
||||
Convert non-breaking spaces to regular spaces
|
||||
Convert narrow non-breaking spaces after non-digits to regular spaces
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
|||
"aiohttp",
|
||||
"configparser",
|
||||
"future",
|
||||
"google-genai>=1.0.0,<2.0.0",
|
||||
"json_repair",
|
||||
"nano-vectordb",
|
||||
"networkx",
|
||||
|
|
@ -59,6 +60,7 @@ api = [
|
|||
"tenacity",
|
||||
"tiktoken",
|
||||
"xlsxwriter>=3.1.0",
|
||||
"google-genai>=1.0.0,<2.0.0",
|
||||
# API-specific dependencies
|
||||
"aiofiles",
|
||||
"ascii_colors",
|
||||
|
|
@ -107,6 +109,7 @@ offline-llm = [
|
|||
"aioboto3>=12.0.0,<16.0.0",
|
||||
"voyageai>=0.2.0,<1.0.0",
|
||||
"llama-index>=0.9.0,<1.0.0",
|
||||
"google-genai>=1.0.0,<2.0.0",
|
||||
]
|
||||
|
||||
offline = [
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
||||
aioboto3>=12.0.0,<16.0.0
|
||||
anthropic>=0.18.0,<1.0.0
|
||||
google-genai>=1.0.0,<2.0.0
|
||||
llama-index>=0.9.0,<1.0.0
|
||||
ollama>=0.1.0,<1.0.0
|
||||
openai>=1.0.0,<3.0.0
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ anthropic>=0.18.0,<1.0.0
|
|||
|
||||
# Storage backend dependencies
|
||||
asyncpg>=0.29.0,<1.0.0
|
||||
google-genai>=1.0.0,<2.0.0
|
||||
|
||||
# Document processing dependencies
|
||||
llama-index>=0.9.0,<1.0.0
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue