Merge branch 'main' into feature-add-tigergraph-support
This commit is contained in:
commit
dc2898d358
26 changed files with 1361 additions and 1476 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -67,8 +67,8 @@ download_models_hf.py
|
||||||
# Frontend build output (built during PyPI release)
|
# Frontend build output (built during PyPI release)
|
||||||
/lightrag/api/webui/
|
/lightrag/api/webui/
|
||||||
|
|
||||||
# unit-test files
|
# temporary test files in project root
|
||||||
test_*
|
/test_*
|
||||||
|
|
||||||
# Cline files
|
# Cline files
|
||||||
memory-bank
|
memory-bank
|
||||||
|
|
|
||||||
46
README-zh.md
46
README-zh.md
|
|
@ -53,7 +53,7 @@
|
||||||
|
|
||||||
## 🎉 新闻
|
## 🎉 新闻
|
||||||
|
|
||||||
- [x] [2025.11.05]🎯📢添加**基于RAGAS的**LightRAG评估框架。
|
- [x] [2025.11.05]🎯📢添加**基于RAGAS的**评估框架和**Langfuse**可观测性支持。
|
||||||
- [x] [2025.10.22]🎯📢消除处理**大规模数据集**的瓶颈。
|
- [x] [2025.10.22]🎯📢消除处理**大规模数据集**的瓶颈。
|
||||||
- [x] [2025.09.15]🎯📢显著提升**小型LLM**(如Qwen3-30B-A3B)的知识图谱提取准确性。
|
- [x] [2025.09.15]🎯📢显著提升**小型LLM**(如Qwen3-30B-A3B)的知识图谱提取准确性。
|
||||||
- [x] [2025.08.29]🎯📢现已支持**Reranker**,显著提升混合查询性能。
|
- [x] [2025.08.29]🎯📢现已支持**Reranker**,显著提升混合查询性能。
|
||||||
|
|
@ -1463,6 +1463,50 @@ LightRAG服务器提供全面的知识图谱可视化功能。它支持各种重
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
## Langfuse 可观测性集成
|
||||||
|
|
||||||
|
Langfuse 为 OpenAI 客户端提供了直接替代方案,可自动跟踪所有 LLM 交互,使开发者能够在无需修改代码的情况下监控、调试和优化其 RAG 系统。
|
||||||
|
|
||||||
|
### 安装 Langfuse 可选依赖
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install lightrag-hku
|
||||||
|
pip install lightrag-hku[observability]
|
||||||
|
|
||||||
|
# 或从源代码安装并启用调试模式
|
||||||
|
pip install -e .
|
||||||
|
pip install -e ".[observability]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置 Langfuse 环境变量
|
||||||
|
|
||||||
|
修改 .env 文件:
|
||||||
|
|
||||||
|
```
|
||||||
|
## Langfuse 可观测性(可选)
|
||||||
|
# LLM 可观测性和追踪平台
|
||||||
|
# 安装命令: pip install lightrag-hku[observability]
|
||||||
|
# 注册地址: https://cloud.langfuse.com 或自托管部署
|
||||||
|
LANGFUSE_SECRET_KEY=""
|
||||||
|
LANGFUSE_PUBLIC_KEY=""
|
||||||
|
LANGFUSE_HOST="https://cloud.langfuse.com" # 或您的自托管实例地址
|
||||||
|
LANGFUSE_ENABLE_TRACE=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Langfuse 使用说明
|
||||||
|
|
||||||
|
安装并配置完成后,Langfuse 会自动追踪所有 OpenAI LLM 调用。Langfuse 仪表板功能包括:
|
||||||
|
|
||||||
|
- **追踪**:查看完整的 LLM 调用链
|
||||||
|
- **分析**:Token 使用量、延迟、成本指标
|
||||||
|
- **调试**:检查提示词和响应内容
|
||||||
|
- **评估**:比较模型输出结果
|
||||||
|
- **监控**:实时告警功能
|
||||||
|
|
||||||
|
### 重要提示
|
||||||
|
|
||||||
|
**注意**:LightRAG 目前仅把 OpenAI 兼容的 API 调用接入了 Langfuse。Ollama、Azure 和 AWS Bedrock 等 API 还无法使用 Langfuse 的可观测性功能。
|
||||||
|
|
||||||
## RAGAS评估
|
## RAGAS评估
|
||||||
|
|
||||||
**RAGAS**(Retrieval Augmented Generation Assessment,检索增强生成评估)是一个使用LLM对RAG系统进行无参考评估的框架。我们提供了基于RAGAS的评估脚本。详细信息请参阅[基于RAGAS的评估框架](lightrag/evaluation/README.md)。
|
**RAGAS**(Retrieval Augmented Generation Assessment,检索增强生成评估)是一个使用LLM对RAG系统进行无参考评估的框架。我们提供了基于RAGAS的评估脚本。详细信息请参阅[基于RAGAS的评估框架](lightrag/evaluation/README.md)。
|
||||||
|
|
|
||||||
46
README.md
46
README.md
|
|
@ -51,7 +51,7 @@
|
||||||
|
|
||||||
---
|
---
|
||||||
## 🎉 News
|
## 🎉 News
|
||||||
- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework for LightRAG.
|
- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework and **Langfuse** observability for LightRAG.
|
||||||
- [x] [2025.10.22]🎯📢Eliminate bottlenecks in processing **large-scale datasets**.
|
- [x] [2025.10.22]🎯📢Eliminate bottlenecks in processing **large-scale datasets**.
|
||||||
- [x] [2025.09.15]🎯📢Significantly enhances KG extraction accuracy for **small LLMs** like Qwen3-30B-A3B.
|
- [x] [2025.09.15]🎯📢Significantly enhances KG extraction accuracy for **small LLMs** like Qwen3-30B-A3B.
|
||||||
- [x] [2025.08.29]🎯📢**Reranker** is supported now , significantly boosting performance for mixed queries.
|
- [x] [2025.08.29]🎯📢**Reranker** is supported now , significantly boosting performance for mixed queries.
|
||||||
|
|
@ -1543,6 +1543,50 @@ The LightRAG Server offers a comprehensive knowledge graph visualization feature
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
## Langfuse observability integration
|
||||||
|
|
||||||
|
Langfuse provides a drop-in replacement for the OpenAI client that automatically tracks all LLM interactions, enabling developers to monitor, debug, and optimize their RAG systems without code changes.
|
||||||
|
|
||||||
|
### Installation with Langfuse option
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install lightrag-hku
|
||||||
|
pip install lightrag-hku[observability]
|
||||||
|
|
||||||
|
# Or install from souce code with debug mode enabled
|
||||||
|
pip install -e .
|
||||||
|
pip install -e ".[observability]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config Langfuse env vars
|
||||||
|
|
||||||
|
modify .env file:
|
||||||
|
|
||||||
|
```
|
||||||
|
## Langfuse Observability (Optional)
|
||||||
|
# LLM observability and tracing platform
|
||||||
|
# Install with: pip install lightrag-hku[observability]
|
||||||
|
# Sign up at: https://cloud.langfuse.com or self-host
|
||||||
|
LANGFUSE_SECRET_KEY=""
|
||||||
|
LANGFUSE_PUBLIC_KEY=""
|
||||||
|
LANGFUSE_HOST="https://cloud.langfuse.com" # or your self-hosted instance
|
||||||
|
LANGFUSE_ENABLE_TRACE=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Langfuse Usage
|
||||||
|
|
||||||
|
Once installed and configured, Langfuse automatically traces all OpenAI LLM calls. Langfuse dashboard features include:
|
||||||
|
|
||||||
|
- **Tracing**: View complete LLM call chains
|
||||||
|
- **Analytics**: Token usage, latency, cost metrics
|
||||||
|
- **Debugging**: Inspect prompts and responses
|
||||||
|
- **Evaluation**: Compare model outputs
|
||||||
|
- **Monitoring**: Real-time alerting
|
||||||
|
|
||||||
|
### Important Notice
|
||||||
|
|
||||||
|
**Note**: LightRAG currently only integrates OpenAI-compatible API calls with Langfuse. APIs such as Ollama, Azure, and AWS Bedrock are not yet supported for Langfuse observability.
|
||||||
|
|
||||||
## RAGAS-based Evaluation
|
## RAGAS-based Evaluation
|
||||||
|
|
||||||
**RAGAS** (Retrieval Augmented Generation Assessment) is a framework for reference-free evaluation of RAG systems using LLMs. There is an evaluation script based on RAGAS. For detailed information, please refer to [RAGAS-based Evaluation Framework](lightrag/evaluation/README.md).
|
**RAGAS** (Retrieval Augmented Generation Assessment) is a framework for reference-free evaluation of RAG systems using LLMs. There is an evaluation script based on RAGAS. For detailed information, please refer to [RAGAS-based Evaluation Framework](lightrag/evaluation/README.md).
|
||||||
|
|
|
||||||
11
env.example
11
env.example
|
|
@ -170,7 +170,7 @@ MAX_PARALLEL_INSERT=2
|
||||||
|
|
||||||
###########################################################################
|
###########################################################################
|
||||||
### LLM Configuration
|
### LLM Configuration
|
||||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini
|
||||||
### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service
|
### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service
|
||||||
###########################################################################
|
###########################################################################
|
||||||
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
|
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
|
||||||
|
|
@ -191,6 +191,15 @@ LLM_BINDING_API_KEY=your_api_key
|
||||||
# LLM_BINDING_API_KEY=your_api_key
|
# LLM_BINDING_API_KEY=your_api_key
|
||||||
# LLM_BINDING=openai
|
# LLM_BINDING=openai
|
||||||
|
|
||||||
|
### Gemini example
|
||||||
|
# LLM_BINDING=gemini
|
||||||
|
# LLM_MODEL=gemini-flash-latest
|
||||||
|
# LLM_BINDING_API_KEY=your_gemini_api_key
|
||||||
|
# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
|
||||||
|
GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}'
|
||||||
|
# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000
|
||||||
|
# GEMINI_LLM_TEMPERATURE=0.7
|
||||||
|
|
||||||
### OpenAI Compatible API Specific Parameters
|
### OpenAI Compatible API Specific Parameters
|
||||||
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
|
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
|
||||||
# OPENAI_LLM_TEMPERATURE=0.9
|
# OPENAI_LLM_TEMPERATURE=0.9
|
||||||
|
|
|
||||||
|
|
@ -1,105 +0,0 @@
|
||||||
# pip install -q -U google-genai to use gemini as a client
|
|
||||||
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from lightrag.utils import EmbeddingFunc
|
|
||||||
from lightrag import LightRAG, QueryParam
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import nest_asyncio
|
|
||||||
|
|
||||||
# Apply nest_asyncio to solve event loop issues
|
|
||||||
nest_asyncio.apply()
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
|
||||||
|
|
||||||
if os.path.exists(WORKING_DIR):
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(WORKING_DIR)
|
|
||||||
|
|
||||||
os.mkdir(WORKING_DIR)
|
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
|
||||||
) -> str:
|
|
||||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
|
||||||
client = genai.Client(api_key=gemini_api_key)
|
|
||||||
|
|
||||||
# 2. Combine prompts: system prompt, history, and user prompt
|
|
||||||
if history_messages is None:
|
|
||||||
history_messages = []
|
|
||||||
|
|
||||||
combined_prompt = ""
|
|
||||||
if system_prompt:
|
|
||||||
combined_prompt += f"{system_prompt}\n"
|
|
||||||
|
|
||||||
for msg in history_messages:
|
|
||||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
|
||||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
|
||||||
|
|
||||||
# Finally, add the new user prompt
|
|
||||||
combined_prompt += f"user: {prompt}"
|
|
||||||
|
|
||||||
# 3. Call the Gemini model
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model="gemini-1.5-flash",
|
|
||||||
contents=[combined_prompt],
|
|
||||||
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Return the response text
|
|
||||||
return response.text
|
|
||||||
|
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
||||||
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
||||||
embeddings = model.encode(texts, convert_to_numpy=True)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_rag():
|
|
||||||
rag = LightRAG(
|
|
||||||
working_dir=WORKING_DIR,
|
|
||||||
llm_model_func=llm_model_func,
|
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=384,
|
|
||||||
max_token_size=8192,
|
|
||||||
func=embedding_func,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
await rag.initialize_storages()
|
|
||||||
await initialize_pipeline_status()
|
|
||||||
|
|
||||||
return rag
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Initialize RAG instance
|
|
||||||
rag = asyncio.run(initialize_rag())
|
|
||||||
file_path = "story.txt"
|
|
||||||
with open(file_path, "r") as file:
|
|
||||||
text = file.read()
|
|
||||||
|
|
||||||
rag.insert(text)
|
|
||||||
|
|
||||||
response = rag.query(
|
|
||||||
query="What is the main theme of the story?",
|
|
||||||
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,230 +0,0 @@
|
||||||
# pip install -q -U google-genai to use gemini as a client
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
import dataclasses
|
|
||||||
from pathlib import Path
|
|
||||||
import hashlib
|
|
||||||
import numpy as np
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from lightrag.utils import EmbeddingFunc, Tokenizer
|
|
||||||
from lightrag import LightRAG, QueryParam
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
|
||||||
import sentencepiece as spm
|
|
||||||
import requests
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import nest_asyncio
|
|
||||||
|
|
||||||
# Apply nest_asyncio to solve event loop issues
|
|
||||||
nest_asyncio.apply()
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
|
||||||
|
|
||||||
if os.path.exists(WORKING_DIR):
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(WORKING_DIR)
|
|
||||||
|
|
||||||
os.mkdir(WORKING_DIR)
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaTokenizer(Tokenizer):
|
|
||||||
# adapted from google-cloud-aiplatform[tokenization]
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class _TokenizerConfig:
|
|
||||||
tokenizer_model_url: str
|
|
||||||
tokenizer_model_hash: str
|
|
||||||
|
|
||||||
_TOKENIZERS = {
|
|
||||||
"google/gemma2": _TokenizerConfig(
|
|
||||||
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
|
|
||||||
tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
|
|
||||||
),
|
|
||||||
"google/gemma3": _TokenizerConfig(
|
|
||||||
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
|
||||||
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
|
|
||||||
):
|
|
||||||
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
|
||||||
if "1.5" in model_name or "1.0" in model_name:
|
|
||||||
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
|
||||||
# https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
|
|
||||||
tokenizer_name = "google/gemma2"
|
|
||||||
else:
|
|
||||||
# for gemini > 2.0 gemma3 was used
|
|
||||||
tokenizer_name = "google/gemma3"
|
|
||||||
|
|
||||||
file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
|
|
||||||
tokenizer_model_name = file_url.rsplit("/", 1)[1]
|
|
||||||
expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
|
|
||||||
|
|
||||||
tokenizer_dir = Path(tokenizer_dir)
|
|
||||||
if tokenizer_dir.is_dir():
|
|
||||||
file_path = tokenizer_dir / tokenizer_model_name
|
|
||||||
model_data = self._maybe_load_from_cache(
|
|
||||||
file_path=file_path, expected_hash=expected_hash
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_data = None
|
|
||||||
if not model_data:
|
|
||||||
model_data = self._load_from_url(
|
|
||||||
file_url=file_url, expected_hash=expected_hash
|
|
||||||
)
|
|
||||||
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
|
||||||
|
|
||||||
tokenizer = spm.SentencePieceProcessor()
|
|
||||||
tokenizer.LoadFromSerializedProto(model_data)
|
|
||||||
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
|
|
||||||
"""Returns true if the content is valid by checking the hash."""
|
|
||||||
return hashlib.sha256(model_data).hexdigest() == expected_hash
|
|
||||||
|
|
||||||
def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
|
|
||||||
"""Loads the model data from the cache path."""
|
|
||||||
if not file_path.is_file():
|
|
||||||
return
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
content = f.read()
|
|
||||||
if self._is_valid_model(model_data=content, expected_hash=expected_hash):
|
|
||||||
return content
|
|
||||||
|
|
||||||
# Cached file corrupted.
|
|
||||||
self._maybe_remove_file(file_path)
|
|
||||||
|
|
||||||
def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
|
|
||||||
"""Loads model bytes from the given file url."""
|
|
||||||
resp = requests.get(file_url)
|
|
||||||
resp.raise_for_status()
|
|
||||||
content = resp.content
|
|
||||||
|
|
||||||
if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
|
|
||||||
actual_hash = hashlib.sha256(content).hexdigest()
|
|
||||||
raise ValueError(
|
|
||||||
f"Downloaded model file is corrupted."
|
|
||||||
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
|
|
||||||
"""Saves the model data to the cache path."""
|
|
||||||
try:
|
|
||||||
if not cache_path.is_file():
|
|
||||||
cache_dir = cache_path.parent
|
|
||||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(cache_path, "wb") as f:
|
|
||||||
f.write(model_data)
|
|
||||||
except OSError:
|
|
||||||
# Don't raise if we cannot write file.
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _maybe_remove_file(file_path: Path) -> None:
|
|
||||||
"""Removes the file if exists."""
|
|
||||||
if not file_path.is_file():
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
file_path.unlink()
|
|
||||||
except OSError:
|
|
||||||
# Don't raise if we cannot remove file.
|
|
||||||
pass
|
|
||||||
|
|
||||||
# def encode(self, content: str) -> list[int]:
|
|
||||||
# return self.tokenizer.encode(content)
|
|
||||||
|
|
||||||
# def decode(self, tokens: list[int]) -> str:
|
|
||||||
# return self.tokenizer.decode(tokens)
|
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
|
||||||
) -> str:
|
|
||||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
|
||||||
client = genai.Client(api_key=gemini_api_key)
|
|
||||||
|
|
||||||
# 2. Combine prompts: system prompt, history, and user prompt
|
|
||||||
if history_messages is None:
|
|
||||||
history_messages = []
|
|
||||||
|
|
||||||
combined_prompt = ""
|
|
||||||
if system_prompt:
|
|
||||||
combined_prompt += f"{system_prompt}\n"
|
|
||||||
|
|
||||||
for msg in history_messages:
|
|
||||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
|
||||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
|
||||||
|
|
||||||
# Finally, add the new user prompt
|
|
||||||
combined_prompt += f"user: {prompt}"
|
|
||||||
|
|
||||||
# 3. Call the Gemini model
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model="gemini-1.5-flash",
|
|
||||||
contents=[combined_prompt],
|
|
||||||
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Return the response text
|
|
||||||
return response.text
|
|
||||||
|
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
||||||
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
||||||
embeddings = model.encode(texts, convert_to_numpy=True)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_rag():
|
|
||||||
rag = LightRAG(
|
|
||||||
working_dir=WORKING_DIR,
|
|
||||||
# tiktoken_model_name="gpt-4o-mini",
|
|
||||||
tokenizer=GemmaTokenizer(
|
|
||||||
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
|
|
||||||
model_name="gemini-2.0-flash",
|
|
||||||
),
|
|
||||||
llm_model_func=llm_model_func,
|
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=384,
|
|
||||||
max_token_size=8192,
|
|
||||||
func=embedding_func,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
await rag.initialize_storages()
|
|
||||||
await initialize_pipeline_status()
|
|
||||||
|
|
||||||
return rag
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Initialize RAG instance
|
|
||||||
rag = asyncio.run(initialize_rag())
|
|
||||||
file_path = "story.txt"
|
|
||||||
with open(file_path, "r") as file:
|
|
||||||
text = file.read()
|
|
||||||
|
|
||||||
rag.insert(text)
|
|
||||||
|
|
||||||
response = rag.query(
|
|
||||||
query="What is the main theme of the story?",
|
|
||||||
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,151 +0,0 @@
|
||||||
# pip install -q -U google-genai to use gemini as a client
|
|
||||||
|
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import numpy as np
|
|
||||||
import nest_asyncio
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from lightrag.utils import EmbeddingFunc
|
|
||||||
from lightrag import LightRAG, QueryParam
|
|
||||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
|
||||||
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
|
||||||
from lightrag.utils import setup_logger
|
|
||||||
from lightrag.utils import TokenTracker
|
|
||||||
|
|
||||||
setup_logger("lightrag", level="DEBUG")
|
|
||||||
|
|
||||||
# Apply nest_asyncio to solve event loop issues
|
|
||||||
nest_asyncio.apply()
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
|
||||||
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
|
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
|
||||||
|
|
||||||
if not os.path.exists(WORKING_DIR):
|
|
||||||
os.mkdir(WORKING_DIR)
|
|
||||||
|
|
||||||
token_tracker = TokenTracker()
|
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
|
||||||
) -> str:
|
|
||||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
|
||||||
client = genai.Client(api_key=gemini_api_key)
|
|
||||||
|
|
||||||
# 2. Combine prompts: system prompt, history, and user prompt
|
|
||||||
if history_messages is None:
|
|
||||||
history_messages = []
|
|
||||||
|
|
||||||
combined_prompt = ""
|
|
||||||
if system_prompt:
|
|
||||||
combined_prompt += f"{system_prompt}\n"
|
|
||||||
|
|
||||||
for msg in history_messages:
|
|
||||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
|
||||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
|
||||||
|
|
||||||
# Finally, add the new user prompt
|
|
||||||
combined_prompt += f"user: {prompt}"
|
|
||||||
|
|
||||||
# 3. Call the Gemini model
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model="gemini-2.0-flash",
|
|
||||||
contents=[combined_prompt],
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
max_output_tokens=5000, temperature=0, top_k=10
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Get token counts with null safety
|
|
||||||
usage = getattr(response, "usage_metadata", None)
|
|
||||||
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
||||||
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
||||||
total_tokens = getattr(usage, "total_token_count", 0) or (
|
|
||||||
prompt_tokens + completion_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
token_counts = {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": completion_tokens,
|
|
||||||
"total_tokens": total_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
token_tracker.add_usage(token_counts)
|
|
||||||
|
|
||||||
# 5. Return the response text
|
|
||||||
return response.text
|
|
||||||
|
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
||||||
return await siliconcloud_embedding(
|
|
||||||
texts,
|
|
||||||
model="BAAI/bge-m3",
|
|
||||||
api_key=siliconflow_api_key,
|
|
||||||
max_token_size=512,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_rag():
|
|
||||||
rag = LightRAG(
|
|
||||||
working_dir=WORKING_DIR,
|
|
||||||
entity_extract_max_gleaning=1,
|
|
||||||
enable_llm_cache=True,
|
|
||||||
enable_llm_cache_for_entity_extract=True,
|
|
||||||
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
|
|
||||||
llm_model_func=llm_model_func,
|
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=1024,
|
|
||||||
max_token_size=8192,
|
|
||||||
func=embedding_func,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
await rag.initialize_storages()
|
|
||||||
await initialize_pipeline_status()
|
|
||||||
|
|
||||||
return rag
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Initialize RAG instance
|
|
||||||
rag = asyncio.run(initialize_rag())
|
|
||||||
|
|
||||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
|
||||||
rag.insert(f.read())
|
|
||||||
|
|
||||||
# Context Manager Method
|
|
||||||
with token_tracker:
|
|
||||||
print(
|
|
||||||
rag.query(
|
|
||||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
rag.query(
|
|
||||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
rag.query(
|
|
||||||
"What are the top themes in this story?",
|
|
||||||
param=QueryParam(mode="global"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
rag.query(
|
|
||||||
"What are the top themes in this story?",
|
|
||||||
param=QueryParam(mode="hybrid"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -50,6 +50,7 @@ LightRAG necessitates the integration of both an LLM (Large Language Model) and
|
||||||
* openai or openai compatible
|
* openai or openai compatible
|
||||||
* azure_openai
|
* azure_openai
|
||||||
* aws_bedrock
|
* aws_bedrock
|
||||||
|
* gemini
|
||||||
|
|
||||||
It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**.
|
It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**.
|
||||||
|
|
||||||
|
|
@ -72,6 +73,8 @@ EMBEDDING_DIM=1024
|
||||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`).
|
||||||
|
|
||||||
* Ollama LLM + Ollama Embedding:
|
* Ollama LLM + Ollama Embedding:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
__api_version__ = "0250"
|
__api_version__ = "0251"
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from lightrag.utils import get_env_value
|
from lightrag.utils import get_env_value
|
||||||
from lightrag.llm.binding_options import (
|
from lightrag.llm.binding_options import (
|
||||||
|
GeminiLLMOptions,
|
||||||
OllamaEmbeddingOptions,
|
OllamaEmbeddingOptions,
|
||||||
OllamaLLMOptions,
|
OllamaLLMOptions,
|
||||||
OpenAILLMOptions,
|
OpenAILLMOptions,
|
||||||
|
|
@ -63,6 +64,9 @@ def get_default_host(binding_type: str) -> str:
|
||||||
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||||
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||||
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||||
|
"gemini": os.getenv(
|
||||||
|
"LLM_BINDING_HOST", "https://generativelanguage.googleapis.com"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
return default_hosts.get(
|
return default_hosts.get(
|
||||||
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||||
|
|
@ -226,6 +230,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"openai-ollama",
|
"openai-ollama",
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
"aws_bedrock",
|
"aws_bedrock",
|
||||||
|
"gemini",
|
||||||
],
|
],
|
||||||
help="LLM binding type (default: from env or ollama)",
|
help="LLM binding type (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
|
|
@ -281,6 +286,16 @@ def parse_args() -> argparse.Namespace:
|
||||||
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
||||||
OpenAILLMOptions.add_args(parser)
|
OpenAILLMOptions.add_args(parser)
|
||||||
|
|
||||||
|
if "--llm-binding" in sys.argv:
|
||||||
|
try:
|
||||||
|
idx = sys.argv.index("--llm-binding")
|
||||||
|
if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "gemini":
|
||||||
|
GeminiLLMOptions.add_args(parser)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
elif os.environ.get("LLM_BINDING") == "gemini":
|
||||||
|
GeminiLLMOptions.add_args(parser)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# convert relative path to absolute path
|
# convert relative path to absolute path
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,7 @@ class LLMConfigCache:
|
||||||
|
|
||||||
# Initialize configurations based on binding conditions
|
# Initialize configurations based on binding conditions
|
||||||
self.openai_llm_options = None
|
self.openai_llm_options = None
|
||||||
|
self.gemini_llm_options = None
|
||||||
self.ollama_llm_options = None
|
self.ollama_llm_options = None
|
||||||
self.ollama_embedding_options = None
|
self.ollama_embedding_options = None
|
||||||
|
|
||||||
|
|
@ -99,6 +100,12 @@ class LLMConfigCache:
|
||||||
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
|
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
|
||||||
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
|
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
|
||||||
|
|
||||||
|
if args.llm_binding == "gemini":
|
||||||
|
from lightrag.llm.binding_options import GeminiLLMOptions
|
||||||
|
|
||||||
|
self.gemini_llm_options = GeminiLLMOptions.options_dict(args)
|
||||||
|
logger.info(f"Gemini LLM Options: {self.gemini_llm_options}")
|
||||||
|
|
||||||
# Only initialize and log Ollama LLM options when using Ollama LLM binding
|
# Only initialize and log Ollama LLM options when using Ollama LLM binding
|
||||||
if args.llm_binding == "ollama":
|
if args.llm_binding == "ollama":
|
||||||
try:
|
try:
|
||||||
|
|
@ -279,6 +286,7 @@ def create_app(args):
|
||||||
"openai",
|
"openai",
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
"aws_bedrock",
|
"aws_bedrock",
|
||||||
|
"gemini",
|
||||||
]:
|
]:
|
||||||
raise Exception("llm binding not supported")
|
raise Exception("llm binding not supported")
|
||||||
|
|
||||||
|
|
@ -504,6 +512,44 @@ def create_app(args):
|
||||||
|
|
||||||
return optimized_azure_openai_model_complete
|
return optimized_azure_openai_model_complete
|
||||||
|
|
||||||
|
def create_optimized_gemini_llm_func(
|
||||||
|
config_cache: LLMConfigCache, args, llm_timeout: int
|
||||||
|
):
|
||||||
|
"""Create optimized Gemini LLM function with cached configuration"""
|
||||||
|
|
||||||
|
async def optimized_gemini_model_complete(
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=None,
|
||||||
|
keyword_extraction=False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
from lightrag.llm.gemini import gemini_complete_if_cache
|
||||||
|
|
||||||
|
if history_messages is None:
|
||||||
|
history_messages = []
|
||||||
|
|
||||||
|
# Use pre-processed configuration to avoid repeated parsing
|
||||||
|
kwargs["timeout"] = llm_timeout
|
||||||
|
if (
|
||||||
|
config_cache.gemini_llm_options is not None
|
||||||
|
and "generation_config" not in kwargs
|
||||||
|
):
|
||||||
|
kwargs["generation_config"] = dict(config_cache.gemini_llm_options)
|
||||||
|
|
||||||
|
return await gemini_complete_if_cache(
|
||||||
|
args.llm_model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=args.llm_binding_api_key,
|
||||||
|
base_url=args.llm_binding_host,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimized_gemini_model_complete
|
||||||
|
|
||||||
def create_llm_model_func(binding: str):
|
def create_llm_model_func(binding: str):
|
||||||
"""
|
"""
|
||||||
Create LLM model function based on binding type.
|
Create LLM model function based on binding type.
|
||||||
|
|
@ -525,6 +571,8 @@ def create_app(args):
|
||||||
return create_optimized_azure_openai_llm_func(
|
return create_optimized_azure_openai_llm_func(
|
||||||
config_cache, args, llm_timeout
|
config_cache, args, llm_timeout
|
||||||
)
|
)
|
||||||
|
elif binding == "gemini":
|
||||||
|
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
# Use optimized function with pre-processed configuration
|
# Use optimized function with pre-processed configuration
|
||||||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ from typing import (
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
from .types import KnowledgeGraph
|
from .types import KnowledgeGraph
|
||||||
from .constants import (
|
from .constants import (
|
||||||
GRAPH_FIELD_SEP,
|
|
||||||
DEFAULT_TOP_K,
|
DEFAULT_TOP_K,
|
||||||
DEFAULT_CHUNK_TOP_K,
|
DEFAULT_CHUNK_TOP_K,
|
||||||
DEFAULT_MAX_ENTITY_TOKENS,
|
DEFAULT_MAX_ENTITY_TOKENS,
|
||||||
|
|
@ -528,56 +527,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||||
result[node_id] = edges if edges is not None else []
|
result[node_id] = edges if edges is not None else []
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all nodes that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
|
||||||
An empty list if no matching nodes are found.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all edges that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
|
||||||
An empty list if no matching edges are found.
|
|
||||||
"""
|
|
||||||
# Default implementation iterates through all nodes and their edges, which is inefficient.
|
|
||||||
# This method should be overridden by subclasses for better performance.
|
|
||||||
all_edges = []
|
|
||||||
all_labels = await self.get_all_labels()
|
|
||||||
processed_edges = set()
|
|
||||||
|
|
||||||
for label in all_labels:
|
|
||||||
edges = await self.get_node_edges(label)
|
|
||||||
if edges:
|
|
||||||
for src_id, tgt_id in edges:
|
|
||||||
# Avoid processing the same edge twice in an undirected graph
|
|
||||||
edge_tuple = tuple(sorted((src_id, tgt_id)))
|
|
||||||
if edge_tuple in processed_edges:
|
|
||||||
continue
|
|
||||||
processed_edges.add(edge_tuple)
|
|
||||||
|
|
||||||
edge = await self.get_edge(src_id, tgt_id)
|
|
||||||
if edge and "source_id" in edge:
|
|
||||||
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
|
|
||||||
if not source_ids.isdisjoint(chunk_ids):
|
|
||||||
# Add source and target to the edge dict for easier processing later
|
|
||||||
edge_with_nodes = edge.copy()
|
|
||||||
edge_with_nodes["source"] = src_id
|
|
||||||
edge_with_nodes["target"] = tgt_id
|
|
||||||
all_edges.append(edge_with_nodes)
|
|
||||||
return all_edges
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""Insert a new node or update an existing node in the graph.
|
"""Insert a new node or update an existing node in the graph.
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import configparser
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -784,79 +783,6 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all nodes that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids: List of chunk IDs to find associated nodes for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
|
||||||
An empty list if no matching nodes are found.
|
|
||||||
"""
|
|
||||||
if self._driver is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
|
||||||
)
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (n:`{workspace_label}`)
|
|
||||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
|
||||||
RETURN DISTINCT n
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
nodes = []
|
|
||||||
async for record in result:
|
|
||||||
node = record["n"]
|
|
||||||
node_dict = dict(node)
|
|
||||||
node_dict["id"] = node_dict.get("entity_id")
|
|
||||||
nodes.append(node_dict)
|
|
||||||
await result.consume()
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all edges that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids: List of chunk IDs to find associated edges for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
|
||||||
An empty list if no matching edges are found.
|
|
||||||
"""
|
|
||||||
if self._driver is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
|
||||||
)
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
|
||||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
|
||||||
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
|
|
||||||
// Ensure we only return each unique edge once by ordering the source and target
|
|
||||||
WITH a, b, r,
|
|
||||||
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
|
|
||||||
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
|
|
||||||
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
edges = []
|
|
||||||
async for record in result:
|
|
||||||
edge_properties = record["properties"]
|
|
||||||
edge_properties["source"] = record["source"]
|
|
||||||
edge_properties["target"] = record["target"]
|
|
||||||
edges.append(edge_properties)
|
|
||||||
await result.consume()
|
|
||||||
return edges
|
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
|
|
|
||||||
|
|
@ -1036,45 +1036,6 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all nodes that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
|
||||||
An empty list if no matching nodes are found.
|
|
||||||
"""
|
|
||||||
if not chunk_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
|
|
||||||
return [doc async for doc in cursor]
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all edges that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
|
||||||
An empty list if no matching edges are found.
|
|
||||||
"""
|
|
||||||
if not chunk_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
|
|
||||||
|
|
||||||
edges = []
|
|
||||||
async for edge in cursor:
|
|
||||||
edge["source"] = edge["source_node_id"]
|
|
||||||
edge["target"] = edge["target_node_id"]
|
|
||||||
edges.append(edge)
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# UPSERTS
|
# UPSERTS
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ import logging
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -904,49 +903,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure results are fully consumed
|
await result.consume() # Ensure results are fully consumed
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (n:`{workspace_label}`)
|
|
||||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
|
||||||
RETURN DISTINCT n
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
nodes = []
|
|
||||||
async for record in result:
|
|
||||||
node = record["n"]
|
|
||||||
node_dict = dict(node)
|
|
||||||
# Add node id (entity_id) to the dictionary for easier access
|
|
||||||
node_dict["id"] = node_dict.get("entity_id")
|
|
||||||
nodes.append(node_dict)
|
|
||||||
await result.consume()
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
|
||||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
|
||||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
edges = []
|
|
||||||
async for record in result:
|
|
||||||
edge_properties = record["properties"]
|
|
||||||
edge_properties["source"] = record["source"]
|
|
||||||
edge_properties["target"] = record["target"]
|
|
||||||
edges.append(edge_properties)
|
|
||||||
await result.consume()
|
|
||||||
return edges
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from typing import final
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
from lightrag.base import BaseGraphStorage
|
from lightrag.base import BaseGraphStorage
|
||||||
from lightrag.constants import GRAPH_FIELD_SEP
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
get_storage_lock,
|
get_storage_lock,
|
||||||
|
|
@ -470,33 +469,6 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
chunk_ids_set = set(chunk_ids)
|
|
||||||
graph = await self._get_graph()
|
|
||||||
matching_nodes = []
|
|
||||||
for node_id, node_data in graph.nodes(data=True):
|
|
||||||
if "source_id" in node_data:
|
|
||||||
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
|
||||||
if not node_source_ids.isdisjoint(chunk_ids_set):
|
|
||||||
node_data_with_id = node_data.copy()
|
|
||||||
node_data_with_id["id"] = node_id
|
|
||||||
matching_nodes.append(node_data_with_id)
|
|
||||||
return matching_nodes
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
chunk_ids_set = set(chunk_ids)
|
|
||||||
graph = await self._get_graph()
|
|
||||||
matching_edges = []
|
|
||||||
for u, v, edge_data in graph.edges(data=True):
|
|
||||||
if "source_id" in edge_data:
|
|
||||||
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
|
|
||||||
if not edge_source_ids.isdisjoint(chunk_ids_set):
|
|
||||||
edge_data_with_nodes = edge_data.copy()
|
|
||||||
edge_data_with_nodes["source"] = u
|
|
||||||
edge_data_with_nodes["target"] = v
|
|
||||||
matching_edges.append(edge_data_with_nodes)
|
|
||||||
return matching_edges
|
|
||||||
|
|
||||||
async def get_all_nodes(self) -> list[dict]:
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
"""Get all nodes in the graph.
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@ from ..base import (
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
|
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
@ -3569,17 +3568,13 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier, return only node properties"""
|
"""Get node by its label identifier, return only node properties"""
|
||||||
|
|
||||||
label = self._normalize_node_id(node_id)
|
result = await self.get_nodes_batch(node_ids=[node_id])
|
||||||
|
|
||||||
result = await self.get_nodes_batch(node_ids=[label])
|
|
||||||
if result and node_id in result:
|
if result and node_id in result:
|
||||||
return result[node_id]
|
return result[node_id]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
label = self._normalize_node_id(node_id)
|
result = await self.node_degrees_batch(node_ids=[node_id])
|
||||||
|
|
||||||
result = await self.node_degrees_batch(node_ids=[label])
|
|
||||||
if result and node_id in result:
|
if result and node_id in result:
|
||||||
return result[node_id]
|
return result[node_id]
|
||||||
|
|
||||||
|
|
@ -3592,12 +3587,11 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
"""Get edge properties between two nodes"""
|
"""Get edge properties between two nodes"""
|
||||||
src_label = self._normalize_node_id(source_node_id)
|
result = await self.get_edges_batch(
|
||||||
tgt_label = self._normalize_node_id(target_node_id)
|
[{"src": source_node_id, "tgt": target_node_id}]
|
||||||
|
)
|
||||||
result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
|
if result and (source_node_id, target_node_id) in result:
|
||||||
if result and (src_label, tgt_label) in result:
|
return result[(source_node_id, target_node_id)]
|
||||||
return result[(src_label, tgt_label)]
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
|
|
@ -3795,13 +3789,17 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
if not node_ids:
|
if not node_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
seen = set()
|
seen: set[str] = set()
|
||||||
unique_ids = []
|
unique_ids: list[str] = []
|
||||||
|
lookup: dict[str, str] = {}
|
||||||
|
requested: set[str] = set()
|
||||||
for nid in node_ids:
|
for nid in node_ids:
|
||||||
nid_norm = self._normalize_node_id(nid)
|
if nid not in seen:
|
||||||
if nid_norm not in seen:
|
seen.add(nid)
|
||||||
seen.add(nid_norm)
|
unique_ids.append(nid)
|
||||||
unique_ids.append(nid_norm)
|
requested.add(nid)
|
||||||
|
lookup[nid] = nid
|
||||||
|
lookup[self._normalize_node_id(nid)] = nid
|
||||||
|
|
||||||
# Build result dictionary
|
# Build result dictionary
|
||||||
nodes_dict = {}
|
nodes_dict = {}
|
||||||
|
|
@ -3840,10 +3838,18 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
node_dict = json.loads(node_dict)
|
node_dict = json.loads(node_dict)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to parse node string in batch: {node_dict}"
|
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes_dict[result["node_id"]] = node_dict
|
node_key = result["node_id"]
|
||||||
|
original_key = lookup.get(node_key)
|
||||||
|
if original_key is None:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.workspace}] Node {node_key} not found in lookup map"
|
||||||
|
)
|
||||||
|
original_key = node_key
|
||||||
|
if original_key in requested:
|
||||||
|
nodes_dict[original_key] = node_dict
|
||||||
|
|
||||||
return nodes_dict
|
return nodes_dict
|
||||||
|
|
||||||
|
|
@ -3866,13 +3872,17 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
if not node_ids:
|
if not node_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
seen = set()
|
seen: set[str] = set()
|
||||||
unique_ids: list[str] = []
|
unique_ids: list[str] = []
|
||||||
|
lookup: dict[str, str] = {}
|
||||||
|
requested: set[str] = set()
|
||||||
for nid in node_ids:
|
for nid in node_ids:
|
||||||
n = self._normalize_node_id(nid)
|
if nid not in seen:
|
||||||
if n not in seen:
|
seen.add(nid)
|
||||||
seen.add(n)
|
unique_ids.append(nid)
|
||||||
unique_ids.append(n)
|
requested.add(nid)
|
||||||
|
lookup[nid] = nid
|
||||||
|
lookup[self._normalize_node_id(nid)] = nid
|
||||||
|
|
||||||
out_degrees = {}
|
out_degrees = {}
|
||||||
in_degrees = {}
|
in_degrees = {}
|
||||||
|
|
@ -3924,8 +3934,16 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
node_id = row["node_id"]
|
node_id = row["node_id"]
|
||||||
if not node_id:
|
if not node_id:
|
||||||
continue
|
continue
|
||||||
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
|
node_key = node_id
|
||||||
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
|
original_key = lookup.get(node_key)
|
||||||
|
if original_key is None:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.workspace}] Node {node_key} not found in lookup map"
|
||||||
|
)
|
||||||
|
original_key = node_key
|
||||||
|
if original_key in requested:
|
||||||
|
out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
|
||||||
|
in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
|
||||||
|
|
||||||
degrees_dict = {}
|
degrees_dict = {}
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
|
|
@ -4054,7 +4072,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
edge_props = json.loads(edge_props)
|
edge_props = json.loads(edge_props)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to parse edge properties string: {edge_props}"
|
f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -4070,7 +4088,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
edge_props = json.loads(edge_props)
|
edge_props = json.loads(edge_props)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to parse edge properties string: {edge_props}"
|
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -4175,102 +4193,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
labels.append(result["label"])
|
labels.append(result["label"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
|
|
||||||
This method uses a Cypher query with UNWIND to efficiently find all nodes
|
|
||||||
where the `source_id` property contains any of the specified chunk IDs.
|
|
||||||
"""
|
|
||||||
# The string representation of the list for the cypher query
|
|
||||||
chunk_ids_str = json.dumps(chunk_ids)
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT * FROM cypher('{self.graph_name}', $$
|
|
||||||
UNWIND {chunk_ids_str} AS chunk_id
|
|
||||||
MATCH (n:base)
|
|
||||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
|
|
||||||
RETURN n
|
|
||||||
$$) AS (n agtype);
|
|
||||||
"""
|
|
||||||
results = await self._query(query)
|
|
||||||
|
|
||||||
# Build result list
|
|
||||||
nodes = []
|
|
||||||
for result in results:
|
|
||||||
if result["n"]:
|
|
||||||
node_dict = result["n"]["properties"]
|
|
||||||
|
|
||||||
# Process string result, parse it to JSON dictionary
|
|
||||||
if isinstance(node_dict, str):
|
|
||||||
try:
|
|
||||||
node_dict = json.loads(node_dict)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(
|
|
||||||
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
|
||||||
)
|
|
||||||
|
|
||||||
node_dict["id"] = node_dict["entity_id"]
|
|
||||||
nodes.append(node_dict)
|
|
||||||
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Retrieves edges from the graph that are associated with a given list of chunk IDs.
|
|
||||||
This method uses a Cypher query with UNWIND to efficiently find all edges
|
|
||||||
where the `source_id` property contains any of the specified chunk IDs.
|
|
||||||
"""
|
|
||||||
chunk_ids_str = json.dumps(chunk_ids)
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT * FROM cypher('{self.graph_name}', $$
|
|
||||||
UNWIND {chunk_ids_str} AS chunk_id
|
|
||||||
MATCH ()-[r]-()
|
|
||||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
|
|
||||||
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
|
|
||||||
$$) AS (edge agtype, source agtype, target agtype);
|
|
||||||
"""
|
|
||||||
results = await self._query(query)
|
|
||||||
edges = []
|
|
||||||
if results:
|
|
||||||
for item in results:
|
|
||||||
edge_agtype = item["edge"]["properties"]
|
|
||||||
# Process string result, parse it to JSON dictionary
|
|
||||||
if isinstance(edge_agtype, str):
|
|
||||||
try:
|
|
||||||
edge_agtype = json.loads(edge_agtype)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(
|
|
||||||
f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
|
|
||||||
)
|
|
||||||
|
|
||||||
source_agtype = item["source"]["properties"]
|
|
||||||
# Process string result, parse it to JSON dictionary
|
|
||||||
if isinstance(source_agtype, str):
|
|
||||||
try:
|
|
||||||
source_agtype = json.loads(source_agtype)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(
|
|
||||||
f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
|
|
||||||
)
|
|
||||||
|
|
||||||
target_agtype = item["target"]["properties"]
|
|
||||||
# Process string result, parse it to JSON dictionary
|
|
||||||
if isinstance(target_agtype, str):
|
|
||||||
try:
|
|
||||||
target_agtype = json.loads(target_agtype)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(
|
|
||||||
f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if edge_agtype and source_agtype and target_agtype:
|
|
||||||
edge_properties = edge_agtype
|
|
||||||
edge_properties["source"] = source_agtype["entity_id"]
|
|
||||||
edge_properties["target"] = target_agtype["entity_id"]
|
|
||||||
edges.append(edge_properties)
|
|
||||||
return edges
|
|
||||||
|
|
||||||
async def _bfs_subgraph(
|
async def _bfs_subgraph(
|
||||||
self, node_label: str, max_depth: int, max_nodes: int
|
self, node_label: str, max_depth: int, max_nodes: int
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
|
|
|
||||||
|
|
@ -3235,38 +3235,31 @@ class LightRAG:
|
||||||
|
|
||||||
if entity_chunk_updates and self.entity_chunks:
|
if entity_chunk_updates and self.entity_chunks:
|
||||||
entity_upsert_payload = {}
|
entity_upsert_payload = {}
|
||||||
entity_delete_ids: set[str] = set()
|
|
||||||
for entity_name, remaining in entity_chunk_updates.items():
|
for entity_name, remaining in entity_chunk_updates.items():
|
||||||
if not remaining:
|
if not remaining:
|
||||||
entity_delete_ids.add(entity_name)
|
# Empty entities are deleted alongside graph nodes later
|
||||||
else:
|
continue
|
||||||
entity_upsert_payload[entity_name] = {
|
entity_upsert_payload[entity_name] = {
|
||||||
"chunk_ids": remaining,
|
"chunk_ids": remaining,
|
||||||
"count": len(remaining),
|
"count": len(remaining),
|
||||||
"updated_at": current_time,
|
"updated_at": current_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
if entity_delete_ids:
|
|
||||||
await self.entity_chunks.delete(list(entity_delete_ids))
|
|
||||||
if entity_upsert_payload:
|
if entity_upsert_payload:
|
||||||
await self.entity_chunks.upsert(entity_upsert_payload)
|
await self.entity_chunks.upsert(entity_upsert_payload)
|
||||||
|
|
||||||
if relation_chunk_updates and self.relation_chunks:
|
if relation_chunk_updates and self.relation_chunks:
|
||||||
relation_upsert_payload = {}
|
relation_upsert_payload = {}
|
||||||
relation_delete_ids: set[str] = set()
|
|
||||||
for edge_tuple, remaining in relation_chunk_updates.items():
|
for edge_tuple, remaining in relation_chunk_updates.items():
|
||||||
storage_key = make_relation_chunk_key(*edge_tuple)
|
|
||||||
if not remaining:
|
if not remaining:
|
||||||
relation_delete_ids.add(storage_key)
|
# Empty relations are deleted alongside graph edges later
|
||||||
else:
|
continue
|
||||||
relation_upsert_payload[storage_key] = {
|
storage_key = make_relation_chunk_key(*edge_tuple)
|
||||||
"chunk_ids": remaining,
|
relation_upsert_payload[storage_key] = {
|
||||||
"count": len(remaining),
|
"chunk_ids": remaining,
|
||||||
"updated_at": current_time,
|
"count": len(remaining),
|
||||||
}
|
"updated_at": current_time,
|
||||||
|
}
|
||||||
|
|
||||||
if relation_delete_ids:
|
|
||||||
await self.relation_chunks.delete(list(relation_delete_ids))
|
|
||||||
if relation_upsert_payload:
|
if relation_upsert_payload:
|
||||||
await self.relation_chunks.upsert(relation_upsert_payload)
|
await self.relation_chunks.upsert(relation_upsert_payload)
|
||||||
|
|
||||||
|
|
@ -3296,7 +3289,7 @@ class LightRAG:
|
||||||
# 6. Delete relationships that have no remaining sources
|
# 6. Delete relationships that have no remaining sources
|
||||||
if relationships_to_delete:
|
if relationships_to_delete:
|
||||||
try:
|
try:
|
||||||
# Delete from vector database
|
# Delete from relation vdb
|
||||||
rel_ids_to_delete = []
|
rel_ids_to_delete = []
|
||||||
for src, tgt in relationships_to_delete:
|
for src, tgt in relationships_to_delete:
|
||||||
rel_ids_to_delete.extend(
|
rel_ids_to_delete.extend(
|
||||||
|
|
@ -3333,15 +3326,16 @@ class LightRAG:
|
||||||
# 7. Delete entities that have no remaining sources
|
# 7. Delete entities that have no remaining sources
|
||||||
if entities_to_delete:
|
if entities_to_delete:
|
||||||
try:
|
try:
|
||||||
|
# Batch get all edges for entities to avoid N+1 query problem
|
||||||
|
nodes_edges_dict = await self.chunk_entity_relation_graph.get_nodes_edges_batch(
|
||||||
|
list(entities_to_delete)
|
||||||
|
)
|
||||||
|
|
||||||
# Debug: Check and log all edges before deleting nodes
|
# Debug: Check and log all edges before deleting nodes
|
||||||
edges_to_delete = set()
|
edges_to_delete = set()
|
||||||
edges_still_exist = 0
|
edges_still_exist = 0
|
||||||
for entity in entities_to_delete:
|
|
||||||
edges = (
|
for entity, edges in nodes_edges_dict.items():
|
||||||
await self.chunk_entity_relation_graph.get_node_edges(
|
|
||||||
entity
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if edges:
|
if edges:
|
||||||
for src, tgt in edges:
|
for src, tgt in edges:
|
||||||
# Normalize edge representation (sorted for consistency)
|
# Normalize edge representation (sorted for consistency)
|
||||||
|
|
@ -3364,6 +3358,7 @@ class LightRAG:
|
||||||
f"Edge still exists: {src} <-- {tgt}"
|
f"Edge still exists: {src} <-- {tgt}"
|
||||||
)
|
)
|
||||||
edges_still_exist += 1
|
edges_still_exist += 1
|
||||||
|
|
||||||
if edges_still_exist:
|
if edges_still_exist:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"⚠️ {edges_still_exist} entities still has edges before deletion"
|
f"⚠️ {edges_still_exist} entities still has edges before deletion"
|
||||||
|
|
@ -3399,7 +3394,7 @@ class LightRAG:
|
||||||
list(entities_to_delete)
|
list(entities_to_delete)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete from vector database
|
# Delete from vector vdb
|
||||||
entity_vdb_ids = [
|
entity_vdb_ids = [
|
||||||
compute_mdhash_id(entity, prefix="ent-")
|
compute_mdhash_id(entity, prefix="ent-")
|
||||||
for entity in entities_to_delete
|
for entity in entities_to_delete
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, ClassVar, List
|
from typing import Any, ClassVar, List, get_args, get_origin
|
||||||
|
|
||||||
from lightrag.utils import get_env_value
|
from lightrag.utils import get_env_value
|
||||||
from lightrag.constants import DEFAULT_TEMPERATURE
|
from lightrag.constants import DEFAULT_TEMPERATURE
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_optional_type(field_type: Any) -> Any:
|
||||||
|
"""Return the concrete type for Optional/Union annotations."""
|
||||||
|
origin = get_origin(field_type)
|
||||||
|
if origin in (list, dict, tuple):
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
args = get_args(field_type)
|
||||||
|
if args:
|
||||||
|
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||||
|
if len(non_none_args) == 1:
|
||||||
|
return non_none_args[0]
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# BindingOptions Base Class
|
# BindingOptions Base Class
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -177,9 +191,13 @@ class BindingOptions:
|
||||||
help=arg_item["help"],
|
help=arg_item["help"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
resolved_type = arg_item["type"]
|
||||||
|
if resolved_type is not None:
|
||||||
|
resolved_type = _resolve_optional_type(resolved_type)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
f"--{arg_item['argname']}",
|
f"--{arg_item['argname']}",
|
||||||
type=arg_item["type"],
|
type=resolved_type,
|
||||||
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
||||||
help=arg_item["help"],
|
help=arg_item["help"],
|
||||||
)
|
)
|
||||||
|
|
@ -210,7 +228,7 @@ class BindingOptions:
|
||||||
argdef = {
|
argdef = {
|
||||||
"argname": f"{args_prefix}-{field.name}",
|
"argname": f"{args_prefix}-{field.name}",
|
||||||
"env_name": f"{env_var_prefix}{field.name.upper()}",
|
"env_name": f"{env_var_prefix}{field.name.upper()}",
|
||||||
"type": field.type,
|
"type": _resolve_optional_type(field.type),
|
||||||
"default": default_value,
|
"default": default_value,
|
||||||
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
|
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
|
||||||
}
|
}
|
||||||
|
|
@ -454,6 +472,39 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
|
||||||
_binding_name: ClassVar[str] = "ollama_llm"
|
_binding_name: ClassVar[str] = "ollama_llm"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeminiLLMOptions(BindingOptions):
|
||||||
|
"""Options for Google Gemini models."""
|
||||||
|
|
||||||
|
_binding_name: ClassVar[str] = "gemini_llm"
|
||||||
|
|
||||||
|
temperature: float = DEFAULT_TEMPERATURE
|
||||||
|
top_p: float = 0.95
|
||||||
|
top_k: int = 40
|
||||||
|
max_output_tokens: int | None = None
|
||||||
|
candidate_count: int = 1
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
stop_sequences: List[str] = field(default_factory=list)
|
||||||
|
seed: int | None = None
|
||||||
|
thinking_config: dict | None = None
|
||||||
|
safety_settings: dict | None = None
|
||||||
|
|
||||||
|
_help: ClassVar[dict[str, str]] = {
|
||||||
|
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
|
||||||
|
"top_p": "Nucleus sampling parameter (0.0-1.0)",
|
||||||
|
"top_k": "Limits sampling to the top K tokens (1 disables the limit)",
|
||||||
|
"max_output_tokens": "Maximum tokens generated in the response",
|
||||||
|
"candidate_count": "Number of candidates returned per request",
|
||||||
|
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
|
||||||
|
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
|
||||||
|
"stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
|
||||||
|
"seed": "Random seed for reproducible generation (leave empty for random)",
|
||||||
|
"thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
|
||||||
|
"safety_settings": "JSON object with Gemini safety settings overrides",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Binding Options for OpenAI
|
# Binding Options for OpenAI
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
422
lightrag/llm/gemini.py
Normal file
422
lightrag/llm/gemini.py
Normal file
|
|
@ -0,0 +1,422 @@
|
||||||
|
"""
|
||||||
|
Gemini LLM binding for LightRAG.
|
||||||
|
|
||||||
|
This module provides asynchronous helpers that adapt Google's Gemini models
|
||||||
|
to the same interface used by the rest of the LightRAG LLM bindings. The
|
||||||
|
implementation mirrors the OpenAI helpers while relying on the official
|
||||||
|
``google-genai`` client under the hood.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
|
||||||
|
|
||||||
|
import pipmaster as pm
|
||||||
|
|
||||||
|
# Install the Google Gemini client on demand
|
||||||
|
if not pm.is_installed("google-genai"):
|
||||||
|
pm.install("google-genai")
|
||||||
|
|
||||||
|
from google import genai # type: ignore
|
||||||
|
from google.genai import types # type: ignore
|
||||||
|
|
||||||
|
DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def _get_gemini_client(
|
||||||
|
api_key: str, base_url: str | None, timeout: int | None = None
|
||||||
|
) -> genai.Client:
|
||||||
|
"""
|
||||||
|
Create (or fetch cached) Gemini client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Google Gemini API key.
|
||||||
|
base_url: Optional custom API endpoint.
|
||||||
|
timeout: Optional request timeout in milliseconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
genai.Client: Configured Gemini client instance.
|
||||||
|
"""
|
||||||
|
client_kwargs: dict[str, Any] = {"api_key": api_key}
|
||||||
|
|
||||||
|
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None:
|
||||||
|
try:
|
||||||
|
http_options_kwargs = {}
|
||||||
|
if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
|
||||||
|
http_options_kwargs["api_endpoint"] = base_url
|
||||||
|
if timeout is not None:
|
||||||
|
http_options_kwargs["timeout"] = timeout
|
||||||
|
|
||||||
|
client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
LOG.warning("Failed to apply custom Gemini http_options: %s", exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return genai.Client(**client_kwargs)
|
||||||
|
except TypeError:
|
||||||
|
# Older google-genai releases don't accept http_options; retry without it.
|
||||||
|
client_kwargs.pop("http_options", None)
|
||||||
|
return genai.Client(**client_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_api_key(api_key: str | None) -> str:
|
||||||
|
key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||||
|
if not key:
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini API key not provided. "
|
||||||
|
"Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def _build_generation_config(
|
||||||
|
base_config: dict[str, Any] | None,
|
||||||
|
system_prompt: str | None,
|
||||||
|
keyword_extraction: bool,
|
||||||
|
) -> types.GenerateContentConfig | None:
|
||||||
|
config_data = dict(base_config or {})
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
if config_data.get("system_instruction"):
|
||||||
|
config_data["system_instruction"] = (
|
||||||
|
f"{config_data['system_instruction']}\n{system_prompt}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config_data["system_instruction"] = system_prompt
|
||||||
|
|
||||||
|
if keyword_extraction and not config_data.get("response_mime_type"):
|
||||||
|
config_data["response_mime_type"] = "application/json"
|
||||||
|
|
||||||
|
# Remove entries that are explicitly set to None to avoid type errors
|
||||||
|
sanitized = {
|
||||||
|
key: value
|
||||||
|
for key, value in config_data.items()
|
||||||
|
if value is not None and value != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if not sanitized:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return types.GenerateContentConfig(**sanitized)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
|
||||||
|
if not history_messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
history_lines: list[str] = []
|
||||||
|
for message in history_messages:
|
||||||
|
role = message.get("role", "user")
|
||||||
|
content = message.get("content", "")
|
||||||
|
history_lines.append(f"[{role}] {content}")
|
||||||
|
|
||||||
|
return "\n".join(history_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_text(
|
||||||
|
response: Any, extract_thoughts: bool = False
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Extract text content from Gemini response, separating regular content from thoughts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Gemini API response object
|
||||||
|
extract_thoughts: Whether to extract thought content separately
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (regular_text, thought_text)
|
||||||
|
"""
|
||||||
|
candidates = getattr(response, "candidates", None)
|
||||||
|
if not candidates:
|
||||||
|
return ("", "")
|
||||||
|
|
||||||
|
regular_parts: list[str] = []
|
||||||
|
thought_parts: list[str] = []
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if not getattr(candidate, "content", None):
|
||||||
|
continue
|
||||||
|
# Use 'or []' to handle None values from parts attribute
|
||||||
|
for part in getattr(candidate.content, "parts", None) or []:
|
||||||
|
text = getattr(part, "text", None)
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this part is thought content using the 'thought' attribute
|
||||||
|
is_thought = getattr(part, "thought", False)
|
||||||
|
|
||||||
|
if is_thought and extract_thoughts:
|
||||||
|
thought_parts.append(text)
|
||||||
|
elif not is_thought:
|
||||||
|
regular_parts.append(text)
|
||||||
|
|
||||||
|
return ("\n".join(regular_parts), "\n".join(thought_parts))
|
||||||
|
|
||||||
|
|
||||||
|
async def gemini_complete_if_cache(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
history_messages: list[dict[str, Any]] | None = None,
|
||||||
|
enable_cot: bool = False,
|
||||||
|
base_url: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
token_tracker: Any | None = None,
|
||||||
|
stream: bool | None = None,
|
||||||
|
keyword_extraction: bool = False,
|
||||||
|
generation_config: dict[str, Any] | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
**_: Any,
|
||||||
|
) -> str | AsyncIterator[str]:
|
||||||
|
"""
|
||||||
|
Complete a prompt using Gemini's API with Chain of Thought (COT) support.
|
||||||
|
|
||||||
|
This function supports automatic integration of reasoning content from Gemini models
|
||||||
|
that provide Chain of Thought capabilities via the thinking_config API feature.
|
||||||
|
|
||||||
|
COT Integration:
|
||||||
|
- When enable_cot=True: Thought content is wrapped in <think>...</think> tags
|
||||||
|
- When enable_cot=False: Thought content is filtered out, only regular content returned
|
||||||
|
- Thought content is identified by the 'thought' attribute on response parts
|
||||||
|
- Requires thinking_config to be enabled in generation_config for API to return thoughts
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The Gemini model to use.
|
||||||
|
prompt: The prompt to complete.
|
||||||
|
system_prompt: Optional system prompt to include.
|
||||||
|
history_messages: Optional list of previous messages in the conversation.
|
||||||
|
api_key: Optional Gemini API key. If None, uses environment variable.
|
||||||
|
base_url: Optional custom API endpoint.
|
||||||
|
generation_config: Optional generation configuration dict.
|
||||||
|
keyword_extraction: Whether to use JSON response format.
|
||||||
|
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
hashing_kv: Storage interface (for interface parity with other bindings).
|
||||||
|
enable_cot: Whether to include Chain of Thought content in the response.
|
||||||
|
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
|
||||||
|
**_: Additional keyword arguments (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The completed text (with COT content if enable_cot=True) or an async iterator
|
||||||
|
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the response from Gemini is empty.
|
||||||
|
ValueError: If API key is not provided or configured.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
key = _ensure_api_key(api_key)
|
||||||
|
# Convert timeout from seconds to milliseconds for Gemini API
|
||||||
|
timeout_ms = timeout * 1000 if timeout else None
|
||||||
|
client = _get_gemini_client(key, base_url, timeout_ms)
|
||||||
|
|
||||||
|
history_block = _format_history_messages(history_messages)
|
||||||
|
prompt_sections = []
|
||||||
|
if history_block:
|
||||||
|
prompt_sections.append(history_block)
|
||||||
|
prompt_sections.append(f"[user] {prompt}")
|
||||||
|
combined_prompt = "\n".join(prompt_sections)
|
||||||
|
|
||||||
|
config_obj = _build_generation_config(
|
||||||
|
generation_config,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_kwargs: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"contents": [combined_prompt],
|
||||||
|
}
|
||||||
|
if config_obj is not None:
|
||||||
|
request_kwargs["config"] = config_obj
|
||||||
|
|
||||||
|
def _call_model():
|
||||||
|
return client.models.generate_content(**request_kwargs)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
queue: asyncio.Queue[Any] = asyncio.Queue()
|
||||||
|
usage_container: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def _stream_model() -> None:
|
||||||
|
# COT state tracking for streaming
|
||||||
|
cot_active = False
|
||||||
|
cot_started = False
|
||||||
|
initial_content_seen = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream_kwargs = dict(request_kwargs)
|
||||||
|
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
|
||||||
|
for chunk in stream_iterator:
|
||||||
|
usage = getattr(chunk, "usage_metadata", None)
|
||||||
|
if usage is not None:
|
||||||
|
usage_container["usage"] = usage
|
||||||
|
|
||||||
|
# Extract both regular and thought content
|
||||||
|
regular_text, thought_text = _extract_response_text(
|
||||||
|
chunk, extract_thoughts=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_cot:
|
||||||
|
# Process regular content
|
||||||
|
if regular_text:
|
||||||
|
if not initial_content_seen:
|
||||||
|
initial_content_seen = True
|
||||||
|
|
||||||
|
# Close COT section if it was active
|
||||||
|
if cot_active:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||||
|
cot_active = False
|
||||||
|
|
||||||
|
# Send regular content
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
|
||||||
|
|
||||||
|
# Process thought content
|
||||||
|
if thought_text:
|
||||||
|
if not initial_content_seen and not cot_started:
|
||||||
|
# Start COT section
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, "<think>")
|
||||||
|
cot_active = True
|
||||||
|
cot_started = True
|
||||||
|
|
||||||
|
# Send thought content if COT is active
|
||||||
|
if cot_active:
|
||||||
|
loop.call_soon_threadsafe(
|
||||||
|
queue.put_nowait, thought_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# COT disabled - only send regular content
|
||||||
|
if regular_text:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
|
||||||
|
|
||||||
|
# Ensure COT is properly closed if still active
|
||||||
|
if cot_active:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||||
|
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||||
|
except Exception as exc: # pragma: no cover - surface runtime issues
|
||||||
|
# Try to close COT tag before reporting error
|
||||||
|
if cot_active:
|
||||||
|
try:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
||||||
|
|
||||||
|
loop.run_in_executor(None, _stream_model)
|
||||||
|
|
||||||
|
async def _async_stream() -> AsyncIterator[str]:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
|
||||||
|
chunk_text = str(item)
|
||||||
|
if "\\u" in chunk_text:
|
||||||
|
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
|
||||||
|
|
||||||
|
# Yield the chunk directly without filtering
|
||||||
|
# COT filtering is already handled in _stream_model()
|
||||||
|
yield chunk_text
|
||||||
|
finally:
|
||||||
|
usage = usage_container.get("usage")
|
||||||
|
if token_tracker and usage:
|
||||||
|
token_tracker.add_usage(
|
||||||
|
{
|
||||||
|
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||||
|
"completion_tokens": getattr(
|
||||||
|
usage, "candidates_token_count", 0
|
||||||
|
),
|
||||||
|
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return _async_stream()
|
||||||
|
|
||||||
|
response = await asyncio.to_thread(_call_model)
|
||||||
|
|
||||||
|
# Extract both regular text and thought text
|
||||||
|
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
|
||||||
|
|
||||||
|
# Apply COT filtering logic based on enable_cot parameter
|
||||||
|
if enable_cot:
|
||||||
|
# Include thought content wrapped in <think> tags
|
||||||
|
if thought_text and thought_text.strip():
|
||||||
|
if not regular_text or regular_text.strip() == "":
|
||||||
|
# Only thought content available
|
||||||
|
final_text = f"<think>{thought_text}</think>"
|
||||||
|
else:
|
||||||
|
# Both content types present: prepend thought to regular content
|
||||||
|
final_text = f"<think>{thought_text}</think>{regular_text}"
|
||||||
|
else:
|
||||||
|
# No thought content, use regular content only
|
||||||
|
final_text = regular_text or ""
|
||||||
|
else:
|
||||||
|
# Filter out thought content, return only regular content
|
||||||
|
final_text = regular_text or ""
|
||||||
|
|
||||||
|
if not final_text:
|
||||||
|
raise RuntimeError("Gemini response did not contain any text content.")
|
||||||
|
|
||||||
|
if "\\u" in final_text:
|
||||||
|
final_text = safe_unicode_decode(final_text.encode("utf-8"))
|
||||||
|
|
||||||
|
final_text = remove_think_tags(final_text)
|
||||||
|
|
||||||
|
usage = getattr(response, "usage_metadata", None)
|
||||||
|
if token_tracker and usage:
|
||||||
|
token_tracker.add_usage(
|
||||||
|
{
|
||||||
|
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||||
|
"completion_tokens": getattr(usage, "candidates_token_count", 0),
|
||||||
|
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Gemini response length: %s", len(final_text))
|
||||||
|
return final_text
|
||||||
|
|
||||||
|
|
||||||
|
async def gemini_model_complete(
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
history_messages: list[dict[str, Any]] | None = None,
|
||||||
|
keyword_extraction: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str | AsyncIterator[str]:
|
||||||
|
hashing_kv = kwargs.get("hashing_kv")
|
||||||
|
model_name = None
|
||||||
|
if hashing_kv is not None:
|
||||||
|
model_name = hashing_kv.global_config.get("llm_model_name")
|
||||||
|
if model_name is None:
|
||||||
|
model_name = kwargs.pop("model_name", None)
|
||||||
|
if model_name is None:
|
||||||
|
raise ValueError("Gemini model name not provided in configuration.")
|
||||||
|
|
||||||
|
return await gemini_complete_if_cache(
|
||||||
|
model_name,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
keyword_extraction=keyword_extraction,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"gemini_complete_if_cache",
|
||||||
|
"gemini_model_complete",
|
||||||
|
]
|
||||||
|
|
@ -138,6 +138,9 @@ async def openai_complete_if_cache(
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
token_tracker: Any | None = None,
|
token_tracker: Any | None = None,
|
||||||
|
keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI
|
||||||
|
stream: bool | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
|
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
|
||||||
|
|
@ -172,8 +175,9 @@ async def openai_complete_if_cache(
|
||||||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||||
These will be passed to the client constructor but will be overridden by
|
These will be passed to the client constructor but will be overridden by
|
||||||
explicit parameters (api_key, base_url).
|
explicit parameters (api_key, base_url).
|
||||||
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
|
||||||
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
||||||
|
- stream: Whether to stream the response. Default is False.
|
||||||
|
- timeout: Request timeout in seconds. Default is None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The completed text (with integrated COT content if available) or an async iterator
|
The completed text (with integrated COT content if available) or an async iterator
|
||||||
|
|
|
||||||
|
|
@ -1795,7 +1795,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
|
||||||
- Filter out short numeric-only text (length < 3 and only digits/dots)
|
- Filter out short numeric-only text (length < 3 and only digits/dots)
|
||||||
- remove_inner_quotes = True
|
- remove_inner_quotes = True
|
||||||
remove Chinese quotes
|
remove Chinese quotes
|
||||||
remove English queotes in and around chinese
|
remove English quotes in and around chinese
|
||||||
Convert non-breaking spaces to regular spaces
|
Convert non-breaking spaces to regular spaces
|
||||||
Convert narrow non-breaking spaces after non-digits to regular spaces
|
Convert narrow non-breaking spaces after non-digits to regular spaces
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"configparser",
|
"configparser",
|
||||||
"future",
|
"future",
|
||||||
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
"json_repair",
|
"json_repair",
|
||||||
"nano-vectordb",
|
"nano-vectordb",
|
||||||
"networkx",
|
"networkx",
|
||||||
|
|
@ -59,6 +60,7 @@ api = [
|
||||||
"tenacity",
|
"tenacity",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"xlsxwriter>=3.1.0",
|
"xlsxwriter>=3.1.0",
|
||||||
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
# API-specific dependencies
|
# API-specific dependencies
|
||||||
"aiofiles",
|
"aiofiles",
|
||||||
"ascii_colors",
|
"ascii_colors",
|
||||||
|
|
@ -107,6 +109,7 @@ offline-llm = [
|
||||||
"aioboto3>=12.0.0,<16.0.0",
|
"aioboto3>=12.0.0,<16.0.0",
|
||||||
"voyageai>=0.2.0,<1.0.0",
|
"voyageai>=0.2.0,<1.0.0",
|
||||||
"llama-index>=0.9.0,<1.0.0",
|
"llama-index>=0.9.0,<1.0.0",
|
||||||
|
"google-genai>=1.0.0,<2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
offline = [
|
offline = [
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
||||||
aioboto3>=12.0.0,<16.0.0
|
aioboto3>=12.0.0,<16.0.0
|
||||||
anthropic>=0.18.0,<1.0.0
|
anthropic>=0.18.0,<1.0.0
|
||||||
|
google-genai>=1.0.0,<2.0.0
|
||||||
llama-index>=0.9.0,<1.0.0
|
llama-index>=0.9.0,<1.0.0
|
||||||
ollama>=0.1.0,<1.0.0
|
ollama>=0.1.0,<1.0.0
|
||||||
openai>=1.0.0,<3.0.0
|
openai>=1.0.0,<3.0.0
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ anthropic>=0.18.0,<1.0.0
|
||||||
|
|
||||||
# Storage backend dependencies
|
# Storage backend dependencies
|
||||||
asyncpg>=0.29.0,<1.0.0
|
asyncpg>=0.29.0,<1.0.0
|
||||||
|
google-genai>=1.0.0,<2.0.0
|
||||||
|
|
||||||
# Document processing dependencies
|
# Document processing dependencies
|
||||||
llama-index>=0.9.0,<1.0.0
|
llama-index>=0.9.0,<1.0.0
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue