diff --git a/.gitignore b/.gitignore
index 9598a6fe..8a5059c8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -67,8 +67,8 @@ download_models_hf.py
# Frontend build output (built during PyPI release)
/lightrag/api/webui/
-# unit-test files
-test_*
+# temporary test files in project root
+/test_*
# Cline files
memory-bank
diff --git a/README-zh.md b/README-zh.md
index c90889dd..3032cdb6 100644
--- a/README-zh.md
+++ b/README-zh.md
@@ -53,7 +53,7 @@
## 🎉 新闻
-- [x] [2025.11.05]🎯📢添加**基于RAGAS的**LightRAG评估框架。
+- [x] [2025.11.05]🎯📢添加**基于RAGAS的**评估框架和**Langfuse**可观测性支持。
- [x] [2025.10.22]🎯📢消除处理**大规模数据集**的瓶颈。
- [x] [2025.09.15]🎯📢显著提升**小型LLM**(如Qwen3-30B-A3B)的知识图谱提取准确性。
- [x] [2025.08.29]🎯📢现已支持**Reranker**,显著提升混合查询性能。
@@ -1463,6 +1463,50 @@ LightRAG服务器提供全面的知识图谱可视化功能。它支持各种重

+## Langfuse 可观测性集成
+
+Langfuse 为 OpenAI 客户端提供了直接替代方案,可自动跟踪所有 LLM 交互,使开发者能够在无需修改代码的情况下监控、调试和优化其 RAG 系统。
+
+### 安装 Langfuse 可选依赖
+
+```
+pip install lightrag-hku
+pip install lightrag-hku[observability]
+
+# 或从源代码安装并启用调试模式
+pip install -e .
+pip install -e ".[observability]"
+```
+
+### 配置 Langfuse 环境变量
+
+修改 .env 文件:
+
+```
+## Langfuse 可观测性(可选)
+# LLM 可观测性和追踪平台
+# 安装命令: pip install lightrag-hku[observability]
+# 注册地址: https://cloud.langfuse.com 或自托管部署
+LANGFUSE_SECRET_KEY=""
+LANGFUSE_PUBLIC_KEY=""
+LANGFUSE_HOST="https://cloud.langfuse.com" # 或您的自托管实例地址
+LANGFUSE_ENABLE_TRACE=true
+```
+
+### Langfuse 使用说明
+
+安装并配置完成后,Langfuse 会自动追踪所有 OpenAI LLM 调用。Langfuse 仪表板功能包括:
+
+- **追踪**:查看完整的 LLM 调用链
+- **分析**:Token 使用量、延迟、成本指标
+- **调试**:检查提示词和响应内容
+- **评估**:比较模型输出结果
+- **监控**:实时告警功能
+
+### 重要提示
+
+**注意**:LightRAG 目前仅把 OpenAI 兼容的 API 调用接入了 Langfuse。Ollama、Azure 和 AWS Bedrock 等 API 还无法使用 Langfuse 的可观测性功能。
+
## RAGAS评估
**RAGAS**(Retrieval Augmented Generation Assessment,检索增强生成评估)是一个使用LLM对RAG系统进行无参考评估的框架。我们提供了基于RAGAS的评估脚本。详细信息请参阅[基于RAGAS的评估框架](lightrag/evaluation/README.md)。
diff --git a/README.md b/README.md
index 559cd0ef..7a366c05 100644
--- a/README.md
+++ b/README.md
@@ -51,7 +51,7 @@
---
## 🎉 News
-- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework for LightRAG.
+- [x] [2025.11.05]🎯📢Add **RAGAS-based** Evaluation Framework and **Langfuse** observability for LightRAG.
- [x] [2025.10.22]🎯📢Eliminate bottlenecks in processing **large-scale datasets**.
- [x] [2025.09.15]🎯📢Significantly enhances KG extraction accuracy for **small LLMs** like Qwen3-30B-A3B.
- [x] [2025.08.29]🎯📢**Reranker** is supported now , significantly boosting performance for mixed queries.
@@ -1543,6 +1543,50 @@ The LightRAG Server offers a comprehensive knowledge graph visualization feature

+## Langfuse observability integration
+
+Langfuse provides a drop-in replacement for the OpenAI client that automatically tracks all LLM interactions, enabling developers to monitor, debug, and optimize their RAG systems without code changes.
+
+### Installation with Langfuse option
+
+```
+pip install lightrag-hku
+pip install lightrag-hku[observability]
+
+# Or install from souce code with debug mode enabled
+pip install -e .
+pip install -e ".[observability]"
+```
+
+### Config Langfuse env vars
+
+modify .env file:
+
+```
+## Langfuse Observability (Optional)
+# LLM observability and tracing platform
+# Install with: pip install lightrag-hku[observability]
+# Sign up at: https://cloud.langfuse.com or self-host
+LANGFUSE_SECRET_KEY=""
+LANGFUSE_PUBLIC_KEY=""
+LANGFUSE_HOST="https://cloud.langfuse.com" # or your self-hosted instance
+LANGFUSE_ENABLE_TRACE=true
+```
+
+### Langfuse Usage
+
+Once installed and configured, Langfuse automatically traces all OpenAI LLM calls. Langfuse dashboard features include:
+
+- **Tracing**: View complete LLM call chains
+- **Analytics**: Token usage, latency, cost metrics
+- **Debugging**: Inspect prompts and responses
+- **Evaluation**: Compare model outputs
+- **Monitoring**: Real-time alerting
+
+### Important Notice
+
+**Note**: LightRAG currently only integrates OpenAI-compatible API calls with Langfuse. APIs such as Ollama, Azure, and AWS Bedrock are not yet supported for Langfuse observability.
+
## RAGAS-based Evaluation
**RAGAS** (Retrieval Augmented Generation Assessment) is a framework for reference-free evaluation of RAG systems using LLMs. There is an evaluation script based on RAGAS. For detailed information, please refer to [RAGAS-based Evaluation Framework](lightrag/evaluation/README.md).
diff --git a/env.example b/env.example
index b68280d2..1961b4a4 100644
--- a/env.example
+++ b/env.example
@@ -170,7 +170,7 @@ MAX_PARALLEL_INSERT=2
###########################################################################
### LLM Configuration
-### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
+### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini
### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service
###########################################################################
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
@@ -191,6 +191,15 @@ LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING=openai
+### Gemini example
+# LLM_BINDING=gemini
+# LLM_MODEL=gemini-flash-latest
+# LLM_BINDING_API_KEY=your_gemini_api_key
+# LLM_BINDING_HOST=https://generativelanguage.googleapis.com
+GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}'
+# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000
+# GEMINI_LLM_TEMPERATURE=0.7
+
### OpenAI Compatible API Specific Parameters
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
# OPENAI_LLM_TEMPERATURE=0.9
diff --git a/examples/lightrag_gemini_demo.py b/examples/lightrag_gemini_demo.py
deleted file mode 100644
index cd2bb579..00000000
--- a/examples/lightrag_gemini_demo.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# pip install -q -U google-genai to use gemini as a client
-
-import os
-import numpy as np
-from google import genai
-from google.genai import types
-from dotenv import load_dotenv
-from lightrag.utils import EmbeddingFunc
-from lightrag import LightRAG, QueryParam
-from sentence_transformers import SentenceTransformer
-from lightrag.kg.shared_storage import initialize_pipeline_status
-
-import asyncio
-import nest_asyncio
-
-# Apply nest_asyncio to solve event loop issues
-nest_asyncio.apply()
-
-load_dotenv()
-gemini_api_key = os.getenv("GEMINI_API_KEY")
-
-WORKING_DIR = "./dickens"
-
-if os.path.exists(WORKING_DIR):
- import shutil
-
- shutil.rmtree(WORKING_DIR)
-
-os.mkdir(WORKING_DIR)
-
-
-async def llm_model_func(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- # 1. Initialize the GenAI Client with your Gemini API Key
- client = genai.Client(api_key=gemini_api_key)
-
- # 2. Combine prompts: system prompt, history, and user prompt
- if history_messages is None:
- history_messages = []
-
- combined_prompt = ""
- if system_prompt:
- combined_prompt += f"{system_prompt}\n"
-
- for msg in history_messages:
- # Each msg is expected to be a dict: {"role": "...", "content": "..."}
- combined_prompt += f"{msg['role']}: {msg['content']}\n"
-
- # Finally, add the new user prompt
- combined_prompt += f"user: {prompt}"
-
- # 3. Call the Gemini model
- response = client.models.generate_content(
- model="gemini-1.5-flash",
- contents=[combined_prompt],
- config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
- )
-
- # 4. Return the response text
- return response.text
-
-
-async def embedding_func(texts: list[str]) -> np.ndarray:
- model = SentenceTransformer("all-MiniLM-L6-v2")
- embeddings = model.encode(texts, convert_to_numpy=True)
- return embeddings
-
-
-async def initialize_rag():
- rag = LightRAG(
- working_dir=WORKING_DIR,
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=384,
- max_token_size=8192,
- func=embedding_func,
- ),
- )
-
- await rag.initialize_storages()
- await initialize_pipeline_status()
-
- return rag
-
-
-def main():
- # Initialize RAG instance
- rag = asyncio.run(initialize_rag())
- file_path = "story.txt"
- with open(file_path, "r") as file:
- text = file.read()
-
- rag.insert(text)
-
- response = rag.query(
- query="What is the main theme of the story?",
- param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
- )
-
- print(response)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/lightrag_gemini_demo_no_tiktoken.py b/examples/lightrag_gemini_demo_no_tiktoken.py
deleted file mode 100644
index 92c74201..00000000
--- a/examples/lightrag_gemini_demo_no_tiktoken.py
+++ /dev/null
@@ -1,230 +0,0 @@
-# pip install -q -U google-genai to use gemini as a client
-
-import os
-from typing import Optional
-import dataclasses
-from pathlib import Path
-import hashlib
-import numpy as np
-from google import genai
-from google.genai import types
-from dotenv import load_dotenv
-from lightrag.utils import EmbeddingFunc, Tokenizer
-from lightrag import LightRAG, QueryParam
-from sentence_transformers import SentenceTransformer
-from lightrag.kg.shared_storage import initialize_pipeline_status
-import sentencepiece as spm
-import requests
-
-import asyncio
-import nest_asyncio
-
-# Apply nest_asyncio to solve event loop issues
-nest_asyncio.apply()
-
-load_dotenv()
-gemini_api_key = os.getenv("GEMINI_API_KEY")
-
-WORKING_DIR = "./dickens"
-
-if os.path.exists(WORKING_DIR):
- import shutil
-
- shutil.rmtree(WORKING_DIR)
-
-os.mkdir(WORKING_DIR)
-
-
-class GemmaTokenizer(Tokenizer):
- # adapted from google-cloud-aiplatform[tokenization]
-
- @dataclasses.dataclass(frozen=True)
- class _TokenizerConfig:
- tokenizer_model_url: str
- tokenizer_model_hash: str
-
- _TOKENIZERS = {
- "google/gemma2": _TokenizerConfig(
- tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
- tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
- ),
- "google/gemma3": _TokenizerConfig(
- tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
- tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
- ),
- }
-
- def __init__(
- self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
- ):
- # https://github.com/google/gemma_pytorch/tree/main/tokenizer
- if "1.5" in model_name or "1.0" in model_name:
- # up to gemini 1.5 gemma2 is a comparable local tokenizer
- # https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
- tokenizer_name = "google/gemma2"
- else:
- # for gemini > 2.0 gemma3 was used
- tokenizer_name = "google/gemma3"
-
- file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
- tokenizer_model_name = file_url.rsplit("/", 1)[1]
- expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
-
- tokenizer_dir = Path(tokenizer_dir)
- if tokenizer_dir.is_dir():
- file_path = tokenizer_dir / tokenizer_model_name
- model_data = self._maybe_load_from_cache(
- file_path=file_path, expected_hash=expected_hash
- )
- else:
- model_data = None
- if not model_data:
- model_data = self._load_from_url(
- file_url=file_url, expected_hash=expected_hash
- )
- self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
-
- tokenizer = spm.SentencePieceProcessor()
- tokenizer.LoadFromSerializedProto(model_data)
- super().__init__(model_name=model_name, tokenizer=tokenizer)
-
- def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
- """Returns true if the content is valid by checking the hash."""
- return hashlib.sha256(model_data).hexdigest() == expected_hash
-
- def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
- """Loads the model data from the cache path."""
- if not file_path.is_file():
- return
- with open(file_path, "rb") as f:
- content = f.read()
- if self._is_valid_model(model_data=content, expected_hash=expected_hash):
- return content
-
- # Cached file corrupted.
- self._maybe_remove_file(file_path)
-
- def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
- """Loads model bytes from the given file url."""
- resp = requests.get(file_url)
- resp.raise_for_status()
- content = resp.content
-
- if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
- actual_hash = hashlib.sha256(content).hexdigest()
- raise ValueError(
- f"Downloaded model file is corrupted."
- f" Expected hash {expected_hash}. Got file hash {actual_hash}."
- )
- return content
-
- @staticmethod
- def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
- """Saves the model data to the cache path."""
- try:
- if not cache_path.is_file():
- cache_dir = cache_path.parent
- cache_dir.mkdir(parents=True, exist_ok=True)
- with open(cache_path, "wb") as f:
- f.write(model_data)
- except OSError:
- # Don't raise if we cannot write file.
- pass
-
- @staticmethod
- def _maybe_remove_file(file_path: Path) -> None:
- """Removes the file if exists."""
- if not file_path.is_file():
- return
- try:
- file_path.unlink()
- except OSError:
- # Don't raise if we cannot remove file.
- pass
-
- # def encode(self, content: str) -> list[int]:
- # return self.tokenizer.encode(content)
-
- # def decode(self, tokens: list[int]) -> str:
- # return self.tokenizer.decode(tokens)
-
-
-async def llm_model_func(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- # 1. Initialize the GenAI Client with your Gemini API Key
- client = genai.Client(api_key=gemini_api_key)
-
- # 2. Combine prompts: system prompt, history, and user prompt
- if history_messages is None:
- history_messages = []
-
- combined_prompt = ""
- if system_prompt:
- combined_prompt += f"{system_prompt}\n"
-
- for msg in history_messages:
- # Each msg is expected to be a dict: {"role": "...", "content": "..."}
- combined_prompt += f"{msg['role']}: {msg['content']}\n"
-
- # Finally, add the new user prompt
- combined_prompt += f"user: {prompt}"
-
- # 3. Call the Gemini model
- response = client.models.generate_content(
- model="gemini-1.5-flash",
- contents=[combined_prompt],
- config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
- )
-
- # 4. Return the response text
- return response.text
-
-
-async def embedding_func(texts: list[str]) -> np.ndarray:
- model = SentenceTransformer("all-MiniLM-L6-v2")
- embeddings = model.encode(texts, convert_to_numpy=True)
- return embeddings
-
-
-async def initialize_rag():
- rag = LightRAG(
- working_dir=WORKING_DIR,
- # tiktoken_model_name="gpt-4o-mini",
- tokenizer=GemmaTokenizer(
- tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
- model_name="gemini-2.0-flash",
- ),
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=384,
- max_token_size=8192,
- func=embedding_func,
- ),
- )
-
- await rag.initialize_storages()
- await initialize_pipeline_status()
-
- return rag
-
-
-def main():
- # Initialize RAG instance
- rag = asyncio.run(initialize_rag())
- file_path = "story.txt"
- with open(file_path, "r") as file:
- text = file.read()
-
- rag.insert(text)
-
- response = rag.query(
- query="What is the main theme of the story?",
- param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
- )
-
- print(response)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/lightrag_gemini_track_token_demo.py b/examples/lightrag_gemini_track_token_demo.py
deleted file mode 100644
index a72fc717..00000000
--- a/examples/lightrag_gemini_track_token_demo.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# pip install -q -U google-genai to use gemini as a client
-
-import os
-import asyncio
-import numpy as np
-import nest_asyncio
-from google import genai
-from google.genai import types
-from dotenv import load_dotenv
-from lightrag.utils import EmbeddingFunc
-from lightrag import LightRAG, QueryParam
-from lightrag.kg.shared_storage import initialize_pipeline_status
-from lightrag.llm.siliconcloud import siliconcloud_embedding
-from lightrag.utils import setup_logger
-from lightrag.utils import TokenTracker
-
-setup_logger("lightrag", level="DEBUG")
-
-# Apply nest_asyncio to solve event loop issues
-nest_asyncio.apply()
-
-load_dotenv()
-gemini_api_key = os.getenv("GEMINI_API_KEY")
-siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
-
-WORKING_DIR = "./dickens"
-
-if not os.path.exists(WORKING_DIR):
- os.mkdir(WORKING_DIR)
-
-token_tracker = TokenTracker()
-
-
-async def llm_model_func(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- # 1. Initialize the GenAI Client with your Gemini API Key
- client = genai.Client(api_key=gemini_api_key)
-
- # 2. Combine prompts: system prompt, history, and user prompt
- if history_messages is None:
- history_messages = []
-
- combined_prompt = ""
- if system_prompt:
- combined_prompt += f"{system_prompt}\n"
-
- for msg in history_messages:
- # Each msg is expected to be a dict: {"role": "...", "content": "..."}
- combined_prompt += f"{msg['role']}: {msg['content']}\n"
-
- # Finally, add the new user prompt
- combined_prompt += f"user: {prompt}"
-
- # 3. Call the Gemini model
- response = client.models.generate_content(
- model="gemini-2.0-flash",
- contents=[combined_prompt],
- config=types.GenerateContentConfig(
- max_output_tokens=5000, temperature=0, top_k=10
- ),
- )
-
- # 4. Get token counts with null safety
- usage = getattr(response, "usage_metadata", None)
- prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
- completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
- total_tokens = getattr(usage, "total_token_count", 0) or (
- prompt_tokens + completion_tokens
- )
-
- token_counts = {
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "total_tokens": total_tokens,
- }
-
- token_tracker.add_usage(token_counts)
-
- # 5. Return the response text
- return response.text
-
-
-async def embedding_func(texts: list[str]) -> np.ndarray:
- return await siliconcloud_embedding(
- texts,
- model="BAAI/bge-m3",
- api_key=siliconflow_api_key,
- max_token_size=512,
- )
-
-
-async def initialize_rag():
- rag = LightRAG(
- working_dir=WORKING_DIR,
- entity_extract_max_gleaning=1,
- enable_llm_cache=True,
- enable_llm_cache_for_entity_extract=True,
- embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=1024,
- max_token_size=8192,
- func=embedding_func,
- ),
- )
-
- await rag.initialize_storages()
- await initialize_pipeline_status()
-
- return rag
-
-
-def main():
- # Initialize RAG instance
- rag = asyncio.run(initialize_rag())
-
- with open("./book.txt", "r", encoding="utf-8") as f:
- rag.insert(f.read())
-
- # Context Manager Method
- with token_tracker:
- print(
- rag.query(
- "What are the top themes in this story?", param=QueryParam(mode="naive")
- )
- )
-
- print(
- rag.query(
- "What are the top themes in this story?", param=QueryParam(mode="local")
- )
- )
-
- print(
- rag.query(
- "What are the top themes in this story?",
- param=QueryParam(mode="global"),
- )
- )
-
- print(
- rag.query(
- "What are the top themes in this story?",
- param=QueryParam(mode="hybrid"),
- )
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/lightrag/api/README.md b/lightrag/api/README.md
index bc21fac4..f62e24d3 100644
--- a/lightrag/api/README.md
+++ b/lightrag/api/README.md
@@ -50,6 +50,7 @@ LightRAG necessitates the integration of both an LLM (Large Language Model) and
* openai or openai compatible
* azure_openai
* aws_bedrock
+* gemini
It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**.
@@ -72,6 +73,8 @@ EMBEDDING_DIM=1024
# EMBEDDING_BINDING_API_KEY=your_api_key
```
+> When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`).
+
* Ollama LLM + Ollama Embedding:
```
diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py
index cd277b6e..1623b064 100644
--- a/lightrag/api/__init__.py
+++ b/lightrag/api/__init__.py
@@ -1 +1 @@
-__api_version__ = "0250"
+__api_version__ = "0251"
diff --git a/lightrag/api/config.py b/lightrag/api/config.py
index de4fa00b..92e81253 100644
--- a/lightrag/api/config.py
+++ b/lightrag/api/config.py
@@ -8,6 +8,7 @@ import logging
from dotenv import load_dotenv
from lightrag.utils import get_env_value
from lightrag.llm.binding_options import (
+ GeminiLLMOptions,
OllamaEmbeddingOptions,
OllamaLLMOptions,
OpenAILLMOptions,
@@ -63,6 +64,9 @@ def get_default_host(binding_type: str) -> str:
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
+ "gemini": os.getenv(
+ "LLM_BINDING_HOST", "https://generativelanguage.googleapis.com"
+ ),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
@@ -226,6 +230,7 @@ def parse_args() -> argparse.Namespace:
"openai-ollama",
"azure_openai",
"aws_bedrock",
+ "gemini",
],
help="LLM binding type (default: from env or ollama)",
)
@@ -281,6 +286,16 @@ def parse_args() -> argparse.Namespace:
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
OpenAILLMOptions.add_args(parser)
+ if "--llm-binding" in sys.argv:
+ try:
+ idx = sys.argv.index("--llm-binding")
+ if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "gemini":
+ GeminiLLMOptions.add_args(parser)
+ except IndexError:
+ pass
+ elif os.environ.get("LLM_BINDING") == "gemini":
+ GeminiLLMOptions.add_args(parser)
+
args = parser.parse_args()
# convert relative path to absolute path
diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py
index fc1e0484..c9bb1a44 100644
--- a/lightrag/api/lightrag_server.py
+++ b/lightrag/api/lightrag_server.py
@@ -89,6 +89,7 @@ class LLMConfigCache:
# Initialize configurations based on binding conditions
self.openai_llm_options = None
+ self.gemini_llm_options = None
self.ollama_llm_options = None
self.ollama_embedding_options = None
@@ -99,6 +100,12 @@ class LLMConfigCache:
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
+ if args.llm_binding == "gemini":
+ from lightrag.llm.binding_options import GeminiLLMOptions
+
+ self.gemini_llm_options = GeminiLLMOptions.options_dict(args)
+ logger.info(f"Gemini LLM Options: {self.gemini_llm_options}")
+
# Only initialize and log Ollama LLM options when using Ollama LLM binding
if args.llm_binding == "ollama":
try:
@@ -279,6 +286,7 @@ def create_app(args):
"openai",
"azure_openai",
"aws_bedrock",
+ "gemini",
]:
raise Exception("llm binding not supported")
@@ -504,6 +512,44 @@ def create_app(args):
return optimized_azure_openai_model_complete
+ def create_optimized_gemini_llm_func(
+ config_cache: LLMConfigCache, args, llm_timeout: int
+ ):
+ """Create optimized Gemini LLM function with cached configuration"""
+
+ async def optimized_gemini_model_complete(
+ prompt,
+ system_prompt=None,
+ history_messages=None,
+ keyword_extraction=False,
+ **kwargs,
+ ) -> str:
+ from lightrag.llm.gemini import gemini_complete_if_cache
+
+ if history_messages is None:
+ history_messages = []
+
+ # Use pre-processed configuration to avoid repeated parsing
+ kwargs["timeout"] = llm_timeout
+ if (
+ config_cache.gemini_llm_options is not None
+ and "generation_config" not in kwargs
+ ):
+ kwargs["generation_config"] = dict(config_cache.gemini_llm_options)
+
+ return await gemini_complete_if_cache(
+ args.llm_model,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ api_key=args.llm_binding_api_key,
+ base_url=args.llm_binding_host,
+ keyword_extraction=keyword_extraction,
+ **kwargs,
+ )
+
+ return optimized_gemini_model_complete
+
def create_llm_model_func(binding: str):
"""
Create LLM model function based on binding type.
@@ -525,6 +571,8 @@ def create_app(args):
return create_optimized_azure_openai_llm_func(
config_cache, args, llm_timeout
)
+ elif binding == "gemini":
+ return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
diff --git a/lightrag/base.py b/lightrag/base.py
index 3cf40136..bae0728b 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -19,7 +19,6 @@ from typing import (
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
from .constants import (
- GRAPH_FIELD_SEP,
DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K,
DEFAULT_MAX_ENTITY_TOKENS,
@@ -528,56 +527,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else []
return result
- @abstractmethod
- async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """Get all nodes that are associated with the given chunk_ids.
-
- Args:
- chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
-
- Returns:
- list[dict]: A list of nodes, where each node is a dictionary of its properties.
- An empty list if no matching nodes are found.
- """
-
- @abstractmethod
- async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """Get all edges that are associated with the given chunk_ids.
-
- Args:
- chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
-
- Returns:
- list[dict]: A list of edges, where each edge is a dictionary of its properties.
- An empty list if no matching edges are found.
- """
- # Default implementation iterates through all nodes and their edges, which is inefficient.
- # This method should be overridden by subclasses for better performance.
- all_edges = []
- all_labels = await self.get_all_labels()
- processed_edges = set()
-
- for label in all_labels:
- edges = await self.get_node_edges(label)
- if edges:
- for src_id, tgt_id in edges:
- # Avoid processing the same edge twice in an undirected graph
- edge_tuple = tuple(sorted((src_id, tgt_id)))
- if edge_tuple in processed_edges:
- continue
- processed_edges.add(edge_tuple)
-
- edge = await self.get_edge(src_id, tgt_id)
- if edge and "source_id" in edge:
- source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
- if not source_ids.isdisjoint(chunk_ids):
- # Add source and target to the edge dict for easier processing later
- edge_with_nodes = edge.copy()
- edge_with_nodes["source"] = src_id
- edge_with_nodes["target"] = tgt_id
- all_edges.append(edge_with_nodes)
- return all_edges
-
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.
diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py
index f48ea20f..d81c2ebd 100644
--- a/lightrag/kg/memgraph_impl.py
+++ b/lightrag/kg/memgraph_impl.py
@@ -8,7 +8,6 @@ import configparser
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
-from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm
@@ -784,79 +783,6 @@ class MemgraphStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
- async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """Get all nodes that are associated with the given chunk_ids.
-
- Args:
- chunk_ids: List of chunk IDs to find associated nodes for
-
- Returns:
- list[dict]: A list of nodes, where each node is a dictionary of its properties.
- An empty list if no matching nodes are found.
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- UNWIND $chunk_ids AS chunk_id
- MATCH (n:`{workspace_label}`)
- WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
- RETURN DISTINCT n
- """
- result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
- nodes = []
- async for record in result:
- node = record["n"]
- node_dict = dict(node)
- node_dict["id"] = node_dict.get("entity_id")
- nodes.append(node_dict)
- await result.consume()
- return nodes
-
- async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """Get all edges that are associated with the given chunk_ids.
-
- Args:
- chunk_ids: List of chunk IDs to find associated edges for
-
- Returns:
- list[dict]: A list of edges, where each edge is a dictionary of its properties.
- An empty list if no matching edges are found.
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- UNWIND $chunk_ids AS chunk_id
- MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
- WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
- WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
- // Ensure we only return each unique edge once by ordering the source and target
- WITH a, b, r,
- CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
- CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
- RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
- """
- result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
- edges = []
- async for record in result:
- edge_properties = record["properties"]
- edge_properties["source"] = record["source"]
- edge_properties["target"] = record["target"]
- edges.append(edge_properties)
- await result.consume()
- return edges
-
async def get_knowledge_graph(
self,
node_label: str,
diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py
index e55062f1..30452c74 100644
--- a/lightrag/kg/mongo_impl.py
+++ b/lightrag/kg/mongo_impl.py
@@ -1036,45 +1036,6 @@ class MongoGraphStorage(BaseGraphStorage):
return result
- async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """Get all nodes that are associated with the given chunk_ids.
-
- Args:
- chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
-
- Returns:
- list[dict]: A list of nodes, where each node is a dictionary of its properties.
- An empty list if no matching nodes are found.
- """
- if not chunk_ids:
- return []
-
- cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
- return [doc async for doc in cursor]
-
- async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """Get all edges that are associated with the given chunk_ids.
-
- Args:
- chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
-
- Returns:
- list[dict]: A list of edges, where each edge is a dictionary of its properties.
- An empty list if no matching edges are found.
- """
- if not chunk_ids:
- return []
-
- cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
-
- edges = []
- async for edge in cursor:
- edge["source"] = edge["source_node_id"]
- edge["target"] = edge["target_node_id"]
- edges.append(edge)
-
- return edges
-
#
# -------------------------------------------------------------------------
# UPSERTS
diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py
index 896e5973..76fa11f2 100644
--- a/lightrag/kg/neo4j_impl.py
+++ b/lightrag/kg/neo4j_impl.py
@@ -16,7 +16,6 @@ import logging
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
-from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm
@@ -904,49 +903,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are fully consumed
return edges_dict
- async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- UNWIND $chunk_ids AS chunk_id
- MATCH (n:`{workspace_label}`)
- WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
- RETURN DISTINCT n
- """
- result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
- nodes = []
- async for record in result:
- node = record["n"]
- node_dict = dict(node)
- # Add node id (entity_id) to the dictionary for easier access
- node_dict["id"] = node_dict.get("entity_id")
- nodes.append(node_dict)
- await result.consume()
- return nodes
-
- async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- UNWIND $chunk_ids AS chunk_id
- MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
- WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
- RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
- """
- result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
- edges = []
- async for record in result:
- edge_properties = record["properties"]
- edge_properties["source"] = record["source"]
- edge_properties["target"] = record["target"]
- edges.append(edge_properties)
- await result.consume()
- return edges
-
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py
index 88a182d6..48a2d2af 100644
--- a/lightrag/kg/networkx_impl.py
+++ b/lightrag/kg/networkx_impl.py
@@ -5,7 +5,6 @@ from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
-from lightrag.constants import GRAPH_FIELD_SEP
import networkx as nx
from .shared_storage import (
get_storage_lock,
@@ -470,33 +469,6 @@ class NetworkXStorage(BaseGraphStorage):
)
return result
- async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- chunk_ids_set = set(chunk_ids)
- graph = await self._get_graph()
- matching_nodes = []
- for node_id, node_data in graph.nodes(data=True):
- if "source_id" in node_data:
- node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
- if not node_source_ids.isdisjoint(chunk_ids_set):
- node_data_with_id = node_data.copy()
- node_data_with_id["id"] = node_id
- matching_nodes.append(node_data_with_id)
- return matching_nodes
-
- async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- chunk_ids_set = set(chunk_ids)
- graph = await self._get_graph()
- matching_edges = []
- for u, v, edge_data in graph.edges(data=True):
- if "source_id" in edge_data:
- edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
- if not edge_source_ids.isdisjoint(chunk_ids_set):
- edge_data_with_nodes = edge_data.copy()
- edge_data_with_nodes["source"] = u
- edge_data_with_nodes["target"] = v
- matching_edges.append(edge_data_with_nodes)
- return matching_edges
-
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py
index 723de69f..d043176e 100644
--- a/lightrag/kg/postgres_impl.py
+++ b/lightrag/kg/postgres_impl.py
@@ -33,7 +33,6 @@ from ..base import (
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
-from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
import pipmaster as pm
@@ -3569,17 +3568,13 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties"""
- label = self._normalize_node_id(node_id)
-
- result = await self.get_nodes_batch(node_ids=[label])
+ result = await self.get_nodes_batch(node_ids=[node_id])
if result and node_id in result:
return result[node_id]
return None
async def node_degree(self, node_id: str) -> int:
- label = self._normalize_node_id(node_id)
-
- result = await self.node_degrees_batch(node_ids=[label])
+ result = await self.node_degrees_batch(node_ids=[node_id])
if result and node_id in result:
return result[node_id]
@@ -3592,12 +3587,11 @@ class PGGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes"""
- src_label = self._normalize_node_id(source_node_id)
- tgt_label = self._normalize_node_id(target_node_id)
-
- result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
- if result and (src_label, tgt_label) in result:
- return result[(src_label, tgt_label)]
+ result = await self.get_edges_batch(
+ [{"src": source_node_id, "tgt": target_node_id}]
+ )
+ if result and (source_node_id, target_node_id) in result:
+ return result[(source_node_id, target_node_id)]
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@@ -3795,13 +3789,17 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
- seen = set()
- unique_ids = []
+ seen: set[str] = set()
+ unique_ids: list[str] = []
+ lookup: dict[str, str] = {}
+ requested: set[str] = set()
for nid in node_ids:
- nid_norm = self._normalize_node_id(nid)
- if nid_norm not in seen:
- seen.add(nid_norm)
- unique_ids.append(nid_norm)
+ if nid not in seen:
+ seen.add(nid)
+ unique_ids.append(nid)
+ requested.add(nid)
+ lookup[nid] = nid
+ lookup[self._normalize_node_id(nid)] = nid
# Build result dictionary
nodes_dict = {}
@@ -3840,10 +3838,18 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
- f"Failed to parse node string in batch: {node_dict}"
+ f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
- nodes_dict[result["node_id"]] = node_dict
+ node_key = result["node_id"]
+ original_key = lookup.get(node_key)
+ if original_key is None:
+ logger.warning(
+ f"[{self.workspace}] Node {node_key} not found in lookup map"
+ )
+ original_key = node_key
+ if original_key in requested:
+ nodes_dict[original_key] = node_dict
return nodes_dict
@@ -3866,13 +3872,17 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
- seen = set()
+ seen: set[str] = set()
unique_ids: list[str] = []
+ lookup: dict[str, str] = {}
+ requested: set[str] = set()
for nid in node_ids:
- n = self._normalize_node_id(nid)
- if n not in seen:
- seen.add(n)
- unique_ids.append(n)
+ if nid not in seen:
+ seen.add(nid)
+ unique_ids.append(nid)
+ requested.add(nid)
+ lookup[nid] = nid
+ lookup[self._normalize_node_id(nid)] = nid
out_degrees = {}
in_degrees = {}
@@ -3924,8 +3934,16 @@ class PGGraphStorage(BaseGraphStorage):
node_id = row["node_id"]
if not node_id:
continue
- out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
- in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
+ node_key = node_id
+ original_key = lookup.get(node_key)
+ if original_key is None:
+ logger.warning(
+ f"[{self.workspace}] Node {node_key} not found in lookup map"
+ )
+ original_key = node_key
+ if original_key in requested:
+ out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
+ in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
degrees_dict = {}
for node_id in node_ids:
@@ -4054,7 +4072,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
- f"Failed to parse edge properties string: {edge_props}"
+ f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
)
continue
@@ -4070,7 +4088,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
- f"Failed to parse edge properties string: {edge_props}"
+ f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
@@ -4175,102 +4193,6 @@ class PGGraphStorage(BaseGraphStorage):
labels.append(result["label"])
return labels
- async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """
- Retrieves nodes from the graph that are associated with a given list of chunk IDs.
- This method uses a Cypher query with UNWIND to efficiently find all nodes
- where the `source_id` property contains any of the specified chunk IDs.
- """
- # The string representation of the list for the cypher query
- chunk_ids_str = json.dumps(chunk_ids)
-
- query = f"""
- SELECT * FROM cypher('{self.graph_name}', $$
- UNWIND {chunk_ids_str} AS chunk_id
- MATCH (n:base)
- WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
- RETURN n
- $$) AS (n agtype);
- """
- results = await self._query(query)
-
- # Build result list
- nodes = []
- for result in results:
- if result["n"]:
- node_dict = result["n"]["properties"]
-
- # Process string result, parse it to JSON dictionary
- if isinstance(node_dict, str):
- try:
- node_dict = json.loads(node_dict)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
- )
-
- node_dict["id"] = node_dict["entity_id"]
- nodes.append(node_dict)
-
- return nodes
-
- async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
- """
- Retrieves edges from the graph that are associated with a given list of chunk IDs.
- This method uses a Cypher query with UNWIND to efficiently find all edges
- where the `source_id` property contains any of the specified chunk IDs.
- """
- chunk_ids_str = json.dumps(chunk_ids)
-
- query = f"""
- SELECT * FROM cypher('{self.graph_name}', $$
- UNWIND {chunk_ids_str} AS chunk_id
- MATCH ()-[r]-()
- WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
- RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
- $$) AS (edge agtype, source agtype, target agtype);
- """
- results = await self._query(query)
- edges = []
- if results:
- for item in results:
- edge_agtype = item["edge"]["properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(edge_agtype, str):
- try:
- edge_agtype = json.loads(edge_agtype)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
- )
-
- source_agtype = item["source"]["properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(source_agtype, str):
- try:
- source_agtype = json.loads(source_agtype)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
- )
-
- target_agtype = item["target"]["properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(target_agtype, str):
- try:
- target_agtype = json.loads(target_agtype)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
- )
-
- if edge_agtype and source_agtype and target_agtype:
- edge_properties = edge_agtype
- edge_properties["source"] = source_agtype["entity_id"]
- edge_properties["target"] = target_agtype["entity_id"]
- edges.append(edge_properties)
- return edges
-
async def _bfs_subgraph(
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index af2067d8..acf157da 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -3235,38 +3235,31 @@ class LightRAG:
if entity_chunk_updates and self.entity_chunks:
entity_upsert_payload = {}
- entity_delete_ids: set[str] = set()
for entity_name, remaining in entity_chunk_updates.items():
if not remaining:
- entity_delete_ids.add(entity_name)
- else:
- entity_upsert_payload[entity_name] = {
- "chunk_ids": remaining,
- "count": len(remaining),
- "updated_at": current_time,
- }
-
- if entity_delete_ids:
- await self.entity_chunks.delete(list(entity_delete_ids))
+ # Empty entities are deleted alongside graph nodes later
+ continue
+ entity_upsert_payload[entity_name] = {
+ "chunk_ids": remaining,
+ "count": len(remaining),
+ "updated_at": current_time,
+ }
if entity_upsert_payload:
await self.entity_chunks.upsert(entity_upsert_payload)
if relation_chunk_updates and self.relation_chunks:
relation_upsert_payload = {}
- relation_delete_ids: set[str] = set()
for edge_tuple, remaining in relation_chunk_updates.items():
- storage_key = make_relation_chunk_key(*edge_tuple)
if not remaining:
- relation_delete_ids.add(storage_key)
- else:
- relation_upsert_payload[storage_key] = {
- "chunk_ids": remaining,
- "count": len(remaining),
- "updated_at": current_time,
- }
+ # Empty relations are deleted alongside graph edges later
+ continue
+ storage_key = make_relation_chunk_key(*edge_tuple)
+ relation_upsert_payload[storage_key] = {
+ "chunk_ids": remaining,
+ "count": len(remaining),
+ "updated_at": current_time,
+ }
- if relation_delete_ids:
- await self.relation_chunks.delete(list(relation_delete_ids))
if relation_upsert_payload:
await self.relation_chunks.upsert(relation_upsert_payload)
@@ -3296,7 +3289,7 @@ class LightRAG:
# 6. Delete relationships that have no remaining sources
if relationships_to_delete:
try:
- # Delete from vector database
+ # Delete from relation vdb
rel_ids_to_delete = []
for src, tgt in relationships_to_delete:
rel_ids_to_delete.extend(
@@ -3333,15 +3326,16 @@ class LightRAG:
# 7. Delete entities that have no remaining sources
if entities_to_delete:
try:
+ # Batch get all edges for entities to avoid N+1 query problem
+ nodes_edges_dict = await self.chunk_entity_relation_graph.get_nodes_edges_batch(
+ list(entities_to_delete)
+ )
+
# Debug: Check and log all edges before deleting nodes
edges_to_delete = set()
edges_still_exist = 0
- for entity in entities_to_delete:
- edges = (
- await self.chunk_entity_relation_graph.get_node_edges(
- entity
- )
- )
+
+ for entity, edges in nodes_edges_dict.items():
if edges:
for src, tgt in edges:
# Normalize edge representation (sorted for consistency)
@@ -3364,6 +3358,7 @@ class LightRAG:
f"Edge still exists: {src} <-- {tgt}"
)
edges_still_exist += 1
+
if edges_still_exist:
logger.warning(
f"⚠️ {edges_still_exist} entities still has edges before deletion"
@@ -3399,7 +3394,7 @@ class LightRAG:
list(entities_to_delete)
)
- # Delete from vector database
+ # Delete from vector vdb
entity_vdb_ids = [
compute_mdhash_id(entity, prefix="ent-")
for entity in entities_to_delete
diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py
index f17ba0f8..b3affb37 100644
--- a/lightrag/llm/binding_options.py
+++ b/lightrag/llm/binding_options.py
@@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace
import argparse
import json
from dataclasses import asdict, dataclass, field
-from typing import Any, ClassVar, List
+from typing import Any, ClassVar, List, get_args, get_origin
from lightrag.utils import get_env_value
from lightrag.constants import DEFAULT_TEMPERATURE
+def _resolve_optional_type(field_type: Any) -> Any:
+ """Return the concrete type for Optional/Union annotations."""
+ origin = get_origin(field_type)
+ if origin in (list, dict, tuple):
+ return field_type
+
+ args = get_args(field_type)
+ if args:
+ non_none_args = [arg for arg in args if arg is not type(None)]
+ if len(non_none_args) == 1:
+ return non_none_args[0]
+ return field_type
+
+
# =============================================================================
# BindingOptions Base Class
# =============================================================================
@@ -177,9 +191,13 @@ class BindingOptions:
help=arg_item["help"],
)
else:
+ resolved_type = arg_item["type"]
+ if resolved_type is not None:
+ resolved_type = _resolve_optional_type(resolved_type)
+
group.add_argument(
f"--{arg_item['argname']}",
- type=arg_item["type"],
+ type=resolved_type,
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
@@ -210,7 +228,7 @@ class BindingOptions:
argdef = {
"argname": f"{args_prefix}-{field.name}",
"env_name": f"{env_var_prefix}{field.name.upper()}",
- "type": field.type,
+ "type": _resolve_optional_type(field.type),
"default": default_value,
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
}
@@ -454,6 +472,39 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
_binding_name: ClassVar[str] = "ollama_llm"
+@dataclass
+class GeminiLLMOptions(BindingOptions):
+ """Options for Google Gemini models."""
+
+ _binding_name: ClassVar[str] = "gemini_llm"
+
+ temperature: float = DEFAULT_TEMPERATURE
+ top_p: float = 0.95
+ top_k: int = 40
+ max_output_tokens: int | None = None
+ candidate_count: int = 1
+ presence_penalty: float = 0.0
+ frequency_penalty: float = 0.0
+ stop_sequences: List[str] = field(default_factory=list)
+ seed: int | None = None
+ thinking_config: dict | None = None
+ safety_settings: dict | None = None
+
+ _help: ClassVar[dict[str, str]] = {
+ "temperature": "Controls randomness (0.0-2.0, higher = more creative)",
+ "top_p": "Nucleus sampling parameter (0.0-1.0)",
+ "top_k": "Limits sampling to the top K tokens (1 disables the limit)",
+ "max_output_tokens": "Maximum tokens generated in the response",
+ "candidate_count": "Number of candidates returned per request",
+ "presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
+ "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
+ "stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
+ "seed": "Random seed for reproducible generation (leave empty for random)",
+ "thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
+ "safety_settings": "JSON object with Gemini safety settings overrides",
+ }
+
+
# =============================================================================
# Binding Options for OpenAI
# =============================================================================
diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py
new file mode 100644
index 00000000..f06ec6b3
--- /dev/null
+++ b/lightrag/llm/gemini.py
@@ -0,0 +1,422 @@
+"""
+Gemini LLM binding for LightRAG.
+
+This module provides asynchronous helpers that adapt Google's Gemini models
+to the same interface used by the rest of the LightRAG LLM bindings. The
+implementation mirrors the OpenAI helpers while relying on the official
+``google-genai`` client under the hood.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from collections.abc import AsyncIterator
+from functools import lru_cache
+from typing import Any
+
+from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
+
+import pipmaster as pm
+
+# Install the Google Gemini client on demand
+if not pm.is_installed("google-genai"):
+ pm.install("google-genai")
+
+from google import genai # type: ignore
+from google.genai import types # type: ignore
+
+DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com"
+
+LOG = logging.getLogger(__name__)
+
+
+@lru_cache(maxsize=8)
+def _get_gemini_client(
+ api_key: str, base_url: str | None, timeout: int | None = None
+) -> genai.Client:
+ """
+ Create (or fetch cached) Gemini client.
+
+ Args:
+ api_key: Google Gemini API key.
+ base_url: Optional custom API endpoint.
+ timeout: Optional request timeout in milliseconds.
+
+ Returns:
+ genai.Client: Configured Gemini client instance.
+ """
+ client_kwargs: dict[str, Any] = {"api_key": api_key}
+
+ if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None:
+ try:
+ http_options_kwargs = {}
+ if base_url and base_url != DEFAULT_GEMINI_ENDPOINT:
+ http_options_kwargs["api_endpoint"] = base_url
+ if timeout is not None:
+ http_options_kwargs["timeout"] = timeout
+
+ client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
+ except Exception as exc: # pragma: no cover - defensive
+ LOG.warning("Failed to apply custom Gemini http_options: %s", exc)
+
+ try:
+ return genai.Client(**client_kwargs)
+ except TypeError:
+ # Older google-genai releases don't accept http_options; retry without it.
+ client_kwargs.pop("http_options", None)
+ return genai.Client(**client_kwargs)
+
+
+def _ensure_api_key(api_key: str | None) -> str:
+ key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
+ if not key:
+ raise ValueError(
+ "Gemini API key not provided. "
+ "Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
+ )
+ return key
+
+
+def _build_generation_config(
+ base_config: dict[str, Any] | None,
+ system_prompt: str | None,
+ keyword_extraction: bool,
+) -> types.GenerateContentConfig | None:
+ config_data = dict(base_config or {})
+
+ if system_prompt:
+ if config_data.get("system_instruction"):
+ config_data["system_instruction"] = (
+ f"{config_data['system_instruction']}\n{system_prompt}"
+ )
+ else:
+ config_data["system_instruction"] = system_prompt
+
+ if keyword_extraction and not config_data.get("response_mime_type"):
+ config_data["response_mime_type"] = "application/json"
+
+ # Remove entries that are explicitly set to None to avoid type errors
+ sanitized = {
+ key: value
+ for key, value in config_data.items()
+ if value is not None and value != ""
+ }
+
+ if not sanitized:
+ return None
+
+ return types.GenerateContentConfig(**sanitized)
+
+
+def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
+ if not history_messages:
+ return ""
+
+ history_lines: list[str] = []
+ for message in history_messages:
+ role = message.get("role", "user")
+ content = message.get("content", "")
+ history_lines.append(f"[{role}] {content}")
+
+ return "\n".join(history_lines)
+
+
+def _extract_response_text(
+ response: Any, extract_thoughts: bool = False
+) -> tuple[str, str]:
+ """
+ Extract text content from Gemini response, separating regular content from thoughts.
+
+ Args:
+ response: Gemini API response object
+ extract_thoughts: Whether to extract thought content separately
+
+ Returns:
+ Tuple of (regular_text, thought_text)
+ """
+ candidates = getattr(response, "candidates", None)
+ if not candidates:
+ return ("", "")
+
+ regular_parts: list[str] = []
+ thought_parts: list[str] = []
+
+ for candidate in candidates:
+ if not getattr(candidate, "content", None):
+ continue
+ # Use 'or []' to handle None values from parts attribute
+ for part in getattr(candidate.content, "parts", None) or []:
+ text = getattr(part, "text", None)
+ if not text:
+ continue
+
+ # Check if this part is thought content using the 'thought' attribute
+ is_thought = getattr(part, "thought", False)
+
+ if is_thought and extract_thoughts:
+ thought_parts.append(text)
+ elif not is_thought:
+ regular_parts.append(text)
+
+ return ("\n".join(regular_parts), "\n".join(thought_parts))
+
+
+async def gemini_complete_if_cache(
+ model: str,
+ prompt: str,
+ system_prompt: str | None = None,
+ history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
+ base_url: str | None = None,
+ api_key: str | None = None,
+ token_tracker: Any | None = None,
+ stream: bool | None = None,
+ keyword_extraction: bool = False,
+ generation_config: dict[str, Any] | None = None,
+ timeout: int | None = None,
+ **_: Any,
+) -> str | AsyncIterator[str]:
+ """
+ Complete a prompt using Gemini's API with Chain of Thought (COT) support.
+
+ This function supports automatic integration of reasoning content from Gemini models
+ that provide Chain of Thought capabilities via the thinking_config API feature.
+
+ COT Integration:
+ - When enable_cot=True: Thought content is wrapped in ... tags
+ - When enable_cot=False: Thought content is filtered out, only regular content returned
+ - Thought content is identified by the 'thought' attribute on response parts
+ - Requires thinking_config to be enabled in generation_config for API to return thoughts
+
+ Args:
+ model: The Gemini model to use.
+ prompt: The prompt to complete.
+ system_prompt: Optional system prompt to include.
+ history_messages: Optional list of previous messages in the conversation.
+ api_key: Optional Gemini API key. If None, uses environment variable.
+ base_url: Optional custom API endpoint.
+ generation_config: Optional generation configuration dict.
+ keyword_extraction: Whether to use JSON response format.
+ token_tracker: Optional token usage tracker for monitoring API usage.
+ stream: Whether to stream the response.
+ hashing_kv: Storage interface (for interface parity with other bindings).
+ enable_cot: Whether to include Chain of Thought content in the response.
+ timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
+ **_: Additional keyword arguments (ignored).
+
+ Returns:
+ The completed text (with COT content if enable_cot=True) or an async iterator
+ of text chunks if streaming. COT content is wrapped in ... tags.
+
+ Raises:
+ RuntimeError: If the response from Gemini is empty.
+ ValueError: If API key is not provided or configured.
+ """
+ loop = asyncio.get_running_loop()
+
+ key = _ensure_api_key(api_key)
+ # Convert timeout from seconds to milliseconds for Gemini API
+ timeout_ms = timeout * 1000 if timeout else None
+ client = _get_gemini_client(key, base_url, timeout_ms)
+
+ history_block = _format_history_messages(history_messages)
+ prompt_sections = []
+ if history_block:
+ prompt_sections.append(history_block)
+ prompt_sections.append(f"[user] {prompt}")
+ combined_prompt = "\n".join(prompt_sections)
+
+ config_obj = _build_generation_config(
+ generation_config,
+ system_prompt=system_prompt,
+ keyword_extraction=keyword_extraction,
+ )
+
+ request_kwargs: dict[str, Any] = {
+ "model": model,
+ "contents": [combined_prompt],
+ }
+ if config_obj is not None:
+ request_kwargs["config"] = config_obj
+
+ def _call_model():
+ return client.models.generate_content(**request_kwargs)
+
+ if stream:
+ queue: asyncio.Queue[Any] = asyncio.Queue()
+ usage_container: dict[str, Any] = {}
+
+ def _stream_model() -> None:
+ # COT state tracking for streaming
+ cot_active = False
+ cot_started = False
+ initial_content_seen = False
+
+ try:
+ stream_kwargs = dict(request_kwargs)
+ stream_iterator = client.models.generate_content_stream(**stream_kwargs)
+ for chunk in stream_iterator:
+ usage = getattr(chunk, "usage_metadata", None)
+ if usage is not None:
+ usage_container["usage"] = usage
+
+ # Extract both regular and thought content
+ regular_text, thought_text = _extract_response_text(
+ chunk, extract_thoughts=True
+ )
+
+ if enable_cot:
+ # Process regular content
+ if regular_text:
+ if not initial_content_seen:
+ initial_content_seen = True
+
+ # Close COT section if it was active
+ if cot_active:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ cot_active = False
+
+ # Send regular content
+ loop.call_soon_threadsafe(queue.put_nowait, regular_text)
+
+ # Process thought content
+ if thought_text:
+ if not initial_content_seen and not cot_started:
+ # Start COT section
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ cot_active = True
+ cot_started = True
+
+ # Send thought content if COT is active
+ if cot_active:
+ loop.call_soon_threadsafe(
+ queue.put_nowait, thought_text
+ )
+ else:
+ # COT disabled - only send regular content
+ if regular_text:
+ loop.call_soon_threadsafe(queue.put_nowait, regular_text)
+
+ # Ensure COT is properly closed if still active
+ if cot_active:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+
+ loop.call_soon_threadsafe(queue.put_nowait, None)
+ except Exception as exc: # pragma: no cover - surface runtime issues
+ # Try to close COT tag before reporting error
+ if cot_active:
+ try:
+ loop.call_soon_threadsafe(queue.put_nowait, "")
+ except Exception:
+ pass
+ loop.call_soon_threadsafe(queue.put_nowait, exc)
+
+ loop.run_in_executor(None, _stream_model)
+
+ async def _async_stream() -> AsyncIterator[str]:
+ try:
+ while True:
+ item = await queue.get()
+ if item is None:
+ break
+ if isinstance(item, Exception):
+ raise item
+
+ chunk_text = str(item)
+ if "\\u" in chunk_text:
+ chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
+
+ # Yield the chunk directly without filtering
+ # COT filtering is already handled in _stream_model()
+ yield chunk_text
+ finally:
+ usage = usage_container.get("usage")
+ if token_tracker and usage:
+ token_tracker.add_usage(
+ {
+ "prompt_tokens": getattr(usage, "prompt_token_count", 0),
+ "completion_tokens": getattr(
+ usage, "candidates_token_count", 0
+ ),
+ "total_tokens": getattr(usage, "total_token_count", 0),
+ }
+ )
+
+ return _async_stream()
+
+ response = await asyncio.to_thread(_call_model)
+
+ # Extract both regular text and thought text
+ regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
+
+ # Apply COT filtering logic based on enable_cot parameter
+ if enable_cot:
+ # Include thought content wrapped in tags
+ if thought_text and thought_text.strip():
+ if not regular_text or regular_text.strip() == "":
+ # Only thought content available
+ final_text = f"{thought_text}"
+ else:
+ # Both content types present: prepend thought to regular content
+ final_text = f"{thought_text}{regular_text}"
+ else:
+ # No thought content, use regular content only
+ final_text = regular_text or ""
+ else:
+ # Filter out thought content, return only regular content
+ final_text = regular_text or ""
+
+ if not final_text:
+ raise RuntimeError("Gemini response did not contain any text content.")
+
+ if "\\u" in final_text:
+ final_text = safe_unicode_decode(final_text.encode("utf-8"))
+
+ final_text = remove_think_tags(final_text)
+
+ usage = getattr(response, "usage_metadata", None)
+ if token_tracker and usage:
+ token_tracker.add_usage(
+ {
+ "prompt_tokens": getattr(usage, "prompt_token_count", 0),
+ "completion_tokens": getattr(usage, "candidates_token_count", 0),
+ "total_tokens": getattr(usage, "total_token_count", 0),
+ }
+ )
+
+ logger.debug("Gemini response length: %s", len(final_text))
+ return final_text
+
+
+async def gemini_model_complete(
+ prompt: str,
+ system_prompt: str | None = None,
+ history_messages: list[dict[str, Any]] | None = None,
+ keyword_extraction: bool = False,
+ **kwargs: Any,
+) -> str | AsyncIterator[str]:
+ hashing_kv = kwargs.get("hashing_kv")
+ model_name = None
+ if hashing_kv is not None:
+ model_name = hashing_kv.global_config.get("llm_model_name")
+ if model_name is None:
+ model_name = kwargs.pop("model_name", None)
+ if model_name is None:
+ raise ValueError("Gemini model name not provided in configuration.")
+
+ return await gemini_complete_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ keyword_extraction=keyword_extraction,
+ **kwargs,
+ )
+
+
+__all__ = [
+ "gemini_complete_if_cache",
+ "gemini_model_complete",
+]
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index 66c3bfe4..511a3a62 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -138,6 +138,9 @@ async def openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
+ keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI
+ stream: bool | None = None,
+ timeout: int | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
@@ -172,8 +175,9 @@ async def openai_complete_if_cache(
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
- - hashing_kv: Will be removed from kwargs before passing to OpenAI.
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
+ - stream: Whether to stream the response. Default is False.
+ - timeout: Request timeout in seconds. Default is None.
Returns:
The completed text (with integrated COT content if available) or an async iterator
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 6c382894..28559af5 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -1795,7 +1795,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
- Filter out short numeric-only text (length < 3 and only digits/dots)
- remove_inner_quotes = True
remove Chinese quotes
- remove English queotes in and around chinese
+ remove English quotes in and around chinese
Convert non-breaking spaces to regular spaces
Convert narrow non-breaking spaces after non-digits to regular spaces
diff --git a/pyproject.toml b/pyproject.toml
index e7ff4262..c665e49a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,6 +24,7 @@ dependencies = [
"aiohttp",
"configparser",
"future",
+ "google-genai>=1.0.0,<2.0.0",
"json_repair",
"nano-vectordb",
"networkx",
@@ -59,6 +60,7 @@ api = [
"tenacity",
"tiktoken",
"xlsxwriter>=3.1.0",
+ "google-genai>=1.0.0,<2.0.0",
# API-specific dependencies
"aiofiles",
"ascii_colors",
@@ -107,6 +109,7 @@ offline-llm = [
"aioboto3>=12.0.0,<16.0.0",
"voyageai>=0.2.0,<1.0.0",
"llama-index>=0.9.0,<1.0.0",
+ "google-genai>=1.0.0,<2.0.0",
]
offline = [
diff --git a/requirements-offline-llm.txt b/requirements-offline-llm.txt
index 269847a2..4e8b7168 100644
--- a/requirements-offline-llm.txt
+++ b/requirements-offline-llm.txt
@@ -10,6 +10,7 @@
# LLM provider dependencies (with version constraints matching pyproject.toml)
aioboto3>=12.0.0,<16.0.0
anthropic>=0.18.0,<1.0.0
+google-genai>=1.0.0,<2.0.0
llama-index>=0.9.0,<1.0.0
ollama>=0.1.0,<1.0.0
openai>=1.0.0,<3.0.0
diff --git a/requirements-offline.txt b/requirements-offline.txt
index 32191005..8dfb1b01 100644
--- a/requirements-offline.txt
+++ b/requirements-offline.txt
@@ -13,6 +13,7 @@ anthropic>=0.18.0,<1.0.0
# Storage backend dependencies
asyncpg>=0.29.0,<1.0.0
+google-genai>=1.0.0,<2.0.0
# Document processing dependencies
llama-index>=0.9.0,<1.0.0
diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py
index 62f658ff..c6932384 100644
--- a/tests/test_graph_storage.py
+++ b/tests/test_graph_storage.py
@@ -1,11 +1,11 @@
#!/usr/bin/env python
"""
-通用图存储测试程序
+General-purpose graph storage test program.
-该程序根据.env中的LIGHTRAG_GRAPH_STORAGE配置选择使用的图存储类型,
-并对其进行基本操作和高级操作的测试。
+This program selects the graph storage type to use based on the LIGHTRAG_GRAPH_STORAGE configuration in .env,
+and tests its basic and advanced operations.
-支持的图存储类型包括:
+Supported graph storage types include:
- NetworkXStorage
- Neo4JStorage
- MongoDBStorage
@@ -21,7 +21,7 @@ import numpy as np
from dotenv import load_dotenv
from ascii_colors import ASCIIColors
-# 添加项目根目录到Python路径
+# Add the project root directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lightrag.types import KnowledgeGraph
@@ -35,325 +35,352 @@ from lightrag.kg.shared_storage import initialize_share_data
from lightrag.constants import GRAPH_FIELD_SEP
-# 模拟的嵌入函数,返回随机向量
+# Mock embedding function that returns random vectors
async def mock_embedding_func(texts):
- return np.random.rand(len(texts), 10) # 返回10维随机向量
+ return np.random.rand(len(texts), 10) # Return 10-dimensional random vectors
def check_env_file():
"""
- 检查.env文件是否存在,如果不存在则发出警告
- 返回True表示应该继续执行,False表示应该退出
+ Check if the .env file exists and issue a warning if it does not.
+ Returns True to continue execution, False to exit.
"""
if not os.path.exists(".env"):
- warning_msg = "警告: 当前目录中没有找到.env文件,这可能会影响存储配置的加载。"
+ warning_msg = "Warning: .env file not found in the current directory. This may affect storage configuration loading."
ASCIIColors.yellow(warning_msg)
- # 检查是否在交互式终端中运行
+ # Check if running in an interactive terminal
if sys.stdin.isatty():
- response = input("是否继续执行? (yes/no): ")
+ response = input("Do you want to continue? (yes/no): ")
if response.lower() != "yes":
- ASCIIColors.red("测试程序已取消")
+ ASCIIColors.red("Test program cancelled.")
return False
return True
async def initialize_graph_storage():
"""
- 根据环境变量初始化相应的图存储实例
- 返回初始化的存储实例
+ Initialize the corresponding graph storage instance based on environment variables.
+ Returns the initialized storage instance.
"""
- # 从环境变量中获取图存储类型
+ # Get the graph storage type from environment variables
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
- # 验证存储类型是否有效
+ # Validate the storage type
try:
verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
except ValueError as e:
- ASCIIColors.red(f"错误: {str(e)}")
+ ASCIIColors.red(f"Error: {str(e)}")
ASCIIColors.yellow(
- f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
+ f"Supported graph storage types: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
return None
- # 检查所需的环境变量
+ # Check for required environment variables
required_env_vars = STORAGE_ENV_REQUIREMENTS.get(graph_storage_type, [])
missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
if missing_env_vars:
ASCIIColors.red(
- f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}"
+ f"Error: {graph_storage_type} requires the following environment variables, but they are not set: {', '.join(missing_env_vars)}"
)
return None
- # 动态导入相应的模块
+ # Dynamically import the corresponding module
module_path = STORAGES.get(graph_storage_type)
if not module_path:
- ASCIIColors.red(f"错误: 未找到 {graph_storage_type} 的模块路径")
+ ASCIIColors.red(f"Error: Module path for {graph_storage_type} not found.")
return None
try:
module = importlib.import_module(module_path, package="lightrag")
storage_class = getattr(module, graph_storage_type)
except (ImportError, AttributeError) as e:
- ASCIIColors.red(f"错误: 导入 {graph_storage_type} 失败: {str(e)}")
+ ASCIIColors.red(f"Error: Failed to import {graph_storage_type}: {str(e)}")
return None
- # 初始化存储实例
+ # Initialize the storage instance
global_config = {
- "embedding_batch_num": 10, # 批处理大小
+ "embedding_batch_num": 10, # Batch size
"vector_db_storage_cls_kwargs": {
- "cosine_better_than_threshold": 0.5 # 余弦相似度阈值
+ "cosine_better_than_threshold": 0.5 # Cosine similarity threshold
},
- "working_dir": os.environ.get("WORKING_DIR", "./rag_storage"), # 工作目录
+ "working_dir": os.environ.get(
+ "WORKING_DIR", "./rag_storage"
+ ), # Working directory
}
- # 如果使用 NetworkXStorage,需要先初始化 shared_storage
- if graph_storage_type == "NetworkXStorage":
- initialize_share_data() # 使用单进程模式
+ # Initialize shared_storage for all storage types (required for locks)
+ # All graph storage implementations use locks like get_data_init_lock() and get_graph_db_lock()
+ initialize_share_data() # Use single-process mode (workers=1)
try:
storage = storage_class(
namespace="test_graph",
+ workspace="test_workspace",
global_config=global_config,
embedding_func=mock_embedding_func,
)
- # 初始化连接
+ # Initialize the connection
await storage.initialize()
return storage
except Exception as e:
- ASCIIColors.red(f"错误: 初始化 {graph_storage_type} 失败: {str(e)}")
+ ASCIIColors.red(f"Error: Failed to initialize {graph_storage_type}: {str(e)}")
return None
async def test_graph_basic(storage):
"""
- 测试图数据库的基本操作:
- 1. 使用 upsert_node 插入两个节点
- 2. 使用 upsert_edge 插入一条连接两个节点的边
- 3. 使用 get_node 读取一个节点
- 4. 使用 get_edge 读取一条边
+ Test basic graph database operations:
+ 1. Use upsert_node to insert two nodes.
+ 2. Use upsert_edge to insert an edge connecting the two nodes.
+ 3. Use get_node to read a node.
+ 4. Use get_edge to read an edge.
"""
try:
- # 1. 插入第一个节点
- node1_id = "人工智能"
+ # 1. Insert the first node
+ node1_id = "Artificial Intelligence"
node1_data = {
"entity_id": node1_id,
- "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
- "keywords": "AI,机器学习,深度学习",
- "entity_type": "技术领域",
+ "description": "Artificial intelligence is a branch of computer science that aims to understand the essence of intelligence and produce a new kind of intelligent machine that can react in a manner similar to human intelligence.",
+ "keywords": "AI,Machine Learning,Deep Learning",
+ "entity_type": "Technology Field",
}
- print(f"插入节点1: {node1_id}")
+ print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
- # 2. 插入第二个节点
- node2_id = "机器学习"
+ # 2. Insert the second node
+ node2_id = "Machine Learning"
node2_data = {
"entity_id": node2_id,
- "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
- "keywords": "监督学习,无监督学习,强化学习",
- "entity_type": "技术领域",
+ "description": "Machine learning is a branch of artificial intelligence that uses statistical methods to enable computer systems to learn without being explicitly programmed.",
+ "keywords": "Supervised Learning,Unsupervised Learning,Reinforcement Learning",
+ "entity_type": "Technology Field",
}
- print(f"插入节点2: {node2_id}")
+ print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
- # 3. 插入连接边
+ # 3. Insert the connecting edge
edge_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "人工智能领域包含机器学习这个子领域",
+ "description": "The field of artificial intelligence includes the subfield of machine learning.",
}
- print(f"插入边: {node1_id} -> {node2_id}")
+ print(f"Inserting edge: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge_data)
- # 4. 读取节点属性
- print(f"读取节点属性: {node1_id}")
+ # 4. Read node properties
+ print(f"Reading node properties: {node1_id}")
node1_props = await storage.get_node(node1_id)
if node1_props:
- print(f"成功读取节点属性: {node1_id}")
- print(f"节点描述: {node1_props.get('description', '无描述')}")
- print(f"节点类型: {node1_props.get('entity_type', '无类型')}")
- print(f"节点关键词: {node1_props.get('keywords', '无关键词')}")
- # 验证返回的属性是否正确
+ print(f"Successfully read node properties: {node1_id}")
+ print(
+ f"Node description: {node1_props.get('description', 'No description')}"
+ )
+ print(f"Node type: {node1_props.get('entity_type', 'No type')}")
+ print(f"Node keywords: {node1_props.get('keywords', 'No keywords')}")
+ # Verify that the returned properties are correct
assert (
node1_props.get("entity_id") == node1_id
- ), f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
+ ), f"Node ID mismatch: expected {node1_id}, got {node1_props.get('entity_id')}"
assert (
node1_props.get("description") == node1_data["description"]
- ), "节点描述不匹配"
+ ), "Node description mismatch"
assert (
node1_props.get("entity_type") == node1_data["entity_type"]
- ), "节点类型不匹配"
+ ), "Node type mismatch"
else:
- print(f"读取节点属性失败: {node1_id}")
- assert False, f"未能读取节点属性: {node1_id}"
+ print(f"Failed to read node properties: {node1_id}")
+ assert False, f"Failed to read node properties: {node1_id}"
- # 5. 读取边属性
- print(f"读取边属性: {node1_id} -> {node2_id}")
+ # 5. Read edge properties
+ print(f"Reading edge properties: {node1_id} -> {node2_id}")
edge_props = await storage.get_edge(node1_id, node2_id)
if edge_props:
- print(f"成功读取边属性: {node1_id} -> {node2_id}")
- print(f"边关系: {edge_props.get('relationship', '无关系')}")
- print(f"边描述: {edge_props.get('description', '无描述')}")
- print(f"边权重: {edge_props.get('weight', '无权重')}")
- # 验证返回的属性是否正确
+ print(f"Successfully read edge properties: {node1_id} -> {node2_id}")
+ print(
+ f"Edge relationship: {edge_props.get('relationship', 'No relationship')}"
+ )
+ print(
+ f"Edge description: {edge_props.get('description', 'No description')}"
+ )
+ print(f"Edge weight: {edge_props.get('weight', 'No weight')}")
+ # Verify that the returned properties are correct
assert (
edge_props.get("relationship") == edge_data["relationship"]
- ), "边关系不匹配"
+ ), "Edge relationship mismatch"
assert (
edge_props.get("description") == edge_data["description"]
- ), "边描述不匹配"
- assert edge_props.get("weight") == edge_data["weight"], "边权重不匹配"
+ ), "Edge description mismatch"
+ assert (
+ edge_props.get("weight") == edge_data["weight"]
+ ), "Edge weight mismatch"
else:
- print(f"读取边属性失败: {node1_id} -> {node2_id}")
- assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
+ print(f"Failed to read edge properties: {node1_id} -> {node2_id}")
+ assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}"
- # 5.1 验证无向图特性 - 读取反向边属性
- print(f"读取反向边属性: {node2_id} -> {node1_id}")
+ # 5.1 Verify undirected graph property - read reverse edge properties
+ print(f"Reading reverse edge properties: {node2_id} -> {node1_id}")
reverse_edge_props = await storage.get_edge(node2_id, node1_id)
if reverse_edge_props:
- print(f"成功读取反向边属性: {node2_id} -> {node1_id}")
- print(f"反向边关系: {reverse_edge_props.get('relationship', '无关系')}")
- print(f"反向边描述: {reverse_edge_props.get('description', '无描述')}")
- print(f"反向边权重: {reverse_edge_props.get('weight', '无权重')}")
- # 验证正向和反向边属性是否相同
+ print(
+ f"Successfully read reverse edge properties: {node2_id} -> {node1_id}"
+ )
+ print(
+ f"Reverse edge relationship: {reverse_edge_props.get('relationship', 'No relationship')}"
+ )
+ print(
+ f"Reverse edge description: {reverse_edge_props.get('description', 'No description')}"
+ )
+ print(
+ f"Reverse edge weight: {reverse_edge_props.get('weight', 'No weight')}"
+ )
+ # Verify that forward and reverse edge properties are the same
assert (
edge_props == reverse_edge_props
- ), "正向和反向边属性不一致,无向图特性验证失败"
- print("无向图特性验证成功:正向和反向边属性一致")
+ ), "Forward and reverse edge properties are not consistent, undirected graph property verification failed"
+ print(
+ "Undirected graph property verification successful: forward and reverse edge properties are consistent"
+ )
else:
- print(f"读取反向边属性失败: {node2_id} -> {node1_id}")
- assert (
- False
- ), f"未能读取反向边属性: {node2_id} -> {node1_id},无向图特性验证失败"
+ print(f"Failed to read reverse edge properties: {node2_id} -> {node1_id}")
+ assert False, f"Failed to read reverse edge properties: {node2_id} -> {node1_id}, undirected graph property verification failed"
- print("基本测试完成,数据已保留在数据库中")
+ print("Basic tests completed, data is preserved in the database.")
return True
except Exception as e:
- ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
+ ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_advanced(storage):
"""
- 测试图数据库的高级操作:
- 1. 使用 node_degree 获取节点的度数
- 2. 使用 edge_degree 获取边的度数
- 3. 使用 get_node_edges 获取节点的所有边
- 4. 使用 get_all_labels 获取所有标签
- 5. 使用 get_knowledge_graph 获取知识图谱
- 6. 使用 delete_node 删除节点
- 7. 使用 remove_nodes 批量删除节点
- 8. 使用 remove_edges 删除边
- 9. 使用 drop 清理数据
+ Test advanced graph database operations:
+ 1. Use node_degree to get the degree of a node.
+ 2. Use edge_degree to get the degree of an edge.
+ 3. Use get_node_edges to get all edges of a node.
+ 4. Use get_all_labels to get all labels.
+ 5. Use get_knowledge_graph to get a knowledge graph.
+ 6. Use delete_node to delete a node.
+ 7. Use remove_nodes to delete multiple nodes.
+ 8. Use remove_edges to delete edges.
+ 9. Use drop to clean up data.
"""
try:
- # 1. 插入测试数据
- # 插入节点1: 人工智能
- node1_id = "人工智能"
+ # 1. Insert test data
+ # Insert node 1: Artificial Intelligence
+ node1_id = "Artificial Intelligence"
node1_data = {
"entity_id": node1_id,
- "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
- "keywords": "AI,机器学习,深度学习",
- "entity_type": "技术领域",
+ "description": "Artificial intelligence is a branch of computer science that aims to understand the essence of intelligence and produce a new kind of intelligent machine that can react in a manner similar to human intelligence.",
+ "keywords": "AI,Machine Learning,Deep Learning",
+ "entity_type": "Technology Field",
}
- print(f"插入节点1: {node1_id}")
+ print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
- # 插入节点2: 机器学习
- node2_id = "机器学习"
+ # Insert node 2: Machine Learning
+ node2_id = "Machine Learning"
node2_data = {
"entity_id": node2_id,
- "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
- "keywords": "监督学习,无监督学习,强化学习",
- "entity_type": "技术领域",
+ "description": "Machine learning is a branch of artificial intelligence that uses statistical methods to enable computer systems to learn without being explicitly programmed.",
+ "keywords": "Supervised Learning,Unsupervised Learning,Reinforcement Learning",
+ "entity_type": "Technology Field",
}
- print(f"插入节点2: {node2_id}")
+ print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
- # 插入节点3: 深度学习
- node3_id = "深度学习"
+ # Insert node 3: Deep Learning
+ node3_id = "Deep Learning"
node3_data = {
"entity_id": node3_id,
- "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
- "keywords": "神经网络,CNN,RNN",
- "entity_type": "技术领域",
+ "description": "Deep learning is a branch of machine learning that uses multi-layered neural networks to simulate the learning process of the human brain.",
+ "keywords": "Neural Networks,CNN,RNN",
+ "entity_type": "Technology Field",
}
- print(f"插入节点3: {node3_id}")
+ print(f"Inserting node 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
- # 插入边1: 人工智能 -> 机器学习
+ # Insert edge 1: Artificial Intelligence -> Machine Learning
edge1_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "人工智能领域包含机器学习这个子领域",
+ "description": "The field of artificial intelligence includes the subfield of machine learning.",
}
- print(f"插入边1: {node1_id} -> {node2_id}")
+ print(f"Inserting edge 1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
- # 插入边2: 机器学习 -> 深度学习
+ # Insert edge 2: Machine Learning -> Deep Learning
edge2_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "机器学习领域包含深度学习这个子领域",
+ "description": "The field of machine learning includes the subfield of deep learning.",
}
- print(f"插入边2: {node2_id} -> {node3_id}")
+ print(f"Inserting edge 2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
- # 2. 测试 node_degree - 获取节点的度数
- print(f"== 测试 node_degree: {node1_id}")
+ # 2. Test node_degree - get the degree of a node
+ print(f"== Testing node_degree: {node1_id}")
node1_degree = await storage.node_degree(node1_id)
- print(f"节点 {node1_id} 的度数: {node1_degree}")
- assert node1_degree == 1, f"节点 {node1_id} 的度数应为1,实际为 {node1_degree}"
+ print(f"Degree of node {node1_id}: {node1_degree}")
+ assert (
+ node1_degree == 1
+ ), f"Degree of node {node1_id} should be 1, but got {node1_degree}"
- # 2.1 测试所有节点的度数
- print("== 测试所有节点的度数")
+ # 2.1 Test degrees of all nodes
+ print("== Testing degrees of all nodes")
node2_degree = await storage.node_degree(node2_id)
node3_degree = await storage.node_degree(node3_id)
- print(f"节点 {node2_id} 的度数: {node2_degree}")
- print(f"节点 {node3_id} 的度数: {node3_degree}")
- assert node2_degree == 2, f"节点 {node2_id} 的度数应为2,实际为 {node2_degree}"
- assert node3_degree == 1, f"节点 {node3_id} 的度数应为1,实际为 {node3_degree}"
+ print(f"Degree of node {node2_id}: {node2_degree}")
+ print(f"Degree of node {node3_id}: {node3_degree}")
+ assert (
+ node2_degree == 2
+ ), f"Degree of node {node2_id} should be 2, but got {node2_degree}"
+ assert (
+ node3_degree == 1
+ ), f"Degree of node {node3_id} should be 1, but got {node3_degree}"
- # 3. 测试 edge_degree - 获取边的度数
- print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
+ # 3. Test edge_degree - get the degree of an edge
+ print(f"== Testing edge_degree: {node1_id} -> {node2_id}")
edge_degree = await storage.edge_degree(node1_id, node2_id)
- print(f"边 {node1_id} -> {node2_id} 的度数: {edge_degree}")
+ print(f"Degree of edge {node1_id} -> {node2_id}: {edge_degree}")
assert (
edge_degree == 3
- ), f"边 {node1_id} -> {node2_id} 的度数应为3,实际为 {edge_degree}"
+ ), f"Degree of edge {node1_id} -> {node2_id} should be 3, but got {edge_degree}"
- # 3.1 测试反向边的度数 - 验证无向图特性
- print(f"== 测试反向边的度数: {node2_id} -> {node1_id}")
+ # 3.1 Test reverse edge degree - verify undirected graph property
+ print(f"== Testing reverse edge degree: {node2_id} -> {node1_id}")
reverse_edge_degree = await storage.edge_degree(node2_id, node1_id)
- print(f"反向边 {node2_id} -> {node1_id} 的度数: {reverse_edge_degree}")
+ print(f"Degree of reverse edge {node2_id} -> {node1_id}: {reverse_edge_degree}")
assert (
edge_degree == reverse_edge_degree
- ), "正向边和反向边的度数不一致,无向图特性验证失败"
- print("无向图特性验证成功:正向边和反向边的度数一致")
+ ), "Degrees of forward and reverse edges are not consistent, undirected graph property verification failed"
+ print(
+ "Undirected graph property verification successful: degrees of forward and reverse edges are consistent"
+ )
- # 4. 测试 get_node_edges - 获取节点的所有边
- print(f"== 测试 get_node_edges: {node2_id}")
+ # 4. Test get_node_edges - get all edges of a node
+ print(f"== Testing get_node_edges: {node2_id}")
node2_edges = await storage.get_node_edges(node2_id)
- print(f"节点 {node2_id} 的所有边: {node2_edges}")
+ print(f"All edges of node {node2_id}: {node2_edges}")
assert (
len(node2_edges) == 2
- ), f"节点 {node2_id} 应有2条边,实际有 {len(node2_edges)}"
+ ), f"Node {node2_id} should have 2 edges, but got {len(node2_edges)}"
- # 4.1 验证节点边的无向图特性
- print("== 验证节点边的无向图特性")
- # 检查是否包含与node1和node3的连接关系(无论方向)
+ # 4.1 Verify undirected graph property of node edges
+ print("== Verifying undirected graph property of node edges")
+ # Check if it includes connections with node1 and node3 (regardless of direction)
has_connection_with_node1 = False
has_connection_with_node3 = False
for edge in node2_edges:
- # 检查是否有与node1的连接(无论方向)
+ # Check for connection with node1 (regardless of direction)
if (edge[0] == node1_id and edge[1] == node2_id) or (
edge[0] == node2_id and edge[1] == node1_id
):
has_connection_with_node1 = True
- # 检查是否有与node3的连接(无论方向)
+ # Check for connection with node3 (regardless of direction)
if (edge[0] == node2_id and edge[1] == node3_id) or (
edge[0] == node3_id and edge[1] == node2_id
):
@@ -361,377 +388,408 @@ async def test_graph_advanced(storage):
assert (
has_connection_with_node1
- ), f"节点 {node2_id} 的边列表中应包含与 {node1_id} 的连接"
+ ), f"Edge list of node {node2_id} should include a connection with {node1_id}"
assert (
has_connection_with_node3
- ), f"节点 {node2_id} 的边列表中应包含与 {node3_id} 的连接"
- print(f"无向图特性验证成功:节点 {node2_id} 的边列表包含所有相关的边")
+ ), f"Edge list of node {node2_id} should include a connection with {node3_id}"
+ print(
+ f"Undirected graph property verification successful: edge list of node {node2_id} contains all relevant edges"
+ )
- # 5. 测试 get_all_labels - 获取所有标签
- print("== 测试 get_all_labels")
+ # 5. Test get_all_labels - get all labels
+ print("== Testing get_all_labels")
all_labels = await storage.get_all_labels()
- print(f"所有标签: {all_labels}")
- assert len(all_labels) == 3, f"应有3个标签,实际有 {len(all_labels)}"
- assert node1_id in all_labels, f"{node1_id} 应在标签列表中"
- assert node2_id in all_labels, f"{node2_id} 应在标签列表中"
- assert node3_id in all_labels, f"{node3_id} 应在标签列表中"
+ print(f"All labels: {all_labels}")
+ assert len(all_labels) == 3, f"Should have 3 labels, but got {len(all_labels)}"
+ assert node1_id in all_labels, f"{node1_id} should be in the label list"
+ assert node2_id in all_labels, f"{node2_id} should be in the label list"
+ assert node3_id in all_labels, f"{node3_id} should be in the label list"
- # 6. 测试 get_knowledge_graph - 获取知识图谱
- print("== 测试 get_knowledge_graph")
+ # 6. Test get_knowledge_graph - get a knowledge graph
+ print("== Testing get_knowledge_graph")
kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
- print(f"知识图谱节点数: {len(kg.nodes)}")
- print(f"知识图谱边数: {len(kg.edges)}")
- assert isinstance(kg, KnowledgeGraph), "返回结果应为 KnowledgeGraph 类型"
- assert len(kg.nodes) == 3, f"知识图谱应有3个节点,实际有 {len(kg.nodes)}"
- assert len(kg.edges) == 2, f"知识图谱应有2条边,实际有 {len(kg.edges)}"
+ print(f"Number of nodes in knowledge graph: {len(kg.nodes)}")
+ print(f"Number of edges in knowledge graph: {len(kg.edges)}")
+ assert isinstance(
+ kg, KnowledgeGraph
+ ), "The returned result should be of type KnowledgeGraph"
+ assert (
+ len(kg.nodes) == 3
+ ), f"The knowledge graph should have 3 nodes, but got {len(kg.nodes)}"
+ assert (
+ len(kg.edges) == 2
+ ), f"The knowledge graph should have 2 edges, but got {len(kg.edges)}"
- # 7. 测试 delete_node - 删除节点
- print(f"== 测试 delete_node: {node3_id}")
+ # 7. Test delete_node - delete a node
+ print(f"== Testing delete_node: {node3_id}")
await storage.delete_node(node3_id)
node3_props = await storage.get_node(node3_id)
- print(f"删除后查询节点属性 {node3_id}: {node3_props}")
- assert node3_props is None, f"节点 {node3_id} 应已被删除"
+ print(f"Querying node properties after deletion {node3_id}: {node3_props}")
+ assert node3_props is None, f"Node {node3_id} should have been deleted"
- # 重新插入节点3用于后续测试
+ # Re-insert node 3 for subsequent tests
await storage.upsert_node(node3_id, node3_data)
await storage.upsert_edge(node2_id, node3_id, edge2_data)
- # 8. 测试 remove_edges - 删除边
- print(f"== 测试 remove_edges: {node2_id} -> {node3_id}")
+ # 8. Test remove_edges - delete edges
+ print(f"== Testing remove_edges: {node2_id} -> {node3_id}")
await storage.remove_edges([(node2_id, node3_id)])
edge_props = await storage.get_edge(node2_id, node3_id)
- print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}")
- assert edge_props is None, f"边 {node2_id} -> {node3_id} 应已被删除"
+ print(
+ f"Querying edge properties after deletion {node2_id} -> {node3_id}: {edge_props}"
+ )
+ assert (
+ edge_props is None
+ ), f"Edge {node2_id} -> {node3_id} should have been deleted"
- # 8.1 验证删除边的无向图特性
- print(f"== 验证删除边的无向图特性: {node3_id} -> {node2_id}")
+ # 8.1 Verify undirected graph property of edge deletion
+ print(
+ f"== Verifying undirected graph property of edge deletion: {node3_id} -> {node2_id}"
+ )
reverse_edge_props = await storage.get_edge(node3_id, node2_id)
- print(f"删除后查询反向边属性 {node3_id} -> {node2_id}: {reverse_edge_props}")
+ print(
+ f"Querying reverse edge properties after deletion {node3_id} -> {node2_id}: {reverse_edge_props}"
+ )
assert (
reverse_edge_props is None
- ), f"反向边 {node3_id} -> {node2_id} 也应被删除,无向图特性验证失败"
- print("无向图特性验证成功:删除一个方向的边后,反向边也被删除")
+ ), f"Reverse edge {node3_id} -> {node2_id} should also be deleted, undirected graph property verification failed"
+ print(
+ "Undirected graph property verification successful: deleting an edge in one direction also deletes the reverse edge"
+ )
- # 9. 测试 remove_nodes - 批量删除节点
- print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]")
+ # 9. Test remove_nodes - delete multiple nodes
+ print(f"== Testing remove_nodes: [{node2_id}, {node3_id}]")
await storage.remove_nodes([node2_id, node3_id])
node2_props = await storage.get_node(node2_id)
node3_props = await storage.get_node(node3_id)
- print(f"删除后查询节点属性 {node2_id}: {node2_props}")
- print(f"删除后查询节点属性 {node3_id}: {node3_props}")
- assert node2_props is None, f"节点 {node2_id} 应已被删除"
- assert node3_props is None, f"节点 {node3_id} 应已被删除"
+ print(f"Querying node properties after deletion {node2_id}: {node2_props}")
+ print(f"Querying node properties after deletion {node3_id}: {node3_props}")
+ assert node2_props is None, f"Node {node2_id} should have been deleted"
+ assert node3_props is None, f"Node {node3_id} should have been deleted"
- print("\n高级测试完成")
+ print("\nAdvanced tests completed.")
return True
except Exception as e:
- ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
+ ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_batch_operations(storage):
"""
- 测试图数据库的批量操作:
- 1. 使用 get_nodes_batch 批量获取多个节点的属性
- 2. 使用 node_degrees_batch 批量获取多个节点的度数
- 3. 使用 edge_degrees_batch 批量获取多个边的度数
- 4. 使用 get_edges_batch 批量获取多个边的属性
- 5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
+ Test batch operations of the graph database:
+ 1. Use get_nodes_batch to get properties of multiple nodes in batch.
+ 2. Use node_degrees_batch to get degrees of multiple nodes in batch.
+ 3. Use edge_degrees_batch to get degrees of multiple edges in batch.
+ 4. Use get_edges_batch to get properties of multiple edges in batch.
+ 5. Use get_nodes_edges_batch to get all edges of multiple nodes in batch.
"""
try:
chunk1_id = "1"
chunk2_id = "2"
chunk3_id = "3"
- # 1. 插入测试数据
- # 插入节点1: 人工智能
- node1_id = "人工智能"
+ # 1. Insert test data
+ # Insert node 1: Artificial Intelligence
+ node1_id = "Artificial Intelligence"
node1_data = {
"entity_id": node1_id,
- "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
- "keywords": "AI,机器学习,深度学习",
- "entity_type": "技术领域",
+ "description": "Artificial intelligence is a branch of computer science that aims to understand the essence of intelligence and produce a new kind of intelligent machine that can react in a manner similar to human intelligence.",
+ "keywords": "AI,Machine Learning,Deep Learning",
+ "entity_type": "Technology Field",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
- print(f"插入节点1: {node1_id}")
+ print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
- # 插入节点2: 机器学习
- node2_id = "机器学习"
+ # Insert node 2: Machine Learning
+ node2_id = "Machine Learning"
node2_data = {
"entity_id": node2_id,
- "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
- "keywords": "监督学习,无监督学习,强化学习",
- "entity_type": "技术领域",
+ "description": "Machine learning is a branch of artificial intelligence that uses statistical methods to enable computer systems to learn without being explicitly programmed.",
+ "keywords": "Supervised Learning,Unsupervised Learning,Reinforcement Learning",
+ "entity_type": "Technology Field",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
- print(f"插入节点2: {node2_id}")
+ print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
- # 插入节点3: 深度学习
- node3_id = "深度学习"
+ # Insert node 3: Deep Learning
+ node3_id = "Deep Learning"
node3_data = {
"entity_id": node3_id,
- "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
- "keywords": "神经网络,CNN,RNN",
- "entity_type": "技术领域",
+ "description": "Deep learning is a branch of machine learning that uses multi-layered neural networks to simulate the learning process of the human brain.",
+ "keywords": "Neural Networks,CNN,RNN",
+ "entity_type": "Technology Field",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
- print(f"插入节点3: {node3_id}")
+ print(f"Inserting node 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
- # 插入节点4: 自然语言处理
- node4_id = "自然语言处理"
+ # Insert node 4: Natural Language Processing
+ node4_id = "Natural Language Processing"
node4_data = {
"entity_id": node4_id,
- "description": "自然语言处理是人工智能的一个分支,专注于使计算机理解和处理人类语言。",
- "keywords": "NLP,文本分析,语言模型",
- "entity_type": "技术领域",
+ "description": "Natural language processing is a branch of artificial intelligence that focuses on enabling computers to understand and process human language.",
+ "keywords": "NLP,Text Analysis,Language Models",
+ "entity_type": "Technology Field",
}
- print(f"插入节点4: {node4_id}")
+ print(f"Inserting node 4: {node4_id}")
await storage.upsert_node(node4_id, node4_data)
- # 插入节点5: 计算机视觉
- node5_id = "计算机视觉"
+ # Insert node 5: Computer Vision
+ node5_id = "Computer Vision"
node5_data = {
"entity_id": node5_id,
- "description": "计算机视觉是人工智能的一个分支,专注于使计算机能够从图像或视频中获取信息。",
- "keywords": "CV,图像识别,目标检测",
- "entity_type": "技术领域",
+ "description": "Computer vision is a branch of artificial intelligence that focuses on enabling computers to gain information from images or videos.",
+ "keywords": "CV,Image Recognition,Object Detection",
+ "entity_type": "Technology Field",
}
- print(f"插入节点5: {node5_id}")
+ print(f"Inserting node 5: {node5_id}")
await storage.upsert_node(node5_id, node5_data)
- # 插入边1: 人工智能 -> 机器学习
+ # Insert edge 1: Artificial Intelligence -> Machine Learning
edge1_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "人工智能领域包含机器学习这个子领域",
+ "description": "The field of artificial intelligence includes the subfield of machine learning.",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
- print(f"插入边1: {node1_id} -> {node2_id}")
+ print(f"Inserting edge 1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
- # 插入边2: 机器学习 -> 深度学习
+ # Insert edge 2: Machine Learning -> Deep Learning
edge2_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "机器学习领域包含深度学习这个子领域",
+ "description": "The field of machine learning includes the subfield of deep learning.",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
- print(f"插入边2: {node2_id} -> {node3_id}")
+ print(f"Inserting edge 2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
- # 插入边3: 人工智能 -> 自然语言处理
+ # Insert edge 3: Artificial Intelligence -> Natural Language Processing
edge3_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "人工智能领域包含自然语言处理这个子领域",
+ "description": "The field of artificial intelligence includes the subfield of natural language processing.",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
- print(f"插入边3: {node1_id} -> {node4_id}")
+ print(f"Inserting edge 3: {node1_id} -> {node4_id}")
await storage.upsert_edge(node1_id, node4_id, edge3_data)
- # 插入边4: 人工智能 -> 计算机视觉
+ # Insert edge 4: Artificial Intelligence -> Computer Vision
edge4_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "人工智能领域包含计算机视觉这个子领域",
+ "description": "The field of artificial intelligence includes the subfield of computer vision.",
}
- print(f"插入边4: {node1_id} -> {node5_id}")
+ print(f"Inserting edge 4: {node1_id} -> {node5_id}")
await storage.upsert_edge(node1_id, node5_id, edge4_data)
- # 插入边5: 深度学习 -> 自然语言处理
+ # Insert edge 5: Deep Learning -> Natural Language Processing
edge5_data = {
- "relationship": "应用于",
+ "relationship": "applied to",
"weight": 0.8,
- "description": "深度学习技术应用于自然语言处理领域",
+ "description": "Deep learning techniques are applied in the field of natural language processing.",
}
- print(f"插入边5: {node3_id} -> {node4_id}")
+ print(f"Inserting edge 5: {node3_id} -> {node4_id}")
await storage.upsert_edge(node3_id, node4_id, edge5_data)
- # 插入边6: 深度学习 -> 计算机视觉
+ # Insert edge 6: Deep Learning -> Computer Vision
edge6_data = {
- "relationship": "应用于",
+ "relationship": "applied to",
"weight": 0.8,
- "description": "深度学习技术应用于计算机视觉领域",
+ "description": "Deep learning techniques are applied in the field of computer vision.",
}
- print(f"插入边6: {node3_id} -> {node5_id}")
+ print(f"Inserting edge 6: {node3_id} -> {node5_id}")
await storage.upsert_edge(node3_id, node5_id, edge6_data)
- # 2. 测试 get_nodes_batch - 批量获取多个节点的属性
- print("== 测试 get_nodes_batch")
+ # 2. Test get_nodes_batch - batch get properties of multiple nodes
+ print("== Testing get_nodes_batch")
node_ids = [node1_id, node2_id, node3_id]
nodes_dict = await storage.get_nodes_batch(node_ids)
- print(f"批量获取节点属性结果: {nodes_dict.keys()}")
- assert len(nodes_dict) == 3, f"应返回3个节点,实际返回 {len(nodes_dict)} 个"
- assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
- assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
- assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
+ print(f"Batch get node properties result: {nodes_dict.keys()}")
+ assert len(nodes_dict) == 3, f"Should return 3 nodes, but got {len(nodes_dict)}"
+ assert node1_id in nodes_dict, f"{node1_id} should be in the result"
+ assert node2_id in nodes_dict, f"{node2_id} should be in the result"
+ assert node3_id in nodes_dict, f"{node3_id} should be in the result"
assert (
nodes_dict[node1_id]["description"] == node1_data["description"]
- ), f"{node1_id} 描述不匹配"
+ ), f"{node1_id} description mismatch"
assert (
nodes_dict[node2_id]["description"] == node2_data["description"]
- ), f"{node2_id} 描述不匹配"
+ ), f"{node2_id} description mismatch"
assert (
nodes_dict[node3_id]["description"] == node3_data["description"]
- ), f"{node3_id} 描述不匹配"
+ ), f"{node3_id} description mismatch"
- # 3. 测试 node_degrees_batch - 批量获取多个节点的度数
- print("== 测试 node_degrees_batch")
+ # 3. Test node_degrees_batch - batch get degrees of multiple nodes
+ print("== Testing node_degrees_batch")
node_degrees = await storage.node_degrees_batch(node_ids)
- print(f"批量获取节点度数结果: {node_degrees}")
+ print(f"Batch get node degrees result: {node_degrees}")
assert (
len(node_degrees) == 3
- ), f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
- assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
- assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
- assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
+ ), f"Should return degrees of 3 nodes, but got {len(node_degrees)}"
+ assert node1_id in node_degrees, f"{node1_id} should be in the result"
+ assert node2_id in node_degrees, f"{node2_id} should be in the result"
+ assert node3_id in node_degrees, f"{node3_id} should be in the result"
assert (
node_degrees[node1_id] == 3
- ), f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
+ ), f"Degree of {node1_id} should be 3, but got {node_degrees[node1_id]}"
assert (
node_degrees[node2_id] == 2
- ), f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
+ ), f"Degree of {node2_id} should be 2, but got {node_degrees[node2_id]}"
assert (
node_degrees[node3_id] == 3
- ), f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
+ ), f"Degree of {node3_id} should be 3, but got {node_degrees[node3_id]}"
- # 4. 测试 edge_degrees_batch - 批量获取多个边的度数
- print("== 测试 edge_degrees_batch")
+ # 4. Test edge_degrees_batch - batch get degrees of multiple edges
+ print("== Testing edge_degrees_batch")
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
edge_degrees = await storage.edge_degrees_batch(edges)
- print(f"批量获取边度数结果: {edge_degrees}")
+ print(f"Batch get edge degrees result: {edge_degrees}")
assert (
len(edge_degrees) == 3
- ), f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
+ ), f"Should return degrees of 3 edges, but got {len(edge_degrees)}"
assert (
node1_id,
node2_id,
- ) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
+ ) in edge_degrees, f"Edge {node1_id} -> {node2_id} should be in the result"
assert (
node2_id,
node3_id,
- ) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
+ ) in edge_degrees, f"Edge {node2_id} -> {node3_id} should be in the result"
assert (
node3_id,
node4_id,
- ) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
- # 验证边的度数是否正确(源节点度数 + 目标节点度数)
+ ) in edge_degrees, f"Edge {node3_id} -> {node4_id} should be in the result"
+ # Verify edge degrees (sum of source and target node degrees)
assert (
edge_degrees[(node1_id, node2_id)] == 5
- ), f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
+ ), f"Degree of edge {node1_id} -> {node2_id} should be 5, but got {edge_degrees[(node1_id, node2_id)]}"
assert (
edge_degrees[(node2_id, node3_id)] == 5
- ), f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
+ ), f"Degree of edge {node2_id} -> {node3_id} should be 5, but got {edge_degrees[(node2_id, node3_id)]}"
assert (
edge_degrees[(node3_id, node4_id)] == 5
- ), f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
+ ), f"Degree of edge {node3_id} -> {node4_id} should be 5, but got {edge_degrees[(node3_id, node4_id)]}"
- # 5. 测试 get_edges_batch - 批量获取多个边的属性
- print("== 测试 get_edges_batch")
- # 将元组列表转换为Neo4j风格的字典列表
+ # 5. Test get_edges_batch - batch get properties of multiple edges
+ print("== Testing get_edges_batch")
+ # Convert list of tuples to list of dicts for Neo4j style
edge_dicts = [{"src": src, "tgt": tgt} for src, tgt in edges]
edges_dict = await storage.get_edges_batch(edge_dicts)
- print(f"批量获取边属性结果: {edges_dict.keys()}")
- assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
+ print(f"Batch get edge properties result: {edges_dict.keys()}")
+ assert (
+ len(edges_dict) == 3
+ ), f"Should return properties of 3 edges, but got {len(edges_dict)}"
assert (
node1_id,
node2_id,
- ) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
+ ) in edges_dict, f"Edge {node1_id} -> {node2_id} should be in the result"
assert (
node2_id,
node3_id,
- ) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
+ ) in edges_dict, f"Edge {node2_id} -> {node3_id} should be in the result"
assert (
node3_id,
node4_id,
- ) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
+ ) in edges_dict, f"Edge {node3_id} -> {node4_id} should be in the result"
assert (
edges_dict[(node1_id, node2_id)]["relationship"]
== edge1_data["relationship"]
- ), f"边 {node1_id} -> {node2_id} 关系不匹配"
+ ), f"Edge {node1_id} -> {node2_id} relationship mismatch"
assert (
edges_dict[(node2_id, node3_id)]["relationship"]
== edge2_data["relationship"]
- ), f"边 {node2_id} -> {node3_id} 关系不匹配"
+ ), f"Edge {node2_id} -> {node3_id} relationship mismatch"
assert (
edges_dict[(node3_id, node4_id)]["relationship"]
== edge5_data["relationship"]
- ), f"边 {node3_id} -> {node4_id} 关系不匹配"
+ ), f"Edge {node3_id} -> {node4_id} relationship mismatch"
- # 5.1 测试反向边的批量获取 - 验证无向图特性
- print("== 测试反向边的批量获取")
- # 创建反向边的字典列表
+ # 5.1 Test batch get of reverse edges - verify undirected property
+ print("== Testing batch get of reverse edges")
+ # Create list of dicts for reverse edges
reverse_edge_dicts = [{"src": tgt, "tgt": src} for src, tgt in edges]
reverse_edges_dict = await storage.get_edges_batch(reverse_edge_dicts)
- print(f"批量获取反向边属性结果: {reverse_edges_dict.keys()}")
+ print(f"Batch get reverse edge properties result: {reverse_edges_dict.keys()}")
assert (
len(reverse_edges_dict) == 3
- ), f"应返回3条反向边的属性,实际返回 {len(reverse_edges_dict)} 条"
+ ), f"Should return properties of 3 reverse edges, but got {len(reverse_edges_dict)}"
- # 验证正向和反向边的属性是否一致
+ # Verify that properties of forward and reverse edges are consistent
for (src, tgt), props in edges_dict.items():
assert (
- tgt,
- src,
- ) in reverse_edges_dict, f"反向边 {tgt} -> {src} 应在返回结果中"
+ (
+ tgt,
+ src,
+ )
+ in reverse_edges_dict
+ ), f"Reverse edge {tgt} -> {src} should be in the result"
assert (
props == reverse_edges_dict[(tgt, src)]
- ), f"边 {src} -> {tgt} 和反向边 {tgt} -> {src} 的属性不一致"
+ ), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent"
- print("无向图特性验证成功:批量获取的正向和反向边属性一致")
+ print(
+ "Undirected graph property verification successful: properties of batch-retrieved forward and reverse edges are consistent"
+ )
- # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
- print("== 测试 get_nodes_edges_batch")
+ # 6. Test get_nodes_edges_batch - batch get all edges of multiple nodes
+ print("== Testing get_nodes_edges_batch")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
- print(f"批量获取节点边结果: {nodes_edges.keys()}")
+ print(f"Batch get node edges result: {nodes_edges.keys()}")
assert (
len(nodes_edges) == 2
- ), f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
- assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
- assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
+ ), f"Should return edges for 2 nodes, but got {len(nodes_edges)}"
+ assert node1_id in nodes_edges, f"{node1_id} should be in the result"
+ assert node3_id in nodes_edges, f"{node3_id} should be in the result"
assert (
len(nodes_edges[node1_id]) == 3
- ), f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
+ ), f"{node1_id} should have 3 edges, but has {len(nodes_edges[node1_id])}"
assert (
len(nodes_edges[node3_id]) == 3
- ), f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
+ ), f"{node3_id} should have 3 edges, but has {len(nodes_edges[node3_id])}"
- # 6.1 验证批量获取节点边的无向图特性
- print("== 验证批量获取节点边的无向图特性")
+ # 6.1 Verify undirected property of batch-retrieved node edges
+ print("== Verifying undirected property of batch-retrieved node edges")
- # 检查节点1的边是否包含所有相关的边(无论方向)
+ # Check if node 1's edges include all relevant edges (regardless of direction)
node1_outgoing_edges = [
(src, tgt) for src, tgt in nodes_edges[node1_id] if src == node1_id
]
node1_incoming_edges = [
(src, tgt) for src, tgt in nodes_edges[node1_id] if tgt == node1_id
]
- print(f"节点 {node1_id} 的出边: {node1_outgoing_edges}")
- print(f"节点 {node1_id} 的入边: {node1_incoming_edges}")
+ print(f"Outgoing edges of node {node1_id}: {node1_outgoing_edges}")
+ print(f"Incoming edges of node {node1_id}: {node1_incoming_edges}")
- # 检查是否包含到机器学习、自然语言处理和计算机视觉的边
+ # Check for edges to Machine Learning, Natural Language Processing, and Computer Vision
has_edge_to_node2 = any(tgt == node2_id for _, tgt in node1_outgoing_edges)
has_edge_to_node4 = any(tgt == node4_id for _, tgt in node1_outgoing_edges)
has_edge_to_node5 = any(tgt == node5_id for _, tgt in node1_outgoing_edges)
- assert has_edge_to_node2, f"节点 {node1_id} 的边列表中应包含到 {node2_id} 的边"
- assert has_edge_to_node4, f"节点 {node1_id} 的边列表中应包含到 {node4_id} 的边"
- assert has_edge_to_node5, f"节点 {node1_id} 的边列表中应包含到 {node5_id} 的边"
+ assert (
+ has_edge_to_node2
+ ), f"Edge list of node {node1_id} should include an edge to {node2_id}"
+ assert (
+ has_edge_to_node4
+ ), f"Edge list of node {node1_id} should include an edge to {node4_id}"
+ assert (
+ has_edge_to_node5
+ ), f"Edge list of node {node1_id} should include an edge to {node5_id}"
- # 检查节点3的边是否包含所有相关的边(无论方向)
+ # Check if node 3's edges include all relevant edges (regardless of direction)
node3_outgoing_edges = [
(src, tgt) for src, tgt in nodes_edges[node3_id] if src == node3_id
]
node3_incoming_edges = [
(src, tgt) for src, tgt in nodes_edges[node3_id] if tgt == node3_id
]
- print(f"节点 {node3_id} 的出边: {node3_outgoing_edges}")
- print(f"节点 {node3_id} 的入边: {node3_incoming_edges}")
+ print(f"Outgoing edges of node {node3_id}: {node3_outgoing_edges}")
+ print(f"Incoming edges of node {node3_id}: {node3_incoming_edges}")
- # 检查是否包含与机器学习、自然语言处理和计算机视觉的连接(忽略方向)
+ # Check for connections with Machine Learning, Natural Language Processing, and Computer Vision (ignoring direction)
has_connection_with_node2 = any(
(src == node2_id and tgt == node3_id)
or (src == node3_id and tgt == node2_id)
@@ -750,155 +808,89 @@ async def test_graph_batch_operations(storage):
assert (
has_connection_with_node2
- ), f"节点 {node3_id} 的边列表中应包含与 {node2_id} 的连接"
+ ), f"Edge list of node {node3_id} should include a connection with {node2_id}"
assert (
has_connection_with_node4
- ), f"节点 {node3_id} 的边列表中应包含与 {node4_id} 的连接"
+ ), f"Edge list of node {node3_id} should include a connection with {node4_id}"
assert (
has_connection_with_node5
- ), f"节点 {node3_id} 的边列表中应包含与 {node5_id} 的连接"
+ ), f"Edge list of node {node3_id} should include a connection with {node5_id}"
- print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
-
- # 7. 测试 get_nodes_by_chunk_ids - 批量根据 chunk_ids 获取多个节点
- print("== 测试 get_nodes_by_chunk_ids")
-
- print("== 测试单个 chunk_id,匹配多个节点")
- nodes = await storage.get_nodes_by_chunk_ids([chunk2_id])
- assert len(nodes) == 2, f"{chunk1_id} 应有2个节点,实际有 {len(nodes)} 个"
-
- has_node1 = any(node["entity_id"] == node1_id for node in nodes)
- has_node2 = any(node["entity_id"] == node2_id for node in nodes)
-
- assert has_node1, f"节点 {node1_id} 应在返回结果中"
- assert has_node2, f"节点 {node2_id} 应在返回结果中"
-
- print("== 测试多个 chunk_id,部分匹配多个节点")
- nodes = await storage.get_nodes_by_chunk_ids([chunk2_id, chunk3_id])
- assert (
- len(nodes) == 3
- ), f"{chunk2_id}, {chunk3_id} 应有3个节点,实际有 {len(nodes)} 个"
-
- has_node1 = any(node["entity_id"] == node1_id for node in nodes)
- has_node2 = any(node["entity_id"] == node2_id for node in nodes)
- has_node3 = any(node["entity_id"] == node3_id for node in nodes)
-
- assert has_node1, f"节点 {node1_id} 应在返回结果中"
- assert has_node2, f"节点 {node2_id} 应在返回结果中"
- assert has_node3, f"节点 {node3_id} 应在返回结果中"
-
- # 8. 测试 get_edges_by_chunk_ids - 批量根据 chunk_ids 获取多条边
- print("== 测试 get_edges_by_chunk_ids")
-
- print("== 测试单个 chunk_id,匹配多条边")
- edges = await storage.get_edges_by_chunk_ids([chunk2_id])
- assert len(edges) == 2, f"{chunk2_id} 应有2条边,实际有 {len(edges)} 条"
-
- has_edge_node1_node2 = any(
- edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
- )
- has_edge_node2_node3 = any(
- edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
+ print(
+ "Undirected graph property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)"
)
- assert has_edge_node1_node2, f"{chunk2_id} 应包含 {node1_id} 到 {node2_id} 的边"
- assert has_edge_node2_node3, f"{chunk2_id} 应包含 {node2_id} 到 {node3_id} 的边"
-
- print("== 测试多个 chunk_id,部分匹配多条边")
- edges = await storage.get_edges_by_chunk_ids([chunk2_id, chunk3_id])
- assert (
- len(edges) == 3
- ), f"{chunk2_id}, {chunk3_id} 应有3条边,实际有 {len(edges)} 条"
-
- has_edge_node1_node2 = any(
- edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
- )
- has_edge_node2_node3 = any(
- edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
- )
- has_edge_node1_node4 = any(
- edge["source"] == node1_id and edge["target"] == node4_id for edge in edges
- )
-
- assert (
- has_edge_node1_node2
- ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node2_id} 的边"
- assert (
- has_edge_node2_node3
- ), f"{chunk2_id}, {chunk3_id} 应包含 {node2_id} 到 {node3_id} 的边"
- assert (
- has_edge_node1_node4
- ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node4_id} 的边"
-
- print("\n批量操作测试完成")
+ print("\nBatch operations tests completed.")
return True
except Exception as e:
- ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
+ ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_special_characters(storage):
"""
- 测试图数据库对特殊字符的处理:
- 1. 测试节点名称和描述中包含单引号、双引号和反斜杠
- 2. 测试边的描述中包含单引号、双引号和反斜杠
- 3. 验证特殊字符是否被正确保存和检索
+ Test the graph database's handling of special characters:
+ 1. Test node names and descriptions containing single quotes, double quotes, and backslashes.
+ 2. Test edge descriptions containing single quotes, double quotes, and backslashes.
+ 3. Verify that special characters are saved and retrieved correctly.
"""
try:
- # 1. 测试节点名称中的特殊字符
- node1_id = "包含'单引号'的节点"
+ # 1. Test special characters in node name
+ node1_id = "Node with 'single quotes'"
node1_data = {
"entity_id": node1_id,
- "description": "这个描述包含'单引号'、\"双引号\"和\\反斜杠",
- "keywords": "特殊字符,引号,转义",
- "entity_type": "测试节点",
+ "description": "This description contains 'single quotes', \"double quotes\", and \\backslashes",
+ "keywords": "special characters,quotes,escaping",
+ "entity_type": "Test Node",
}
- print(f"插入包含特殊字符的节点1: {node1_id}")
+ print(f"Inserting node with special characters 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
- # 2. 测试节点名称中的双引号
- node2_id = '包含"双引号"的节点'
+ # 2. Test double quotes in node name
+ node2_id = 'Node with "double quotes"'
node2_data = {
"entity_id": node2_id,
- "description": "这个描述同时包含'单引号'和\"双引号\"以及\\反斜杠\\路径",
- "keywords": "特殊字符,引号,JSON",
- "entity_type": "测试节点",
+ "description": "This description contains both 'single quotes' and \"double quotes\" and \\a\\path",
+ "keywords": "special characters,quotes,JSON",
+ "entity_type": "Test Node",
}
- print(f"插入包含特殊字符的节点2: {node2_id}")
+ print(f"Inserting node with special characters 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
- # 3. 测试节点名称中的反斜杠
- node3_id = "包含\\反斜杠\\的节点"
+ # 3. Test backslashes in node name
+ node3_id = "Node with \\backslashes\\"
node3_data = {
"entity_id": node3_id,
- "description": "这个描述包含Windows路径C:\\Program Files\\和转义字符\\n\\t",
- "keywords": "反斜杠,路径,转义",
- "entity_type": "测试节点",
+ "description": "This description contains a Windows path C:\\Program Files\\ and escape characters \\n\\t",
+ "keywords": "backslashes,paths,escaping",
+ "entity_type": "Test Node",
}
- print(f"插入包含特殊字符的节点3: {node3_id}")
+ print(f"Inserting node with special characters 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
- # 4. 测试边描述中的特殊字符
+ # 4. Test special characters in edge description
edge1_data = {
- "relationship": "特殊'关系'",
+ "relationship": "special 'relationship'",
"weight": 1.0,
- "description": "这个边描述包含'单引号'、\"双引号\"和\\反斜杠",
+ "description": "This edge description contains 'single quotes', \"double quotes\", and \\backslashes",
}
- print(f"插入包含特殊字符的边: {node1_id} -> {node2_id}")
+ print(f"Inserting edge with special characters: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
- # 5. 测试边描述中的更复杂特殊字符组合
+ # 5. Test more complex combination of special characters in edge description
edge2_data = {
- "relationship": '复杂"关系"\\类型',
+ "relationship": 'complex "relationship"\\type',
"weight": 0.8,
- "description": "包含SQL注入尝试: SELECT * FROM users WHERE name='admin'--",
+ "description": "Contains SQL injection attempt: SELECT * FROM users WHERE name='admin'--",
}
- print(f"插入包含复杂特殊字符的边: {node2_id} -> {node3_id}")
+ print(
+ f"Inserting edge with complex special characters: {node2_id} -> {node3_id}"
+ )
await storage.upsert_edge(node2_id, node3_id, edge2_data)
- # 6. 验证节点特殊字符是否正确保存
- print("\n== 验证节点特殊字符")
+ # 6. Verify that node special characters are saved correctly
+ print("\n== Verifying node special characters")
for node_id, original_data in [
(node1_id, node1_data),
(node2_id, node2_data),
@@ -906,196 +898,226 @@ async def test_graph_special_characters(storage):
]:
node_props = await storage.get_node(node_id)
if node_props:
- print(f"成功读取节点: {node_id}")
- print(f"节点描述: {node_props.get('description', '无描述')}")
+ print(f"Successfully read node: {node_id}")
+ print(
+ f"Node description: {node_props.get('description', 'No description')}"
+ )
- # 验证节点ID是否正确保存
+ # Verify node ID is saved correctly
assert (
node_props.get("entity_id") == node_id
- ), f"节点ID不匹配: 期望 {node_id}, 实际 {node_props.get('entity_id')}"
+ ), f"Node ID mismatch: expected {node_id}, got {node_props.get('entity_id')}"
- # 验证描述是否正确保存
+ # Verify description is saved correctly
assert (
node_props.get("description") == original_data["description"]
- ), f"节点描述不匹配: 期望 {original_data['description']}, 实际 {node_props.get('description')}"
+ ), f"Node description mismatch: expected {original_data['description']}, got {node_props.get('description')}"
- print(f"节点 {node_id} 特殊字符验证成功")
+ print(f"Node {node_id} special character verification successful")
else:
- print(f"读取节点属性失败: {node_id}")
- assert False, f"未能读取节点属性: {node_id}"
+ print(f"Failed to read node properties: {node_id}")
+ assert False, f"Failed to read node properties: {node_id}"
- # 7. 验证边特殊字符是否正确保存
- print("\n== 验证边特殊字符")
+ # 7. Verify that edge special characters are saved correctly
+ print("\n== Verifying edge special characters")
edge1_props = await storage.get_edge(node1_id, node2_id)
if edge1_props:
- print(f"成功读取边: {node1_id} -> {node2_id}")
- print(f"边关系: {edge1_props.get('relationship', '无关系')}")
- print(f"边描述: {edge1_props.get('description', '无描述')}")
+ print(f"Successfully read edge: {node1_id} -> {node2_id}")
+ print(
+ f"Edge relationship: {edge1_props.get('relationship', 'No relationship')}"
+ )
+ print(
+ f"Edge description: {edge1_props.get('description', 'No description')}"
+ )
- # 验证边关系是否正确保存
+ # Verify edge relationship is saved correctly
assert (
edge1_props.get("relationship") == edge1_data["relationship"]
- ), f"边关系不匹配: 期望 {edge1_data['relationship']}, 实际 {edge1_props.get('relationship')}"
+ ), f"Edge relationship mismatch: expected {edge1_data['relationship']}, got {edge1_props.get('relationship')}"
- # 验证边描述是否正确保存
+ # Verify edge description is saved correctly
assert (
edge1_props.get("description") == edge1_data["description"]
- ), f"边描述不匹配: 期望 {edge1_data['description']}, 实际 {edge1_props.get('description')}"
+ ), f"Edge description mismatch: expected {edge1_data['description']}, got {edge1_props.get('description')}"
- print(f"边 {node1_id} -> {node2_id} 特殊字符验证成功")
+ print(
+ f"Edge {node1_id} -> {node2_id} special character verification successful"
+ )
else:
- print(f"读取边属性失败: {node1_id} -> {node2_id}")
- assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
+ print(f"Failed to read edge properties: {node1_id} -> {node2_id}")
+ assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}"
edge2_props = await storage.get_edge(node2_id, node3_id)
if edge2_props:
- print(f"成功读取边: {node2_id} -> {node3_id}")
- print(f"边关系: {edge2_props.get('relationship', '无关系')}")
- print(f"边描述: {edge2_props.get('description', '无描述')}")
+ print(f"Successfully read edge: {node2_id} -> {node3_id}")
+ print(
+ f"Edge relationship: {edge2_props.get('relationship', 'No relationship')}"
+ )
+ print(
+ f"Edge description: {edge2_props.get('description', 'No description')}"
+ )
- # 验证边关系是否正确保存
+ # Verify edge relationship is saved correctly
assert (
edge2_props.get("relationship") == edge2_data["relationship"]
- ), f"边关系不匹配: 期望 {edge2_data['relationship']}, 实际 {edge2_props.get('relationship')}"
+ ), f"Edge relationship mismatch: expected {edge2_data['relationship']}, got {edge2_props.get('relationship')}"
- # 验证边描述是否正确保存
+ # Verify edge description is saved correctly
assert (
edge2_props.get("description") == edge2_data["description"]
- ), f"边描述不匹配: 期望 {edge2_data['description']}, 实际 {edge2_props.get('description')}"
+ ), f"Edge description mismatch: expected {edge2_data['description']}, got {edge2_props.get('description')}"
- print(f"边 {node2_id} -> {node3_id} 特殊字符验证成功")
+ print(
+ f"Edge {node2_id} -> {node3_id} special character verification successful"
+ )
else:
- print(f"读取边属性失败: {node2_id} -> {node3_id}")
- assert False, f"未能读取边属性: {node2_id} -> {node3_id}"
+ print(f"Failed to read edge properties: {node2_id} -> {node3_id}")
+ assert False, f"Failed to read edge properties: {node2_id} -> {node3_id}"
- print("\n特殊字符测试完成,数据已保留在数据库中")
+ print("\nSpecial character tests completed, data is preserved in the database.")
return True
except Exception as e:
- ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
+ ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_undirected_property(storage):
"""
- 专门测试图存储的无向图特性:
- 1. 验证插入一个方向的边后,反向查询是否能获得相同的结果
- 2. 验证边的属性在正向和反向查询中是否一致
- 3. 验证删除一个方向的边后,另一个方向的边是否也被删除
- 4. 验证批量操作中的无向图特性
+ Specifically test the undirected graph property of the storage:
+ 1. Verify that after inserting an edge in one direction, a reverse query can retrieve the same result.
+ 2. Verify that edge properties are consistent in forward and reverse queries.
+ 3. Verify that after deleting an edge in one direction, the edge in the other direction is also deleted.
+ 4. Verify the undirected property in batch operations.
"""
try:
- # 1. 插入测试数据
- # 插入节点1: 计算机科学
- node1_id = "计算机科学"
+ # 1. Insert test data
+ # Insert node 1: Computer Science
+ node1_id = "Computer Science"
node1_data = {
"entity_id": node1_id,
- "description": "计算机科学是研究计算机及其应用的科学。",
- "keywords": "计算机,科学,技术",
- "entity_type": "学科",
+ "description": "Computer science is the study of computers and their applications.",
+ "keywords": "computer,science,technology",
+ "entity_type": "Discipline",
}
- print(f"插入节点1: {node1_id}")
+ print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
- # 插入节点2: 数据结构
- node2_id = "数据结构"
+ # Insert node 2: Data Structures
+ node2_id = "Data Structures"
node2_data = {
"entity_id": node2_id,
- "description": "数据结构是计算机科学中的一个基础概念,用于组织和存储数据。",
- "keywords": "数据,结构,组织",
- "entity_type": "概念",
+ "description": "A data structure is a fundamental concept in computer science used to organize and store data.",
+ "keywords": "data,structure,organization",
+ "entity_type": "Concept",
}
- print(f"插入节点2: {node2_id}")
+ print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
- # 插入节点3: 算法
- node3_id = "算法"
+ # Insert node 3: Algorithms
+ node3_id = "Algorithms"
node3_data = {
"entity_id": node3_id,
- "description": "算法是解决问题的步骤和方法。",
- "keywords": "算法,步骤,方法",
- "entity_type": "概念",
+ "description": "An algorithm is a set of steps and methods for solving problems.",
+ "keywords": "algorithm,steps,methods",
+ "entity_type": "Concept",
}
- print(f"插入节点3: {node3_id}")
+ print(f"Inserting node 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
- # 2. 测试插入边后的无向图特性
- print("\n== 测试插入边后的无向图特性")
+ # 2. Test undirected property after edge insertion
+ print("\n== Testing undirected property after edge insertion")
- # 插入边1: 计算机科学 -> 数据结构
+ # Insert edge 1: Computer Science -> Data Structures
edge1_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "计算机科学包含数据结构这个概念",
+ "description": "Computer science includes the concept of data structures.",
}
- print(f"插入边1: {node1_id} -> {node2_id}")
+ print(f"Inserting edge 1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
- # 验证正向查询
+ # Verify forward query
forward_edge = await storage.get_edge(node1_id, node2_id)
- print(f"正向边属性: {forward_edge}")
- assert forward_edge is not None, f"未能读取正向边属性: {node1_id} -> {node2_id}"
+ print(f"Forward edge properties: {forward_edge}")
+ assert (
+ forward_edge is not None
+ ), f"Failed to read forward edge properties: {node1_id} -> {node2_id}"
- # 验证反向查询
+ # Verify reverse query
reverse_edge = await storage.get_edge(node2_id, node1_id)
- print(f"反向边属性: {reverse_edge}")
- assert reverse_edge is not None, f"未能读取反向边属性: {node2_id} -> {node1_id}"
+ print(f"Reverse edge properties: {reverse_edge}")
+ assert (
+ reverse_edge is not None
+ ), f"Failed to read reverse edge properties: {node2_id} -> {node1_id}"
- # 验证正向和反向边属性是否一致
+ # Verify that forward and reverse edge properties are consistent
assert (
forward_edge == reverse_edge
- ), "正向和反向边属性不一致,无向图特性验证失败"
- print("无向图特性验证成功:正向和反向边属性一致")
+ ), "Forward and reverse edge properties are inconsistent, undirected property verification failed"
+ print(
+ "Undirected property verification successful: forward and reverse edge properties are consistent"
+ )
- # 3. 测试边的度数的无向图特性
- print("\n== 测试边的度数的无向图特性")
+ # 3. Test undirected property of edge degree
+ print("\n== Testing undirected property of edge degree")
- # 插入边2: 计算机科学 -> 算法
+ # Insert edge 2: Computer Science -> Algorithms
edge2_data = {
- "relationship": "包含",
+ "relationship": "includes",
"weight": 1.0,
- "description": "计算机科学包含算法这个概念",
+ "description": "Computer science includes the concept of algorithms.",
}
- print(f"插入边2: {node1_id} -> {node3_id}")
+ print(f"Inserting edge 2: {node1_id} -> {node3_id}")
await storage.upsert_edge(node1_id, node3_id, edge2_data)
- # 验证正向和反向边的度数
+ # Verify degrees of forward and reverse edges
forward_degree = await storage.edge_degree(node1_id, node2_id)
reverse_degree = await storage.edge_degree(node2_id, node1_id)
- print(f"正向边 {node1_id} -> {node2_id} 的度数: {forward_degree}")
- print(f"反向边 {node2_id} -> {node1_id} 的度数: {reverse_degree}")
+ print(f"Degree of forward edge {node1_id} -> {node2_id}: {forward_degree}")
+ print(f"Degree of reverse edge {node2_id} -> {node1_id}: {reverse_degree}")
assert (
forward_degree == reverse_degree
- ), "正向和反向边的度数不一致,无向图特性验证失败"
- print("无向图特性验证成功:正向和反向边的度数一致")
+ ), "Degrees of forward and reverse edges are inconsistent, undirected property verification failed"
+ print(
+ "Undirected property verification successful: degrees of forward and reverse edges are consistent"
+ )
- # 4. 测试删除边的无向图特性
- print("\n== 测试删除边的无向图特性")
+ # 4. Test undirected property of edge deletion
+ print("\n== Testing undirected property of edge deletion")
- # 删除正向边
- print(f"删除边: {node1_id} -> {node2_id}")
+ # Delete forward edge
+ print(f"Deleting edge: {node1_id} -> {node2_id}")
await storage.remove_edges([(node1_id, node2_id)])
- # 验证正向边是否被删除
+ # Verify forward edge is deleted
forward_edge = await storage.get_edge(node1_id, node2_id)
- print(f"删除后查询正向边属性 {node1_id} -> {node2_id}: {forward_edge}")
- assert forward_edge is None, f"边 {node1_id} -> {node2_id} 应已被删除"
+ print(
+ f"Querying forward edge properties after deletion {node1_id} -> {node2_id}: {forward_edge}"
+ )
+ assert (
+ forward_edge is None
+ ), f"Edge {node1_id} -> {node2_id} should have been deleted"
- # 验证反向边是否也被删除
+ # Verify reverse edge is also deleted
reverse_edge = await storage.get_edge(node2_id, node1_id)
- print(f"删除后查询反向边属性 {node2_id} -> {node1_id}: {reverse_edge}")
+ print(
+ f"Querying reverse edge properties after deletion {node2_id} -> {node1_id}: {reverse_edge}"
+ )
assert (
reverse_edge is None
- ), f"反向边 {node2_id} -> {node1_id} 也应被删除,无向图特性验证失败"
- print("无向图特性验证成功:删除一个方向的边后,反向边也被删除")
+ ), f"Reverse edge {node2_id} -> {node1_id} should also be deleted, undirected property verification failed"
+ print(
+ "Undirected property verification successful: deleting an edge in one direction also deletes the reverse edge"
+ )
- # 5. 测试批量操作中的无向图特性
- print("\n== 测试批量操作中的无向图特性")
+ # 5. Test undirected property in batch operations
+ print("\n== Testing undirected property in batch operations")
- # 重新插入边
+ # Re-insert edge
await storage.upsert_edge(node1_id, node2_id, edge1_data)
- # 批量获取边属性
+ # Batch get edge properties
edge_dicts = [
{"src": node1_id, "tgt": node2_id},
{"src": node1_id, "tgt": node3_id},
@@ -1108,32 +1130,37 @@ async def test_graph_undirected_property(storage):
edges_dict = await storage.get_edges_batch(edge_dicts)
reverse_edges_dict = await storage.get_edges_batch(reverse_edge_dicts)
- print(f"批量获取正向边属性结果: {edges_dict.keys()}")
- print(f"批量获取反向边属性结果: {reverse_edges_dict.keys()}")
+ print(f"Batch get forward edge properties result: {edges_dict.keys()}")
+ print(f"Batch get reverse edge properties result: {reverse_edges_dict.keys()}")
- # 验证正向和反向边的属性是否一致
+ # Verify that properties of forward and reverse edges are consistent
for (src, tgt), props in edges_dict.items():
assert (
- tgt,
- src,
- ) in reverse_edges_dict, f"反向边 {tgt} -> {src} 应在返回结果中"
+ (
+ tgt,
+ src,
+ )
+ in reverse_edges_dict
+ ), f"Reverse edge {tgt} -> {src} should be in the result"
assert (
props == reverse_edges_dict[(tgt, src)]
- ), f"边 {src} -> {tgt} 和反向边 {tgt} -> {src} 的属性不一致"
+ ), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent"
- print("无向图特性验证成功:批量获取的正向和反向边属性一致")
+ print(
+ "Undirected property verification successful: properties of batch-retrieved forward and reverse edges are consistent"
+ )
- # 6. 测试批量获取节点边的无向图特性
- print("\n== 测试批量获取节点边的无向图特性")
+ # 6. Test undirected property of batch-retrieved node edges
+ print("\n== Testing undirected property of batch-retrieved node edges")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node2_id])
- print(f"批量获取节点边结果: {nodes_edges.keys()}")
+ print(f"Batch get node edges result: {nodes_edges.keys()}")
- # 检查节点1的边是否包含所有相关的边(无论方向)
+ # Check if node 1's edges include all relevant edges (regardless of direction)
node1_edges = nodes_edges[node1_id]
node2_edges = nodes_edges[node2_id]
- # 检查节点1是否有到节点2和节点3的边
+ # Check if node 1 has edges to node 2 and node 3
has_edge_to_node2 = any(
(src == node1_id and tgt == node2_id) for src, tgt in node1_edges
)
@@ -1141,10 +1168,14 @@ async def test_graph_undirected_property(storage):
(src == node1_id and tgt == node3_id) for src, tgt in node1_edges
)
- assert has_edge_to_node2, f"节点 {node1_id} 的边列表中应包含到 {node2_id} 的边"
- assert has_edge_to_node3, f"节点 {node1_id} 的边列表中应包含到 {node3_id} 的边"
+ assert (
+ has_edge_to_node2
+ ), f"Edge list of node {node1_id} should include an edge to {node2_id}"
+ assert (
+ has_edge_to_node3
+ ), f"Edge list of node {node1_id} should include an edge to {node3_id}"
- # 检查节点2是否有到节点1的边
+ # Check if node 2 has a connection with node 1
has_edge_to_node1 = any(
(src == node2_id and tgt == node1_id)
or (src == node1_id and tgt == node2_id)
@@ -1152,64 +1183,76 @@ async def test_graph_undirected_property(storage):
)
assert (
has_edge_to_node1
- ), f"节点 {node2_id} 的边列表中应包含与 {node1_id} 的连接"
+ ), f"Edge list of node {node2_id} should include a connection with {node1_id}"
- print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
+ print(
+ "Undirected property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)"
+ )
- print("\n无向图特性测试完成")
+ print("\nUndirected property tests completed.")
return True
except Exception as e:
- ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
+ ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def main():
- """主函数"""
- # 显示程序标题
+ """Main function"""
+ # Display program title
ASCIIColors.cyan("""
╔══════════════════════════════════════════════════════════════╗
- ║ 通用图存储测试程序 ║
+ ║ General Graph Storage Test Program ║
╚══════════════════════════════════════════════════════════════╝
""")
- # 检查.env文件
+ # Check for .env file
if not check_env_file():
return
- # 加载环境变量
+ # Load environment variables
load_dotenv(dotenv_path=".env", override=False)
- # 获取图存储类型
+ # Get graph storage type
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
- ASCIIColors.magenta(f"\n当前配置的图存储类型: {graph_storage_type}")
+ ASCIIColors.magenta(
+ f"\nCurrently configured graph storage type: {graph_storage_type}"
+ )
ASCIIColors.white(
- f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
+ f"Supported graph storage types: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
- # 初始化存储实例
+ # Initialize storage instance
storage = await initialize_graph_storage()
if not storage:
- ASCIIColors.red("初始化存储实例失败,测试程序退出")
+ ASCIIColors.red("Failed to initialize storage instance, exiting test program.")
return
try:
- # 显示测试选项
- ASCIIColors.yellow("\n请选择测试类型:")
- ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
- ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
- ASCIIColors.white("3. 批量操作测试 (批量获取节点、边属性和度数等)")
- ASCIIColors.white("4. 无向图特性测试 (验证存储的无向图特性)")
- ASCIIColors.white("5. 特殊字符测试 (验证单引号、双引号和反斜杠等特殊字符)")
- ASCIIColors.white("6. 全部测试")
+ # Display test options
+ ASCIIColors.yellow("\nPlease select a test type:")
+ ASCIIColors.white("1. Basic Test (Node and edge insertion, reading)")
+ ASCIIColors.white(
+ "2. Advanced Test (Degree, labels, knowledge graph, deletion, etc.)"
+ )
+ ASCIIColors.white(
+ "3. Batch Operations Test (Batch get node/edge properties, degrees, etc.)"
+ )
+ ASCIIColors.white(
+ "4. Undirected Property Test (Verify undirected properties of the storage)"
+ )
+ ASCIIColors.white(
+ "5. Special Characters Test (Verify handling of single/double quotes, backslashes, etc.)"
+ )
+ ASCIIColors.white("6. All Tests")
- choice = input("\n请输入选项 (1/2/3/4/5/6): ")
+ choice = input("\nEnter your choice (1/2/3/4/5/6): ")
- # 在执行测试前清理数据
+ # Clean data before running tests
if choice in ["1", "2", "3", "4", "5", "6"]:
- ASCIIColors.yellow("\n执行测试前清理数据...")
+ ASCIIColors.yellow("\nCleaning data before running tests...")
await storage.drop()
- ASCIIColors.green("数据清理完成\n")
+ ASCIIColors.green("Data cleanup complete\n")
if choice == "1":
await test_graph_basic(storage)
@@ -1222,34 +1265,36 @@ async def main():
elif choice == "5":
await test_graph_special_characters(storage)
elif choice == "6":
- ASCIIColors.cyan("\n=== 开始基本测试 ===")
+ ASCIIColors.cyan("\n=== Starting Basic Test ===")
basic_result = await test_graph_basic(storage)
if basic_result:
- ASCIIColors.cyan("\n=== 开始高级测试 ===")
+ ASCIIColors.cyan("\n=== Starting Advanced Test ===")
advanced_result = await test_graph_advanced(storage)
if advanced_result:
- ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
+ ASCIIColors.cyan("\n=== Starting Batch Operations Test ===")
batch_result = await test_graph_batch_operations(storage)
if batch_result:
- ASCIIColors.cyan("\n=== 开始无向图特性测试 ===")
+ ASCIIColors.cyan("\n=== Starting Undirected Property Test ===")
undirected_result = await test_graph_undirected_property(
storage
)
if undirected_result:
- ASCIIColors.cyan("\n=== 开始特殊字符测试 ===")
+ ASCIIColors.cyan(
+ "\n=== Starting Special Characters Test ==="
+ )
await test_graph_special_characters(storage)
else:
- ASCIIColors.red("无效的选项")
+ ASCIIColors.red("Invalid choice")
finally:
- # 关闭连接
+ # Close connection
if storage:
await storage.finalize()
- ASCIIColors.green("\n存储连接已关闭")
+ ASCIIColors.green("\nStorage connection closed.")
if __name__ == "__main__":