Merge branch 'main' into add-Memgraph-graph-db

This commit is contained in:
yangdx 2025-07-04 23:53:07 +08:00
commit a69194c079
28 changed files with 3042 additions and 793 deletions

View file

@ -58,6 +58,8 @@ SUMMARY_LANGUAGE=English
# FORCE_LLM_SUMMARY_ON_MERGE=6 # FORCE_LLM_SUMMARY_ON_MERGE=6
### Max tokens for entity/relations description after merge ### Max tokens for entity/relations description after merge
# MAX_TOKEN_SUMMARY=500 # MAX_TOKEN_SUMMARY=500
### Maximum number of entity extraction attempts for ambiguous content
# MAX_GLEANING=1
### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended) ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
# MAX_PARALLEL_INSERT=2 # MAX_PARALLEL_INSERT=2
@ -112,15 +114,6 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage # LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
### TiDB Configuration (Deprecated)
# TIDB_HOST=localhost
# TIDB_PORT=4000
# TIDB_USER=your_username
# TIDB_PASSWORD='your_password'
# TIDB_DATABASE=your_database
### separating all data from difference Lightrag instances(deprecating)
# TIDB_WORKSPACE=default
### PostgreSQL Configuration ### PostgreSQL Configuration
POSTGRES_HOST=localhost POSTGRES_HOST=localhost
POSTGRES_PORT=5432 POSTGRES_PORT=5432
@ -128,7 +121,7 @@ POSTGRES_USER=your_username
POSTGRES_PASSWORD='your_password' POSTGRES_PASSWORD='your_password'
POSTGRES_DATABASE=your_database POSTGRES_DATABASE=your_database
POSTGRES_MAX_CONNECTIONS=12 POSTGRES_MAX_CONNECTIONS=12
### separating all data from difference Lightrag instances(deprecating) ### separating all data from difference Lightrag instances
# POSTGRES_WORKSPACE=default # POSTGRES_WORKSPACE=default
### Neo4j Configuration ### Neo4j Configuration
@ -144,14 +137,15 @@ NEO4J_PASSWORD='your_password'
# AGE_POSTGRES_PORT=8529 # AGE_POSTGRES_PORT=8529
# AGE Graph Name(apply to PostgreSQL and independent AGM) # AGE Graph Name(apply to PostgreSQL and independent AGM)
### AGE_GRAPH_NAME is precated ### AGE_GRAPH_NAME is deprecated
# AGE_GRAPH_NAME=lightrag # AGE_GRAPH_NAME=lightrag
### MongoDB Configuration ### MongoDB Configuration
MONGO_URI=mongodb://root:root@localhost:27017/ MONGO_URI=mongodb://root:root@localhost:27017/
MONGO_DATABASE=LightRAG MONGO_DATABASE=LightRAG
### separating all data from difference Lightrag instances(deprecating) ### separating all data from difference Lightrag instances(deprecating)
# MONGODB_GRAPH=false ### separating all data from difference Lightrag instances
# MONGODB_WORKSPACE=default
### Milvus Configuration ### Milvus Configuration
MILVUS_URI=http://localhost:19530 MILVUS_URI=http://localhost:19530

View file

@ -11,9 +11,74 @@ This example shows how to:
import os import os
import argparse import argparse
import asyncio import asyncio
import logging
import logging.config
from pathlib import Path
# Add project root directory to Python path
import sys
sys.path.append(str(Path(__file__).parent.parent))
from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
from raganything.raganything import RAGAnything from raganything import RAGAnything, RAGAnythingConfig
def configure_logging():
"""Configure logging for the application"""
# Get log directory path from environment variable or use current directory
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "raganything_example.log"))
print(f"\nRAGAnything example log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
},
}
)
# Set the logger level to INFO
logger.setLevel(logging.INFO)
# Enable verbose debug if needed
set_verbose_debug(os.getenv("VERBOSE", "false").lower() == "true")
async def process_with_rag( async def process_with_rag(
@ -31,15 +96,21 @@ async def process_with_rag(
output_dir: Output directory for RAG results output_dir: Output directory for RAG results
api_key: OpenAI API key api_key: OpenAI API key
base_url: Optional base URL for API base_url: Optional base URL for API
working_dir: Working directory for RAG storage
""" """
try: try:
# Initialize RAGAnything # Create RAGAnything configuration
rag = RAGAnything( config = RAGAnythingConfig(
working_dir=working_dir, working_dir=working_dir or "./rag_storage",
llm_model_func=lambda prompt, mineru_parse_method="auto",
system_prompt=None, enable_image_processing=True,
history_messages=[], enable_table_processing=True,
**kwargs: openai_complete_if_cache( enable_equation_processing=True,
)
# Define LLM model function
def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
return openai_complete_if_cache(
"gpt-4o-mini", "gpt-4o-mini",
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
@ -47,81 +118,123 @@ async def process_with_rag(
api_key=api_key, api_key=api_key,
base_url=base_url, base_url=base_url,
**kwargs, **kwargs,
),
vision_model_func=lambda prompt,
system_prompt=None,
history_messages=[],
image_data=None,
**kwargs: openai_complete_if_cache(
"gpt-4o",
"",
system_prompt=None,
history_messages=[],
messages=[
{"role": "system", "content": system_prompt}
if system_prompt
else None,
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_data}"
},
},
],
}
if image_data
else {"role": "user", "content": prompt},
],
api_key=api_key,
base_url=base_url,
**kwargs,
) )
if image_data
else openai_complete_if_cache( # Define vision model function for image processing
"gpt-4o-mini", def vision_model_func(
prompt, prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs
system_prompt=system_prompt, ):
history_messages=history_messages, if image_data:
api_key=api_key, return openai_complete_if_cache(
base_url=base_url, "gpt-4o",
**kwargs, "",
), system_prompt=None,
embedding_func=EmbeddingFunc( history_messages=[],
embedding_dim=3072, messages=[
max_token_size=8192, {"role": "system", "content": system_prompt}
func=lambda texts: openai_embed( if system_prompt
texts, else None,
model="text-embedding-3-large", {
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_data}"
},
},
],
}
if image_data
else {"role": "user", "content": prompt},
],
api_key=api_key, api_key=api_key,
base_url=base_url, base_url=base_url,
), **kwargs,
)
else:
return llm_model_func(prompt, system_prompt, history_messages, **kwargs)
# Define embedding function
embedding_func = EmbeddingFunc(
embedding_dim=3072,
max_token_size=8192,
func=lambda texts: openai_embed(
texts,
model="text-embedding-3-large",
api_key=api_key,
base_url=base_url,
), ),
) )
# Initialize RAGAnything with new dataclass structure
rag = RAGAnything(
config=config,
llm_model_func=llm_model_func,
vision_model_func=vision_model_func,
embedding_func=embedding_func,
)
# Process document # Process document
await rag.process_document_complete( await rag.process_document_complete(
file_path=file_path, output_dir=output_dir, parse_method="auto" file_path=file_path, output_dir=output_dir, parse_method="auto"
) )
# Example queries # Example queries - demonstrating different query approaches
queries = [ logger.info("\nQuerying processed document:")
# 1. Pure text queries using aquery()
text_queries = [
"What is the main content of the document?", "What is the main content of the document?",
"Describe the images and figures in the document", "What are the key topics discussed?",
"Tell me about the experimental results and data tables",
] ]
print("\nQuerying processed document:") for query in text_queries:
for query in queries: logger.info(f"\n[Text Query]: {query}")
print(f"\nQuery: {query}") result = await rag.aquery(query, mode="hybrid")
result = await rag.query_with_multimodal(query, mode="hybrid") logger.info(f"Answer: {result}")
print(f"Answer: {result}")
# 2. Multimodal query with specific multimodal content using aquery_with_multimodal()
logger.info(
"\n[Multimodal Query]: Analyzing performance data in context of document"
)
multimodal_result = await rag.aquery_with_multimodal(
"Compare this performance data with any similar results mentioned in the document",
multimodal_content=[
{
"type": "table",
"table_data": """Method,Accuracy,Processing_Time
RAGAnything,95.2%,120ms
Traditional_RAG,87.3%,180ms
Baseline,82.1%,200ms""",
"table_caption": "Performance comparison results",
}
],
mode="hybrid",
)
logger.info(f"Answer: {multimodal_result}")
# 3. Another multimodal query with equation content
logger.info("\n[Multimodal Query]: Mathematical formula analysis")
equation_result = await rag.aquery_with_multimodal(
"Explain this formula and relate it to any mathematical concepts in the document",
multimodal_content=[
{
"type": "equation",
"latex": "F1 = 2 \\cdot \\frac{precision \\cdot recall}{precision + recall}",
"equation_caption": "F1-score calculation formula",
}
],
mode="hybrid",
)
logger.info(f"Answer: {equation_result}")
except Exception as e: except Exception as e:
print(f"Error processing with RAG: {str(e)}") logger.error(f"Error processing with RAG: {str(e)}")
import traceback
logger.error(traceback.format_exc())
def main(): def main():
@ -135,12 +248,20 @@ def main():
"--output", "-o", default="./output", help="Output directory path" "--output", "-o", default="./output", help="Output directory path"
) )
parser.add_argument( parser.add_argument(
"--api-key", required=True, help="OpenAI API key for RAG processing" "--api-key",
default=os.getenv("OPENAI_API_KEY"),
help="OpenAI API key (defaults to OPENAI_API_KEY env var)",
) )
parser.add_argument("--base-url", help="Optional base URL for API") parser.add_argument("--base-url", help="Optional base URL for API")
args = parser.parse_args() args = parser.parse_args()
# Check if API key is provided
if not args.api_key:
logger.error("Error: OpenAI API key is required")
logger.error("Set OPENAI_API_KEY environment variable or use --api-key option")
return
# Create output directory if specified # Create output directory if specified
if args.output: if args.output:
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
@ -154,4 +275,12 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
# Configure logging first
configure_logging()
print("RAGAnything Example")
print("=" * 30)
print("Processing document with multimodal RAG pipeline")
print("=" * 30)
main() main()

View file

@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
embedding_func=None, embedding_func=None,
) )
# Get all cache data using the new flattened structure
all_data = await from_llm_response_cache.get_all()
# Convert flattened data to hierarchical structure for JsonKVStorage
kv = {} kv = {}
for c_id in await from_llm_response_cache.all_keys(): for flattened_key, cache_entry in all_data.items():
print(f"Copying {c_id}") # Parse flattened key: {mode}:{cache_type}:{hash}
workspace = c_id["workspace"] parts = flattened_key.split(":", 2)
mode = c_id["mode"] if len(parts) == 3:
_id = c_id["id"] mode, cache_type, hash_value = parts
postgres_db.workspace = workspace if mode not in kv:
obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id) kv[mode] = {}
if mode not in kv: kv[mode][hash_value] = cache_entry
kv[mode] = {} print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
kv[mode][_id] = obj[_id] else:
print(f"Object {obj}") print(f"Skipping invalid key format: {flattened_key}")
await to_llm_response_cache.upsert(kv) await to_llm_response_cache.upsert(kv)
await to_llm_response_cache.index_done_callback() await to_llm_response_cache.index_done_callback()
print("Mission accomplished!") print("Mission accomplished!")
@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
db=postgres_db, db=postgres_db,
) )
for mode in await from_llm_response_cache.all_keys(): # Get all cache data from JsonKVStorage (hierarchical structure)
print(f"Copying {mode}") all_data = await from_llm_response_cache.get_all()
caches = await from_llm_response_cache.get_by_id(mode)
for k, v in caches.items(): # Convert hierarchical data to flattened structure for PGKVStorage
item = {mode: {k: v}} flattened_data = {}
print(f"\tCopying {item}") for mode, mode_data in all_data.items():
await to_llm_response_cache.upsert(item) print(f"Processing mode: {mode}")
for hash_value, cache_entry in mode_data.items():
# Determine cache_type from cache entry or use default
cache_type = cache_entry.get("cache_type", "extract")
# Create flattened key: {mode}:{cache_type}:{hash}
flattened_key = f"{mode}:{cache_type}:{hash_value}"
flattened_data[flattened_key] = cache_entry
print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
# Upsert the flattened data
await to_llm_response_cache.upsert(flattened_data)
print("Mission accomplished!")
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1 +1 @@
__api_version__ = "0176" __api_version__ = "0178"

View file

@ -62,6 +62,51 @@ router = APIRouter(
temp_prefix = "__tmp__" temp_prefix = "__tmp__"
def sanitize_filename(filename: str, input_dir: Path) -> str:
"""
Sanitize uploaded filename to prevent Path Traversal attacks.
Args:
filename: The original filename from the upload
input_dir: The target input directory
Returns:
str: Sanitized filename that is safe to use
Raises:
HTTPException: If the filename is unsafe or invalid
"""
# Basic validation
if not filename or not filename.strip():
raise HTTPException(status_code=400, detail="Filename cannot be empty")
# Remove path separators and traversal sequences
clean_name = filename.replace("/", "").replace("\\", "")
clean_name = clean_name.replace("..", "")
# Remove control characters and null bytes
clean_name = "".join(c for c in clean_name if ord(c) >= 32 and c != "\x7f")
# Remove leading/trailing whitespace and dots
clean_name = clean_name.strip().strip(".")
# Check if anything is left after sanitization
if not clean_name:
raise HTTPException(
status_code=400, detail="Invalid filename after sanitization"
)
# Verify the final path stays within the input directory
try:
final_path = (input_dir / clean_name).resolve()
if not final_path.is_relative_to(input_dir.resolve()):
raise HTTPException(status_code=400, detail="Unsafe filename detected")
except (OSError, ValueError):
raise HTTPException(status_code=400, detail="Invalid filename")
return clean_name
class ScanResponse(BaseModel): class ScanResponse(BaseModel):
"""Response model for document scanning operation """Response model for document scanning operation
@ -783,7 +828,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
try: try:
new_files = doc_manager.scan_directory_for_new_files() new_files = doc_manager.scan_directory_for_new_files()
total_files = len(new_files) total_files = len(new_files)
logger.info(f"Found {total_files} new files to index.") logger.info(f"Found {total_files} files to index.")
if not new_files: if not new_files:
return return
@ -816,8 +861,13 @@ async def background_delete_documents(
successful_deletions = [] successful_deletions = []
failed_deletions = [] failed_deletions = []
# Set pipeline status to busy for deletion # Double-check pipeline status before proceeding
async with pipeline_status_lock: async with pipeline_status_lock:
if pipeline_status.get("busy", False):
logger.warning("Error: Unexpected pipeline busy state, aborting deletion.")
return # Abort deletion operation
# Set pipeline status to busy for deletion
pipeline_status.update( pipeline_status.update(
{ {
"busy": True, "busy": True,
@ -926,13 +976,26 @@ async def background_delete_documents(
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["history_messages"].append(error_msg) pipeline_status["history_messages"].append(error_msg)
finally: finally:
# Final summary # Final summary and check for pending requests
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["busy"] = False pipeline_status["busy"] = False
completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed" completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed"
pipeline_status["latest_message"] = completion_msg pipeline_status["latest_message"] = completion_msg
pipeline_status["history_messages"].append(completion_msg) pipeline_status["history_messages"].append(completion_msg)
# Check if there are pending document indexing requests
has_pending_request = pipeline_status.get("request_pending", False)
# If there are pending requests, start document processing pipeline
if has_pending_request:
try:
logger.info(
"Processing pending document indexing requests after deletion"
)
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logger.error(f"Error processing pending documents after deletion: {e}")
def create_document_routes( def create_document_routes(
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
@ -986,18 +1049,21 @@ def create_document_routes(
HTTPException: If the file type is not supported (400) or other errors occur (500). HTTPException: If the file type is not supported (400) or other errors occur (500).
""" """
try: try:
if not doc_manager.is_supported_file(file.filename): # Sanitize filename to prevent Path Traversal attacks
safe_filename = sanitize_filename(file.filename, doc_manager.input_dir)
if not doc_manager.is_supported_file(safe_filename):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
) )
file_path = doc_manager.input_dir / file.filename file_path = doc_manager.input_dir / safe_filename
# Check if file already exists # Check if file already exists
if file_path.exists(): if file_path.exists():
return InsertResponse( return InsertResponse(
status="duplicated", status="duplicated",
message=f"File '{file.filename}' already exists in the input directory.", message=f"File '{safe_filename}' already exists in the input directory.",
) )
with open(file_path, "wb") as buffer: with open(file_path, "wb") as buffer:
@ -1008,7 +1074,7 @@ def create_document_routes(
return InsertResponse( return InsertResponse(
status="success", status="success",
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.",
) )
except Exception as e: except Exception as e:
logger.error(f"Error /documents/upload: {file.filename}: {str(e)}") logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")

View file

@ -234,7 +234,7 @@ class OllamaAPI:
@self.router.get("/version", dependencies=[Depends(combined_auth)]) @self.router.get("/version", dependencies=[Depends(combined_auth)])
async def get_version(): async def get_version():
"""Get Ollama version information""" """Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4") return OllamaVersionResponse(version="0.9.3")
@self.router.get("/tags", dependencies=[Depends(combined_auth)]) @self.router.get("/tags", dependencies=[Depends(combined_auth)])
async def get_tags(): async def get_tags():
@ -244,9 +244,9 @@ class OllamaAPI:
{ {
"name": self.ollama_server_infos.LIGHTRAG_MODEL, "name": self.ollama_server_infos.LIGHTRAG_MODEL,
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"size": self.ollama_server_infos.LIGHTRAG_SIZE, "size": self.ollama_server_infos.LIGHTRAG_SIZE,
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST, "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"details": { "details": {
"parent_model": "", "parent_model": "",
"format": "gguf", "format": "gguf",
@ -337,7 +337,10 @@ class OllamaAPI:
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True, "done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time, "total_duration": total_time,
"load_duration": 0, "load_duration": 0,
"prompt_eval_count": prompt_tokens, "prompt_eval_count": prompt_tokens,
@ -377,6 +380,7 @@ class OllamaAPI:
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": f"\n\nError: {error_msg}", "response": f"\n\nError: {error_msg}",
"error": f"\n\nError: {error_msg}",
"done": False, "done": False,
} }
yield f"{json.dumps(error_data, ensure_ascii=False)}\n" yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
@ -385,6 +389,7 @@ class OllamaAPI:
final_data = { final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True, "done": True,
} }
yield f"{json.dumps(final_data, ensure_ascii=False)}\n" yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
@ -399,7 +404,10 @@ class OllamaAPI:
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True, "done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time, "total_duration": total_time,
"load_duration": 0, "load_duration": 0,
"prompt_eval_count": prompt_tokens, "prompt_eval_count": prompt_tokens,
@ -444,6 +452,8 @@ class OllamaAPI:
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": str(response_text), "response": str(response_text),
"done": True, "done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time, "total_duration": total_time,
"load_duration": 0, "load_duration": 0,
"prompt_eval_count": prompt_tokens, "prompt_eval_count": prompt_tokens,
@ -557,6 +567,12 @@ class OllamaAPI:
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done_reason": "stop",
"done": True, "done": True,
"total_duration": total_time, "total_duration": total_time,
"load_duration": 0, "load_duration": 0,
@ -605,6 +621,7 @@ class OllamaAPI:
"content": f"\n\nError: {error_msg}", "content": f"\n\nError: {error_msg}",
"images": None, "images": None,
}, },
"error": f"\n\nError: {error_msg}",
"done": False, "done": False,
} }
yield f"{json.dumps(error_data, ensure_ascii=False)}\n" yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
@ -613,6 +630,11 @@ class OllamaAPI:
final_data = { final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done": True, "done": True,
} }
yield f"{json.dumps(final_data, ensure_ascii=False)}\n" yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
@ -633,6 +655,7 @@ class OllamaAPI:
"content": "", "content": "",
"images": None, "images": None,
}, },
"done_reason": "stop",
"done": True, "done": True,
"total_duration": total_time, "total_duration": total_time,
"load_duration": 0, "load_duration": 0,
@ -697,6 +720,7 @@ class OllamaAPI:
"content": str(response_text), "content": str(response_text),
"images": None, "images": None,
}, },
"done_reason": "stop",
"done": True, "done": True,
"total_duration": total_time, "total_duration": total_time,
"load_duration": 0, "load_duration": 0,

View file

@ -183,6 +183,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send it all at once # If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n" yield f"{json.dumps({'response': response})}\n"
elif response is None:
# Handle None response (e.g., when only_need_context=True but no context found)
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
else: else:
# If it's an async generator, send chunks one by one # If it's an async generator, send chunks one by one
try: try:

View file

@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC):
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace, ABC): class BaseGraphStorage(StorageNameSpace, ABC):
"""All operations related to edges in graph should be undirected."""
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
@abstractmethod @abstractmethod
@ -468,17 +470,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
list[dict]: A list of nodes, where each node is a dictionary of its properties. list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found. An empty list if no matching nodes are found.
""" """
# Default implementation iterates through all nodes, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_nodes = []
all_labels = await self.get_all_labels()
for label in all_labels:
node = await self.get_node(label)
if node and "source_id" in node:
source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
all_nodes.append(node)
return all_nodes
@abstractmethod @abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
@ -643,6 +634,8 @@ class DocProcessingStatus:
"""ISO format timestamp when document was last updated""" """ISO format timestamp when document was last updated"""
chunks_count: int | None = None chunks_count: int | None = None
"""Number of chunks after splitting, used for processing""" """Number of chunks after splitting, used for processing"""
chunks_list: list[str] | None = field(default_factory=list)
"""List of chunk IDs associated with this document, used for deletion"""
error: str | None = None error: str | None = None
"""Error message if failed""" """Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)

View file

@ -7,6 +7,7 @@ consistency and makes maintenance easier.
""" """
# Default values for environment variables # Default values for environment variables
DEFAULT_MAX_GLEANING = 1
DEFAULT_MAX_TOKEN_SUMMARY = 500 DEFAULT_MAX_TOKEN_SUMMARY = 500
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6 DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
DEFAULT_WOKERS = 2 DEFAULT_WOKERS = 2

View file

@ -26,11 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
"implementations": [ "implementations": [
"NanoVectorDBStorage", "NanoVectorDBStorage",
"MilvusVectorDBStorage", "MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"PGVectorStorage", "PGVectorStorage",
"FaissVectorDBStorage", "FaissVectorDBStorage",
"QdrantVectorDBStorage", "QdrantVectorDBStorage",
"MongoVectorDBStorage", "MongoVectorDBStorage",
# "ChromaVectorDBStorage",
# "TiDBVectorDBStorage", # "TiDBVectorDBStorage",
], ],
"required_methods": ["query", "upsert"], "required_methods": ["query", "upsert"],
@ -38,6 +38,7 @@ STORAGE_IMPLEMENTATIONS = {
"DOC_STATUS_STORAGE": { "DOC_STATUS_STORAGE": {
"implementations": [ "implementations": [
"JsonDocStatusStorage", "JsonDocStatusStorage",
"RedisDocStatusStorage",
"PGDocStatusStorage", "PGDocStatusStorage",
"MongoDocStatusStorage", "MongoDocStatusStorage",
], ],
@ -81,6 +82,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"MongoVectorDBStorage": [], "MongoVectorDBStorage": [],
# Document Status Storage Implementations # Document Status Storage Implementations
"JsonDocStatusStorage": [], "JsonDocStatusStorage": [],
"RedisDocStatusStorage": ["REDIS_URI"],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [], "MongoDocStatusStorage": [],
} }
@ -98,6 +100,7 @@ STORAGES = {
"MongoGraphStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl", "MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl", "RedisKVStorage": ".kg.redis_impl",
"RedisDocStatusStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",
# "TiDBKVStorage": ".kg.tidb_impl", # "TiDBKVStorage": ".kg.tidb_impl",
# "TiDBVectorDBStorage": ".kg.tidb_impl", # "TiDBVectorDBStorage": ".kg.tidb_impl",

View file

@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
raise raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
try: try:
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
self._collection.delete(ids=ids) self._collection.delete(ids=ids)
logger.debug( logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}" f"Successfully deleted {len(ids)} vectors from {self.namespace}"

View file

@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
################ INSERT full_doc AND chunks ################ ################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
###### INSERT entities And relationships ###### ###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
# Get current time as UNIX timestamp # Get current time as UNIX timestamp
import time import time
@ -522,11 +520,6 @@ class TiDBVectorDBStorage(BaseVectorStorage):
} }
await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param) await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param)
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs from the storage. """Delete vectors with specified IDs from the storage.

View file

@ -17,14 +17,13 @@ from .shared_storage import (
set_all_update_flags, set_all_update_flags,
) )
import faiss # type: ignore
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1" USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu" FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
if not pm.is_installed(FAISS_PACKAGE): if not pm.is_installed(FAISS_PACKAGE):
pm.install(FAISS_PACKAGE) pm.install(FAISS_PACKAGE)
import faiss # type: ignore
@final @final
@dataclass @dataclass

View file

@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage):
return return
logger.debug(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
if "chunks_list" not in doc_data:
doc_data["chunks_list"] = []
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)

View file

@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
async with self._storage_lock: async with self._storage_lock:
self._data.update(loaded_data) # Migrate legacy cache structure if needed
if self.namespace.endswith("_cache"):
# Calculate data count based on namespace loaded_data = await self._migrate_legacy_cache_structure(
if self.namespace.endswith("cache"): loaded_data
# For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in loaded_data.values()
if isinstance(first_level_dict, dict)
) )
else:
# For non-cache namespaces, use the original count method self._data.update(loaded_data)
data_count = len(loaded_data) data_count = len(loaded_data)
logger.info( logger.info(
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records" f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
# Calculate data count based on namespace # Calculate data count - all data is now flattened
if self.namespace.endswith("cache"): data_count = len(data_dict)
# # For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in data_dict.values()
if isinstance(first_level_dict, dict)
)
else:
# For non-cache namespaces, use the original count method
data_count = len(data_dict)
logger.debug( logger.debug(
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}" f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
@ -92,22 +78,49 @@ class JsonKVStorage(BaseKVStorage):
Dictionary containing all stored data Dictionary containing all stored data
""" """
async with self._storage_lock: async with self._storage_lock:
return dict(self._data) result = {}
for key, value in self._data.items():
if value:
# Create a copy to avoid modifying the original data
data = dict(value)
# Ensure time fields are present, provide default values for old data
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
result[key] = data
else:
result[key] = value
return result
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock: async with self._storage_lock:
return self._data.get(id) result = self._data.get(id)
if result:
# Create a copy to avoid modifying the original data
result = dict(result)
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
return result
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async with self._storage_lock: async with self._storage_lock:
return [ results = []
( for id in ids:
{k: v for k, v in self._data[id].items()} data = self._data.get(id, None)
if self._data.get(id, None) if data:
else None # Create a copy to avoid modifying the original data
) result = {k: v for k, v in data.items()}
for id in ids # Ensure time fields are present, provide default values for old data
] result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
results.append(result)
else:
results.append(None)
return results
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._storage_lock: async with self._storage_lock:
@ -121,8 +134,29 @@ class JsonKVStorage(BaseKVStorage):
""" """
if not data: if not data:
return return
import time
current_time = int(time.time()) # Get current Unix timestamp
logger.debug(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
# Add timestamps to data based on whether key exists
for k, v in data.items():
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Add timestamps based on whether key exists
if k in self._data: # Key exists, only update update_time
v["update_time"] = current_time
else: # New key, set both create_time and update_time
v["create_time"] = current_time
v["update_time"] = current_time
v["_id"] = k
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)
@ -150,14 +184,14 @@ class JsonKVStorage(BaseKVStorage):
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode """Delete specific records from storage by cache mode
Importance notes for in-memory storage: Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback 1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed 2. update flags to notify other processes that data persistence is needed
Args: Args:
ids (list[str]): List of cache mode to be drop from storage modes (list[str]): List of cache modes to be dropped from storage
Returns: Returns:
True: if the cache drop successfully True: if the cache drop successfully
@ -167,9 +201,29 @@ class JsonKVStorage(BaseKVStorage):
return False return False
try: try:
await self.delete(modes) async with self._storage_lock:
keys_to_delete = []
modes_set = set(modes) # Convert to set for efficient lookup
for key in list(self._data.keys()):
# Parse flattened cache key: mode:cache_type:hash
parts = key.split(":", 2)
if len(parts) == 3 and parts[0] in modes_set:
keys_to_delete.append(key)
# Batch delete
for key in keys_to_delete:
self._data.pop(key, None)
if keys_to_delete:
await set_all_update_flags(self.namespace)
logger.info(
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
)
return True return True
except Exception: except Exception as e:
logger.error(f"Error dropping cache by modes: {e}")
return False return False
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool: # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
@ -245,9 +299,58 @@ class JsonKVStorage(BaseKVStorage):
logger.error(f"Error dropping {self.namespace}: {e}") logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
"""Migrate legacy nested cache structure to flattened structure
Args:
data: Original data dictionary that may contain legacy structure
Returns:
Migrated data dictionary with flattened cache keys
"""
from lightrag.utils import generate_cache_key
# Early return if data is empty
if not data:
return data
# Check first entry to see if it's already in new format
first_key = next(iter(data.keys()))
if ":" in first_key and len(first_key.split(":")) == 3:
# Already in flattened format, return as-is
return data
migrated_data = {}
migration_count = 0
for key, value in data.items():
# Check if this is a legacy nested cache structure
if isinstance(value, dict) and all(
isinstance(v, dict) and "return" in v for v in value.values()
):
# This looks like a legacy cache mode with nested structure
mode = key
for cache_hash, cache_entry in value.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
migrated_data[flattened_key] = cache_entry
migration_count += 1
else:
# Keep non-cache data or already flattened cache data as-is
migrated_data[key] = value
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure"
)
# Persist migrated data immediately
write_json(migrated_data, self._file_name)
return migrated_data
async def finalize(self): async def finalize(self):
"""Finalize storage resources """Finalize storage resources
Persistence cache data to disk before exiting Persistence cache data to disk before exiting
""" """
if self.namespace.endswith("cache"): if self.namespace.endswith("_cache"):
await self.index_done_callback() await self.index_done_callback()

View file

@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
pm.install("pymilvus") pm.install("pymilvus")
import configparser import configparser
from pymilvus import MilvusClient # type: ignore from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@ -24,16 +24,605 @@ config.read("config.ini", "utf-8")
@final @final
@dataclass @dataclass
class MilvusVectorDBStorage(BaseVectorStorage): class MilvusVectorDBStorage(BaseVectorStorage):
@staticmethod def _create_schema_for_namespace(self) -> CollectionSchema:
def create_collection_if_not_exist( """Create schema based on the current instance's namespace"""
client: MilvusClient, collection_name: str, **kwargs
): # Get vector dimension from embedding_func
if client.has_collection(collection_name): dimension = self.embedding_func.embedding_dim
return
client.create_collection( # Base fields (common to all collections)
collection_name, max_length=64, id_type="string", **kwargs base_fields = [
FieldSchema(
name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="created_at", dtype=DataType.INT64),
]
# Determine specific fields based on namespace
if "entities" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="entity_name",
dtype=DataType.VARCHAR,
max_length=256,
nullable=True,
),
FieldSchema(
name="entity_type",
dtype=DataType.VARCHAR,
max_length=64,
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG entities vector storage"
elif "relationships" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
),
FieldSchema(
name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
),
FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG relationships vector storage"
elif "chunks" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="full_doc_id",
dtype=DataType.VARCHAR,
max_length=64,
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG chunks vector storage"
else:
# Default generic schema (backward compatibility)
specific_fields = [
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG generic vector storage"
# Merge all fields
all_fields = base_fields + specific_fields
return CollectionSchema(
fields=all_fields,
description=description,
enable_dynamic_field=True, # Support dynamic fields
) )
def _get_index_params(self):
"""Get IndexParams in a version-compatible way"""
try:
# Try to use client's prepare_index_params method (most common)
if hasattr(self._client, "prepare_index_params"):
return self._client.prepare_index_params()
except Exception:
pass
try:
# Try to import IndexParams from different possible locations
from pymilvus.client.prepare import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus.client.types import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus import IndexParams
return IndexParams()
except ImportError:
pass
# If all else fails, return None to use fallback method
return None
def _create_vector_index_fallback(self):
"""Fallback method to create vector index using direct API"""
try:
self._client.create_index(
collection_name=self.namespace,
field_name="vector",
index_params={
"index_type": "HNSW",
"metric_type": "COSINE",
"params": {"M": 16, "efConstruction": 256},
},
)
logger.debug("Created vector index using fallback method")
except Exception as e:
logger.warning(f"Failed to create vector index using fallback method: {e}")
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
"""Fallback method to create scalar index using direct API"""
# Skip unsupported index types
if index_type == "SORTED":
logger.info(
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
)
return
try:
self._client.create_index(
collection_name=self.namespace,
field_name=field_name,
index_params={"index_type": index_type},
)
logger.debug(f"Created {field_name} index using fallback method")
except Exception as e:
logger.info(
f"Could not create {field_name} index using fallback method: {e}"
)
def _create_indexes_after_collection(self):
"""Create indexes after collection is created"""
try:
# Try to get IndexParams in a version-compatible way
IndexParamsClass = self._get_index_params()
if IndexParamsClass is not None:
# Use IndexParams approach if available
try:
# Create vector index first (required for most operations)
vector_index = IndexParamsClass
vector_index.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 256},
)
self._client.create_index(
collection_name=self.namespace, index_params=vector_index
)
logger.debug("Created vector index using IndexParams")
except Exception as e:
logger.debug(f"IndexParams method failed for vector index: {e}")
self._create_vector_index_fallback()
# Create scalar indexes based on namespace
if "entities" in self.namespace.lower():
# Create indexes for entity fields
try:
entity_name_index = self._get_index_params()
entity_name_index.add_index(
field_name="entity_name", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace,
index_params=entity_name_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for entity_name: {e}")
self._create_scalar_index_fallback("entity_name", "INVERTED")
try:
entity_type_index = self._get_index_params()
entity_type_index.add_index(
field_name="entity_type", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace,
index_params=entity_type_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for entity_type: {e}")
self._create_scalar_index_fallback("entity_type", "INVERTED")
elif "relationships" in self.namespace.lower():
# Create indexes for relationship fields
try:
src_id_index = self._get_index_params()
src_id_index.add_index(
field_name="src_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=src_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for src_id: {e}")
self._create_scalar_index_fallback("src_id", "INVERTED")
try:
tgt_id_index = self._get_index_params()
tgt_id_index.add_index(
field_name="tgt_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=tgt_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for tgt_id: {e}")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
# Create indexes for chunk fields
try:
doc_id_index = self._get_index_params()
doc_id_index.add_index(
field_name="full_doc_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=doc_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for full_doc_id: {e}")
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
# No common indexes needed
else:
# Fallback to direct API calls if IndexParams is not available
logger.info(
f"IndexParams not available, using fallback methods for {self.namespace}"
)
# Create vector index using fallback
self._create_vector_index_fallback()
# Create scalar indexes using fallback
if "entities" in self.namespace.lower():
self._create_scalar_index_fallback("entity_name", "INVERTED")
self._create_scalar_index_fallback("entity_type", "INVERTED")
elif "relationships" in self.namespace.lower():
self._create_scalar_index_fallback("src_id", "INVERTED")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
logger.info(f"Created indexes for collection: {self.namespace}")
except Exception as e:
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
def _get_required_fields_for_namespace(self) -> dict:
"""Get required core field definitions for current namespace"""
# Base fields (common to all types)
base_fields = {
"id": {"type": "VarChar", "is_primary": True},
"vector": {"type": "FloatVector"},
"created_at": {"type": "Int64"},
}
# Add specific fields based on namespace
if "entities" in self.namespace.lower():
specific_fields = {
"entity_name": {"type": "VarChar"},
"entity_type": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
elif "relationships" in self.namespace.lower():
specific_fields = {
"src_id": {"type": "VarChar"},
"tgt_id": {"type": "VarChar"},
"weight": {"type": "Double"},
"file_path": {"type": "VarChar"},
}
elif "chunks" in self.namespace.lower():
specific_fields = {
"full_doc_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
else:
specific_fields = {
"file_path": {"type": "VarChar"},
}
return {**base_fields, **specific_fields}
def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
"""Check compatibility of a single field"""
field_name = existing_field.get("name", "unknown")
existing_type = existing_field.get("type")
expected_type = expected_config.get("type")
logger.debug(
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
)
# Convert DataType enum values to string names if needed
original_existing_type = existing_type
if hasattr(existing_type, "name"):
existing_type = existing_type.name
logger.debug(
f"Converted enum to name: {original_existing_type} -> {existing_type}"
)
elif isinstance(existing_type, int):
# Map common Milvus internal type codes to type names for backward compatibility
type_mapping = {
21: "VarChar",
101: "FloatVector",
5: "Int64",
9: "Double",
}
mapped_type = type_mapping.get(existing_type, str(existing_type))
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
existing_type = mapped_type
# Normalize type names for comparison
type_aliases = {
"VARCHAR": "VarChar",
"String": "VarChar",
"FLOAT_VECTOR": "FloatVector",
"INT64": "Int64",
"BigInt": "Int64",
"DOUBLE": "Double",
"Float": "Double",
}
original_existing = existing_type
original_expected = expected_type
existing_type = type_aliases.get(existing_type, existing_type)
expected_type = type_aliases.get(expected_type, expected_type)
if original_existing != existing_type or original_expected != expected_type:
logger.debug(
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
)
# Basic type compatibility check
type_compatible = existing_type == expected_type
logger.debug(
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
)
if not type_compatible:
logger.warning(
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
)
return False
# Primary key check - be more flexible about primary key detection
if expected_config.get("is_primary"):
# Check multiple possible field names for primary key status
is_primary = (
existing_field.get("is_primary_key", False)
or existing_field.get("is_primary", False)
or existing_field.get("primary_key", False)
)
logger.debug(
f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
)
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
# For ID field, be more lenient - if it's the ID field, assume it should be primary
if field_name == "id" and not is_primary:
logger.info(
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
)
# Don't fail for ID field primary key mismatch
elif not is_primary:
logger.warning(
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
)
return False
logger.debug(f"Field '{field_name}' is compatible")
return True
def _check_vector_dimension(self, collection_info: dict):
"""Check vector dimension compatibility"""
current_dimension = self.embedding_func.embedding_dim
# Find vector field dimension
for field in collection_info.get("fields", []):
if field.get("name") == "vector":
field_type = field.get("type")
if field_type in ["FloatVector", "FLOAT_VECTOR"]:
existing_dimension = field.get("params", {}).get("dim")
if existing_dimension != current_dimension:
raise ValueError(
f"Vector dimension mismatch for collection '{self.namespace}': "
f"existing={existing_dimension}, current={current_dimension}"
)
logger.debug(f"Vector dimension check passed: {current_dimension}")
return
# If no vector field found, this might be an old collection created with simple schema
logger.warning(
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
)
logger.warning("Consider recreating the collection for optimal performance.")
return
def _check_schema_compatibility(self, collection_info: dict):
"""Check schema field compatibility"""
existing_fields = {
field["name"]: field for field in collection_info.get("fields", [])
}
# Check if this is an old collection created with simple schema
has_vector_field = any(
field.get("name") == "vector" for field in collection_info.get("fields", [])
)
if not has_vector_field:
logger.warning(
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
)
logger.warning(
"This collection will work but may have suboptimal performance"
)
logger.warning("Consider recreating the collection for optimal performance")
return
# For collections with vector field, check basic compatibility
# Only check for critical incompatibilities, not missing optional fields
critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
incompatible_fields = []
for field_name, expected_config in critical_fields.items():
if field_name in existing_fields:
existing_field = existing_fields[field_name]
if not self._is_field_compatible(existing_field, expected_config):
incompatible_fields.append(
f"{field_name}: expected {expected_config['type']}, "
f"got {existing_field.get('type')}"
)
if incompatible_fields:
raise ValueError(
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
)
# Get all expected fields for informational purposes
expected_fields = self._get_required_fields_for_namespace()
missing_fields = [
field for field in expected_fields if field not in existing_fields
]
if missing_fields:
logger.info(
f"Collection {self.namespace} missing optional fields: {missing_fields}"
)
logger.info(
"These fields would be available in a newly created collection for better performance"
)
logger.debug(f"Schema compatibility check passed for {self.namespace}")
def _validate_collection_compatibility(self):
"""Validate existing collection's dimension and schema compatibility"""
try:
collection_info = self._client.describe_collection(self.namespace)
# 1. Check vector dimension
self._check_vector_dimension(collection_info)
# 2. Check schema compatibility
self._check_schema_compatibility(collection_info)
logger.info(f"Collection {self.namespace} compatibility validation passed")
except Exception as e:
logger.error(
f"Collection compatibility validation failed for {self.namespace}: {e}"
)
raise
def _create_collection_if_not_exist(self):
"""Create collection if not exists and check existing collection compatibility"""
try:
# First, list all collections to see what actually exists
try:
all_collections = self._client.list_collections()
logger.debug(f"All collections in database: {all_collections}")
except Exception as list_error:
logger.warning(f"Could not list collections: {list_error}")
all_collections = []
# Check if our specific collection exists
collection_exists = self._client.has_collection(self.namespace)
logger.info(
f"Collection '{self.namespace}' exists check: {collection_exists}"
)
if collection_exists:
# Double-check by trying to describe the collection
try:
self._client.describe_collection(self.namespace)
logger.info(
f"Collection '{self.namespace}' confirmed to exist, validating compatibility..."
)
self._validate_collection_compatibility()
return
except Exception as describe_error:
logger.warning(
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
)
logger.info(
"Treating as if collection doesn't exist and creating new one..."
)
# Fall through to creation logic
# Collection doesn't exist, create new collection
logger.info(f"Creating new collection: {self.namespace}")
schema = self._create_schema_for_namespace()
# Create collection with schema only first
self._client.create_collection(
collection_name=self.namespace, schema=schema
)
# Then create indexes
self._create_indexes_after_collection()
logger.info(f"Successfully created Milvus collection: {self.namespace}")
except Exception as e:
logger.error(
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
)
# If there's any error, try to force create the collection
logger.info(f"Attempting to force create collection {self.namespace}...")
try:
# Try to drop the collection first if it exists in a bad state
try:
if self._client.has_collection(self.namespace):
logger.info(
f"Dropping potentially corrupted collection {self.namespace}"
)
self._client.drop_collection(self.namespace)
except Exception as drop_error:
logger.warning(
f"Could not drop collection {self.namespace}: {drop_error}"
)
# Create fresh collection
schema = self._create_schema_for_namespace()
self._client.create_collection(
collection_name=self.namespace, schema=schema
)
self._create_indexes_after_collection()
logger.info(f"Successfully force-created collection {self.namespace}")
except Exception as create_error:
logger.error(
f"Failed to force-create collection {self.namespace}: {create_error}"
)
raise
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
# Ensure created_at is in meta_fields
if "created_at" not in self.meta_fields:
self.meta_fields.add("created_at")
self._client = MilvusClient( self._client = MilvusClient(
uri=os.environ.get( uri=os.environ.get(
"MILVUS_URI", "MILVUS_URI",
@ -68,14 +661,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
), ),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client, # Create collection and check compatibility
self.namespace, self._create_collection_if_not_exist()
dimension=self.embedding_func.embedding_dim,
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func( embedding = await self.embedding_func(
[query], _priority=5 [query], _priority=5
) # higher priority for query ) # higher priority for query
# Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields)
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.namespace,
data=embedding, data=embedding,
limit=top_k, limit=top_k,
output_fields=list(self.meta_fields) + ["created_at"], output_fields=output_fields,
search_params={ search_params={
"metric_type": "COSINE", "metric_type": "COSINE",
"params": {"radius": self.cosine_better_than_threshold}, "params": {"radius": self.cosine_better_than_threshold},
}, },
) )
print(results)
return [ return [
{ {
**dp["entity"], **dp["entity"],
"id": dp["id"], "id": dp["id"],
"distance": dp["distance"], "distance": dp["distance"],
# created_at is requested in output_fields, so it should be a top-level key in the result dict (dp)
"created_at": dp.get("created_at"), "created_at": dp.get("created_at"),
} }
for dp in results[0] for dp in results[0]
@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
The vector data if found, or None if not found The vector data if found, or None if not found
""" """
try: try:
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
# Query Milvus for a specific ID # Query Milvus for a specific ID
result = self._client.query( result = self._client.query(
collection_name=self.namespace, collection_name=self.namespace,
filter=f'id == "{id}"', filter=f'id == "{id}"',
output_fields=list(self.meta_fields) + ["id", "created_at"], output_fields=output_fields,
) )
if not result or len(result) == 0: if not result or len(result) == 0:
return None return None
# Ensure the result contains created_at field
if "created_at" not in result[0]:
result[0]["created_at"] = None
return result[0] return result[0]
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}") logger.error(f"Error retrieving vector data for ID {id}: {e}")
@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return [] return []
try: try:
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
# Prepare the ID filter expression # Prepare the ID filter expression
id_list = '", "'.join(ids) id_list = '", "'.join(ids)
filter_expr = f'id in ["{id_list}"]' filter_expr = f'id in ["{id_list}"]'
@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
result = self._client.query( result = self._client.query(
collection_name=self.namespace, collection_name=self.namespace,
filter=filter_expr, filter=filter_expr,
output_fields=list(self.meta_fields) + ["id", "created_at"], output_fields=output_fields,
) )
# Ensure each result contains created_at field
for item in result:
if "created_at" not in item:
item["created_at"] = None
return result or [] return result or []
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._client.drop_collection(self.namespace) self._client.drop_collection(self.namespace)
# Recreate the collection # Recreate the collection
MilvusVectorDBStorage.create_collection_if_not_exist( self._create_collection_if_not_exist()
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
logger.info( logger.info(
f"Process {os.getpid()} drop Milvus collection {self.namespace}" f"Process {os.getpid()} drop Milvus collection {self.namespace}"

View file

@ -1,4 +1,5 @@
import os import os
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
import numpy as np import numpy as np
import configparser import configparser
@ -14,7 +15,6 @@ from ..base import (
DocStatus, DocStatus,
DocStatusStorage, DocStatusStorage,
) )
from ..namespace import NameSpace, is_namespace
from ..utils import logger, compute_mdhash_id from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
@ -35,6 +35,7 @@ config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000 # Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
class ClientManager: class ClientManager:
@ -96,11 +97,22 @@ class MongoKVStorage(BaseKVStorage):
self._data = None self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id}) # Unified handling for flattened keys
doc = await self._data.find_one({"_id": id})
if doc:
# Ensure time fields are present, provide default values for old data
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
return doc
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}}) cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list() docs = await cursor.to_list()
# Ensure time fields are present for all documents
for doc in docs:
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
return docs
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
@ -117,47 +129,53 @@ class MongoKVStorage(BaseKVStorage):
result = {} result = {}
async for doc in cursor: async for doc in cursor:
doc_id = doc.pop("_id") doc_id = doc.pop("_id")
# Ensure time fields are present for all documents
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
result[doc_id] = doc result[doc_id] = doc
return result return result
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): # Unified handling for all namespaces with flattened keys
update_tasks: list[Any] = [] # Use bulk_write for better performance
for mode, items in data.items(): from pymongo import UpdateOne
for k, v in items.items():
key = f"{mode}_{k}"
data[mode][k]["_id"] = f"{mode}_{k}"
update_tasks.append(
self._data.update_one(
{"_id": key}, {"$setOnInsert": v}, upsert=True
)
)
await asyncio.gather(*update_tasks)
else:
update_tasks = []
for k, v in data.items():
data[k]["_id"] = k
update_tasks.append(
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
)
await asyncio.gather(*update_tasks)
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: operations = []
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): current_time = int(time.time()) # Get current Unix timestamp
res = {}
v = await self._data.find_one({"_id": mode + "_" + id}) for k, v in data.items():
if v: # For text_chunks namespace, ensure llm_cache_list field exists
res[id] = v if self.namespace.endswith("text_chunks"):
logger.debug(f"llm_response_cache find one by:{id}") if "llm_cache_list" not in v:
return res v["llm_cache_list"] = []
else:
return None # Create a copy of v for $set operation, excluding create_time to avoid conflicts
else: v_for_set = v.copy()
return None v_for_set["_id"] = k # Use flattened key as _id
v_for_set["update_time"] = current_time # Always update update_time
# Remove create_time from $set to avoid conflict with $setOnInsert
v_for_set.pop("create_time", None)
operations.append(
UpdateOne(
{"_id": k},
{
"$set": v_for_set, # Update all fields except create_time
"$setOnInsert": {
"create_time": current_time
}, # Set create_time only on insert
},
upsert=True,
)
)
if operations:
await self._data.bulk_write(operations)
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Mongo handles persistence automatically # Mongo handles persistence automatically
@ -197,8 +215,8 @@ class MongoKVStorage(BaseKVStorage):
return False return False
try: try:
# Build regex pattern to match documents with the specified modes # Build regex pattern to match flattened key format: mode:cache_type:hash
pattern = f"^({'|'.join(modes)})_" pattern = f"^({'|'.join(modes)}):"
result = await self._data.delete_many({"_id": {"$regex": pattern}}) result = await self._data.delete_many({"_id": {"$regex": pattern}})
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}") logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
return True return True
@ -262,11 +280,14 @@ class MongoDocStatusStorage(DocStatusStorage):
return data - existing_ids return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
update_tasks: list[Any] = [] update_tasks: list[Any] = []
for k, v in data.items(): for k, v in data.items():
# Ensure chunks_list field exists and is an array
if "chunks_list" not in v:
v["chunks_list"] = []
data[k]["_id"] = k data[k]["_id"] = k
update_tasks.append( update_tasks.append(
self._data.update_one({"_id": k}, {"$set": v}, upsert=True) self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
@ -299,6 +320,7 @@ class MongoDocStatusStorage(DocStatusStorage):
updated_at=doc.get("updated_at"), updated_at=doc.get("updated_at"),
chunks_count=doc.get("chunks_count", -1), chunks_count=doc.get("chunks_count", -1),
file_path=doc.get("file_path", doc["_id"]), file_path=doc.get("file_path", doc["_id"]),
chunks_list=doc.get("chunks_list", []),
) )
for doc in result for doc in result
} }
@ -417,11 +439,21 @@ class MongoGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
""" """
Check if there's a direct single-hop edge from source_node_id to target_node_id. Check if there's a direct single-hop edge between source_node_id and target_node_id.
""" """
# Direct check if the target_node appears among the edges array.
doc = await self.edge_collection.find_one( doc = await self.edge_collection.find_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id}, {
"$or": [
{
"source_node_id": source_node_id,
"target_node_id": target_node_id,
},
{
"source_node_id": target_node_id,
"target_node_id": source_node_id,
},
]
},
{"_id": 1}, {"_id": 1},
) )
return doc is not None return doc is not None
@ -651,7 +683,7 @@ class MongoGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
""" """
Upsert an edge from source_node_id -> target_node_id with optional 'relation'. Upsert an edge between source_node_id and target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data. If an edge with the same target exists, we remove it and re-insert with updated data.
""" """
# Ensure source node exists # Ensure source node exists
@ -663,8 +695,22 @@ class MongoGraphStorage(BaseGraphStorage):
GRAPH_FIELD_SEP GRAPH_FIELD_SEP
) )
edge_data["source_node_id"] = source_node_id
edge_data["target_node_id"] = target_node_id
await self.edge_collection.update_one( await self.edge_collection.update_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id}, {
"$or": [
{
"source_node_id": source_node_id,
"target_node_id": target_node_id,
},
{
"source_node_id": target_node_id,
"target_node_id": source_node_id,
},
]
},
update_doc, update_doc,
upsert=True, upsert=True,
) )
@ -678,7 +724,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
""" """
1) Remove node's doc entirely. 1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id. 2) Remove inbound & outbound edges from any doc that references node_id.
""" """
# Remove all edges # Remove all edges
await self.edge_collection.delete_many( await self.edge_collection.delete_many(
@ -709,141 +755,369 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"]) labels.append(doc["_id"])
return labels return labels
def _construct_graph_node(
self, node_id, node_data: dict[str, str]
) -> KnowledgeGraphNode:
return KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in node_data.items()
if k
not in [
"_id",
"connected_edges",
"source_ids",
"edge_count",
]
},
)
def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]):
return KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in [
"_id",
"source_node_id",
"target_node_id",
"relationship",
"source_ids",
]
},
)
async def get_knowledge_graph_all_by_degree(
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
"""
It's possible that the node with one or multiple relationships is retrieved,
while its neighbor is not. Then this node might seem like disconnected in UI.
"""
total_node_count = await self.collection.count_documents({})
result = KnowledgeGraph()
seen_edges = set()
result.is_truncated = total_node_count > max_nodes
if result.is_truncated:
# Get all node_ids ranked by degree if max_nodes exceeds total node count
pipeline = [
{"$project": {"source_node_id": 1, "_id": 0}},
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
{
"$unionWith": {
"coll": self._edge_collection_name,
"pipeline": [
{"$project": {"target_node_id": 1, "_id": 0}},
{
"$group": {
"_id": "$target_node_id",
"degree": {"$sum": 1},
}
},
],
}
},
{"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
{"$sort": {"degree": -1}},
{"$limit": max_nodes},
]
cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
node_ids = []
async for doc in cursor:
node_id = str(doc["_id"])
node_ids.append(node_id)
cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
# As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
edge_cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": node_ids}},
{"target_node_id": {"$in": node_ids}},
]
}
)
else:
# All nodes and edges are needed
cursor = self.collection.find({}, {"source_ids": 0})
async for doc in cursor:
node_id = str(doc["_id"])
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
edge_cursor = self.edge_collection.find({})
async for edge in edge_cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
seen_edges.add(edge_id)
result.edges.append(self._construct_graph_edge(edge_id, edge))
return result
async def _bidirectional_bfs_nodes(
self,
node_labels: list[str],
seen_nodes: set[str],
result: KnowledgeGraph,
depth: int = 0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
if depth > max_depth or len(result.nodes) > max_nodes:
return result
cursor = self.collection.find({"_id": {"$in": node_labels}})
async for node in cursor:
node_id = node["_id"]
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(self._construct_graph_node(node_id, node))
if len(result.nodes) > max_nodes:
return result
# Collect neighbors
# Get both inbound and outbound one hop nodes
cursor = self.edge_collection.find(
{
"$or": [
{"source_node_id": {"$in": node_labels}},
{"target_node_id": {"$in": node_labels}},
]
}
)
neighbor_nodes = []
async for edge in cursor:
if edge["source_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["source_node_id"])
if edge["target_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["target_node_id"])
if neighbor_nodes:
result = await self._bidirectional_bfs_nodes(
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
)
return result
async def get_knowledge_subgraph_bidirectional_bfs(
self,
node_label: str,
depth=0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
result = await self._bidirectional_bfs_nodes(
[node_label], seen_nodes, result, depth, max_depth, max_nodes
)
# Get all edges from seen_nodes
all_node_ids = list(seen_nodes)
cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": all_node_ids}},
{"target_node_id": {"$in": all_node_ids}},
]
}
)
async for edge in cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_subgraph_in_out_bound_bfs(
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
project_doc = {
"source_ids": 0,
"created_at": 0,
"entity_type": 0,
"file_path": 0,
}
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": node_label})
if not start_node:
logger.warning(f"Starting node with label {node_label} does not exist!")
return result
seen_nodes.add(node_label)
result.nodes.append(self._construct_graph_node(node_label, start_node))
if max_depth == 0:
return result
# In MongoDB, depth = 0 means one-hop
max_depth = max_depth - 1
pipeline = [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{
"$unionWith": {
"coll": self._collection_name,
"pipeline": [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "source_node_id",
"connectToField": "target_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
}
},
],
}
},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
node_edges = []
# Two records for node_label are returned capturing outbound and inbound connected_edges
async for doc in cursor:
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
# Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
node_edges = sorted(
node_edges,
key=lambda x: (x["depth"], -x["weight"]),
)
# As order matters, we need to use another list to store the node_id
# And only take the first max_nodes ones
node_ids = []
for edge in node_edges:
if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"])
if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
# Filter out all the node whose id is same as node_label so that we do not check existence next step
cursor = self.collection.find({"_id": {"$in": node_ids}})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_graph( async def get_knowledge_graph(
self, self,
node_label: str, node_label: str,
max_depth: int = 5, max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args: Args:
node_label: Label of the nodes to start from node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of traversal (default: 5) max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000
Returns: Returns:
KnowledgeGraph object containing nodes and edges of the subgraph KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
If a graph is like this and starting from B:
A B C F, B -> E, C D
Outbound BFS:
B E
Inbound BFS:
A B
C B
F C
Bidirectional BFS:
A B
B E
F C
C B
C D
""" """
label = node_label
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() start = time.perf_counter()
seen_edges = set()
node_edges = []
try: try:
# Optimize pipeline to avoid memory issues with large datasets # Optimize pipeline to avoid memory issues with large datasets
if label == "*": if node_label == "*":
# For getting all nodes, use a simpler pipeline to avoid memory issues result = await self.get_knowledge_graph_all_by_degree(
pipeline = [ max_depth, max_nodes
{"$limit": max_nodes}, # Limit early to reduce memory usage )
{ elif GRAPH_BFS_MODE == "in_out_bound":
"$graphLookup": { result = await self.get_knowledge_subgraph_in_out_bound_bfs(
"from": self._edge_collection_name, node_label, max_depth, max_nodes
"startWith": "$_id", )
"connectFromField": "target_node_id", else:
"connectToField": "source_node_id", result = await self.get_knowledge_subgraph_bidirectional_bfs(
"maxDepth": max_depth, node_label, 0, max_depth, max_nodes
"depthField": "depth",
"as": "connected_edges",
},
},
]
# Check if we need to set truncation flag
all_node_count = await self.collection.count_documents({})
result.is_truncated = all_node_count > max_nodes
else:
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": label})
if not start_node:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# For specific node queries, use the original pipeline but optimized
pipeline = [
{"$match": {"_id": label}},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
{"$sort": {"edge_count": -1}},
{"$limit": max_nodes},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
nodes_processed = 0
async for doc in cursor:
# Add the start node
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
) )
seen_nodes.add(node_id)
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
nodes_processed += 1 duration = time.perf_counter() - start
# Additional safety check to prevent memory issues
if nodes_processed >= max_nodes:
result.is_truncated = True
break
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in [
"_id",
"source_node_id",
"target_node_id",
"relationship",
]
},
)
)
seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
) )
except PyMongoError as e: except PyMongoError as e:
@ -856,13 +1130,8 @@ class MongoGraphStorage(BaseGraphStorage):
try: try:
simple_cursor = self.collection.find({}).limit(max_nodes) simple_cursor = self.collection.find({}).limit(max_nodes)
async for doc in simple_cursor: async for doc in simple_cursor:
node_id = str(doc["_id"])
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( self._construct_graph_node(str(doc["_id"]), doc)
id=node_id,
labels=[node_id],
properties={k: v for k, v in doc.items() if k != "_id"},
)
) )
result.is_truncated = True result.is_truncated = True
logger.info( logger.info(
@ -1023,13 +1292,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug("vector index already exist") logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
# Add current time as Unix timestamp # Add current time as Unix timestamp
import time
current_time = int(time.time()) current_time = int(time.time())
list_data = [ list_data = [
@ -1114,7 +1381,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
Args: Args:
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
if not ids: if not ids:
return return

View file

@ -106,7 +106,9 @@ class NetworkXStorage(BaseGraphStorage):
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
graph = await self._get_graph() graph = await self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id) src_degree = graph.degree(src_id) if graph.has_node(src_id) else 0
tgt_degree = graph.degree(tgt_id) if graph.has_node(tgt_id) else 0
return src_degree + tgt_degree
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str

View file

@ -136,6 +136,52 @@ class PostgreSQLDB:
except Exception as e: except Exception as e:
logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}") logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}")
async def _migrate_llm_cache_add_cache_type(self):
"""Add cache_type column to LIGHTRAG_LLM_CACHE table if it doesn't exist"""
try:
# Check if cache_type column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'lightrag_llm_cache'
AND column_name = 'cache_type'
"""
column_info = await self.query(check_column_sql)
if not column_info:
logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table")
add_column_sql = """
ALTER TABLE LIGHTRAG_LLM_CACHE
ADD COLUMN cache_type VARCHAR(32) NULL
"""
await self.execute(add_column_sql)
logger.info(
"Successfully added cache_type column to LIGHTRAG_LLM_CACHE table"
)
# Migrate existing data: extract cache_type from flattened keys
logger.info(
"Migrating existing LLM cache data to populate cache_type field"
)
update_sql = """
UPDATE LIGHTRAG_LLM_CACHE
SET cache_type = CASE
WHEN id LIKE '%:%:%' THEN split_part(id, ':', 2)
ELSE 'extract'
END
WHERE cache_type IS NULL
"""
await self.execute(update_sql)
logger.info("Successfully migrated existing LLM cache data")
else:
logger.info(
"cache_type column already exists in LIGHTRAG_LLM_CACHE table"
)
except Exception as e:
logger.warning(
f"Failed to add cache_type column to LIGHTRAG_LLM_CACHE: {e}"
)
async def _migrate_timestamp_columns(self): async def _migrate_timestamp_columns(self):
"""Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time""" """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time"""
# Tables and columns that need migration # Tables and columns that need migration
@ -189,6 +235,239 @@ class PostgreSQLDB:
# Log error but don't interrupt the process # Log error but don't interrupt the process
logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}") logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}")
async def _migrate_doc_chunks_to_vdb_chunks(self):
"""
Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met.
This migration is intended for users who are upgrading and have an older table structure
where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column.
"""
try:
# 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty
vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS"
vdb_chunks_count_result = await self.query(vdb_chunks_count_sql)
if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0:
logger.info(
"Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data."
)
return
# 2. Check if `content_vector` column exists in the old table
check_column_sql = """
SELECT 1 FROM information_schema.columns
WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector'
"""
column_exists = await self.query(check_column_sql)
if not column_exists:
logger.info(
"Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS"
)
return
# 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data
doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS"
doc_chunks_count_result = await self.query(doc_chunks_count_sql)
if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0:
logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.")
return
# 4. Perform the migration
logger.info(
"Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..."
)
migration_sql = """
INSERT INTO LIGHTRAG_VDB_CHUNKS (
id, workspace, full_doc_id, chunk_order_index, tokens, content,
content_vector, file_path, create_time, update_time
)
SELECT
id, workspace, full_doc_id, chunk_order_index, tokens, content,
content_vector, file_path, create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS
ON CONFLICT (workspace, id) DO NOTHING;
"""
await self.execute(migration_sql)
logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.")
except Exception as e:
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
# Do not re-raise, to allow the application to start
async def _check_llm_cache_needs_migration(self):
"""Check if LLM cache data needs migration by examining the first record"""
try:
# Only query the first record to determine format
check_sql = """
SELECT id FROM LIGHTRAG_LLM_CACHE
ORDER BY create_time ASC
LIMIT 1
"""
result = await self.query(check_sql)
if result and result.get("id"):
# If id doesn't contain colon, it's old format
return ":" not in result["id"]
return False # No data or already new format
except Exception as e:
logger.warning(f"Failed to check LLM cache migration status: {e}")
return False
async def _migrate_llm_cache_to_flattened_keys(self):
"""Migrate LLM cache to flattened key format, recalculating hash values"""
try:
# Get all old format data
old_data_sql = """
SELECT id, mode, original_prompt, return_value, chunk_id,
create_time, update_time
FROM LIGHTRAG_LLM_CACHE
WHERE id NOT LIKE '%:%'
"""
old_records = await self.query(old_data_sql, multirows=True)
if not old_records:
logger.info("No old format LLM cache data found, skipping migration")
return
logger.info(
f"Found {len(old_records)} old format cache records, starting migration..."
)
# Import hash calculation function
from ..utils import compute_args_hash
migrated_count = 0
# Migrate data in batches
for record in old_records:
try:
# Recalculate hash using correct method
new_hash = compute_args_hash(
record["mode"], record["original_prompt"]
)
# Determine cache_type based on mode
cache_type = "extract" if record["mode"] == "default" else "unknown"
# Generate new flattened key
new_key = f"{record['mode']}:{cache_type}:{new_hash}"
# Insert new format data with cache_type field
insert_sql = """
INSERT INTO LIGHTRAG_LLM_CACHE
(workspace, id, mode, original_prompt, return_value, chunk_id, cache_type, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (workspace, mode, id) DO NOTHING
"""
await self.execute(
insert_sql,
{
"workspace": self.workspace,
"id": new_key,
"mode": record["mode"],
"original_prompt": record["original_prompt"],
"return_value": record["return_value"],
"chunk_id": record["chunk_id"],
"cache_type": cache_type, # Add cache_type field
"create_time": record["create_time"],
"update_time": record["update_time"],
},
)
# Delete old data
delete_sql = """
DELETE FROM LIGHTRAG_LLM_CACHE
WHERE workspace=$1 AND mode=$2 AND id=$3
"""
await self.execute(
delete_sql,
{
"workspace": self.workspace,
"mode": record["mode"],
"id": record["id"], # Old id
},
)
migrated_count += 1
except Exception as e:
logger.warning(
f"Failed to migrate cache record {record['id']}: {e}"
)
continue
logger.info(
f"Successfully migrated {migrated_count} cache records to flattened format"
)
except Exception as e:
logger.error(f"LLM cache migration failed: {e}")
# Don't raise exception, allow system to continue startup
async def _migrate_doc_status_add_chunks_list(self):
"""Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist"""
try:
# Check if chunks_list column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'lightrag_doc_status'
AND column_name = 'chunks_list'
"""
column_info = await self.query(check_column_sql)
if not column_info:
logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table")
add_column_sql = """
ALTER TABLE LIGHTRAG_DOC_STATUS
ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb
"""
await self.execute(add_column_sql)
logger.info(
"Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table"
)
else:
logger.info(
"chunks_list column already exists in LIGHTRAG_DOC_STATUS table"
)
except Exception as e:
logger.warning(
f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}"
)
async def _migrate_text_chunks_add_llm_cache_list(self):
"""Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist"""
try:
# Check if llm_cache_list column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'lightrag_doc_chunks'
AND column_name = 'llm_cache_list'
"""
column_info = await self.query(check_column_sql)
if not column_info:
logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table")
add_column_sql = """
ALTER TABLE LIGHTRAG_DOC_CHUNKS
ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb
"""
await self.execute(add_column_sql)
logger.info(
"Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table"
)
else:
logger.info(
"llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table"
)
except Exception as e:
logger.warning(
f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}"
)
async def check_tables(self): async def check_tables(self):
# First create all tables # First create all tables
for k, v in TABLES.items(): for k, v in TABLES.items():
@ -240,6 +519,44 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}") logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}")
# Don't throw an exception, allow the initialization process to continue # Don't throw an exception, allow the initialization process to continue
# Migrate LLM cache table to add cache_type field if needed
try:
await self._migrate_llm_cache_add_cache_type()
except Exception as e:
logger.error(
f"PostgreSQL, Failed to migrate LLM cache cache_type field: {e}"
)
# Don't throw an exception, allow the initialization process to continue
# Finally, attempt to migrate old doc chunks data if needed
try:
await self._migrate_doc_chunks_to_vdb_chunks()
except Exception as e:
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
# Check and migrate LLM cache to flattened keys if needed
try:
if await self._check_llm_cache_needs_migration():
await self._migrate_llm_cache_to_flattened_keys()
except Exception as e:
logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
# Migrate doc status to add chunks_list field if needed
try:
await self._migrate_doc_status_add_chunks_list()
except Exception as e:
logger.error(
f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}"
)
# Migrate text chunks to add llm_cache_list field if needed
try:
await self._migrate_text_chunks_add_llm_cache_list()
except Exception as e:
logger.error(
f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}"
)
async def query( async def query(
self, self,
sql: str, sql: str,
@ -423,74 +740,139 @@ class PGKVStorage(BaseKVStorage):
try: try:
results = await self.db.query(sql, params, multirows=True) results = await self.db.query(sql, params, multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
result_dict = {} processed_results = {}
for row in results: for row in results:
mode = row["mode"] create_time = row.get("create_time", 0)
if mode not in result_dict: update_time = row.get("update_time", 0)
result_dict[mode] = {} # Map field names and add cache_type for compatibility
result_dict[mode][row["id"]] = row processed_row = {
return result_dict **row,
else: "return": row.get("return_value", ""),
return {row["id"]: row for row in results} "cache_type": row.get("original_prompt", "unknow"),
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
processed_results[row["id"]] = processed_row
return processed_results
# For text_chunks namespace, parse llm_cache_list JSON string back to list
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
processed_results = {}
for row in results:
llm_cache_list = row.get("llm_cache_list", [])
if isinstance(llm_cache_list, str):
try:
llm_cache_list = json.loads(llm_cache_list)
except json.JSONDecodeError:
llm_cache_list = []
row["llm_cache_list"] = llm_cache_list
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
row["create_time"] = create_time
row["update_time"] = (
create_time if update_time == 0 else update_time
)
processed_results[row["id"]] = row
return processed_results
# For other namespaces, return as-is
return {row["id"]: row for row in results}
except Exception as e: except Exception as e:
logger.error(f"Error retrieving all data from {self.namespace}: {e}") logger.error(f"Error retrieving all data from {self.namespace}: {e}")
return {} return {}
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data by id.""" """Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace] sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id} params = {"workspace": self.db.workspace, "id": id}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): response = await self.db.query(sql, params)
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res if res else None
else:
response = await self.db.query(sql, params)
return response if response else None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
"""Specifically for llm_response_cache.""" # Parse llm_cache_list JSON string back to list
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] llm_cache_list = response.get("llm_cache_list", [])
params = {"workspace": self.db.workspace, "mode": mode, "id": id} if isinstance(llm_cache_list, str):
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): try:
array_res = await self.db.query(sql, params, multirows=True) llm_cache_list = json.loads(llm_cache_list)
res = {} except json.JSONDecodeError:
for row in array_res: llm_cache_list = []
res[row["id"]] = row response["llm_cache_list"] = llm_cache_list
return res create_time = response.get("create_time", 0)
else: update_time = response.get("update_time", 0)
return None response["create_time"] = create_time
response["update_time"] = create_time if update_time == 0 else update_time
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if response and is_namespace(
self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
):
create_time = response.get("create_time", 0)
update_time = response.get("update_time", 0)
# Map field names and add cache_type for compatibility
response = {
**response,
"return": response.get("return_value", ""),
"cache_type": response.get("cache_type"),
"original_prompt": response.get("original_prompt", ""),
"chunk_id": response.get("chunk_id"),
"mode": response.get("mode", "default"),
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
return response if response else None
# Query by id # Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id""" """Get data by ids"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids]) ids=",".join([f"'{id}'" for id in ids])
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): results = await self.db.query(sql, params, multirows=True)
array_res = await self.db.query(sql, params, multirows=True)
modes = set()
dict_res: dict[str, dict] = {}
for row in array_res:
modes.add(row["mode"])
for mode in modes:
if mode not in dict_res:
dict_res[mode] = {}
for row in array_res:
dict_res[row["mode"]][row["id"]] = row
return [{k: v} for k, v in dict_res.items()]
else:
return await self.db.query(sql, params, multirows=True)
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
"""Specifically for llm_response_cache.""" # Parse llm_cache_list JSON string back to list for each result
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] for result in results:
params = {"workspace": self.db.workspace, "status": status} llm_cache_list = result.get("llm_cache_list", [])
return await self.db.query(SQL, params, multirows=True) if isinstance(llm_cache_list, str):
try:
llm_cache_list = json.loads(llm_cache_list)
except json.JSONDecodeError:
llm_cache_list = []
result["llm_cache_list"] = llm_cache_list
create_time = result.get("create_time", 0)
update_time = result.get("update_time", 0)
result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if results and is_namespace(
self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
):
processed_results = []
for row in results:
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
# Map field names and add cache_type for compatibility
processed_row = {
**row,
"return": row.get("return_value", ""),
"cache_type": row.get("cache_type"),
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
processed_results.append(processed_row)
return processed_results
return results if results else []
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
@ -520,7 +902,22 @@ class PGKVStorage(BaseKVStorage):
return return
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass current_time = datetime.datetime.now(timezone.utc)
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_text_chunk"]
_data = {
"workspace": self.db.workspace,
"id": k,
"tokens": v["tokens"],
"chunk_order_index": v["chunk_order_index"],
"full_doc_id": v["full_doc_id"],
"content": v["content"],
"file_path": v["file_path"],
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
"create_time": current_time,
"update_time": current_time,
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
for k, v in data.items(): for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_doc_full"] upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
@ -531,19 +928,21 @@ class PGKVStorage(BaseKVStorage):
} }
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items(): for k, v in data.items():
for k, v in items.items(): upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] _data = {
_data = { "workspace": self.db.workspace,
"workspace": self.db.workspace, "id": k, # Use flattened key as id
"id": k, "original_prompt": v["original_prompt"],
"original_prompt": v["original_prompt"], "return_value": v["return"],
"return_value": v["return"], "mode": v.get("mode", "default"), # Get mode from data
"mode": mode, "chunk_id": v.get("chunk_id"),
"chunk_id": v.get("chunk_id"), "cache_type": v.get(
} "cache_type", "extract"
), # Get cache_type from data
}
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# PG handles persistence automatically # PG handles persistence automatically
@ -949,8 +1348,8 @@ class PGDocStatusStorage(DocStatusStorage):
else: else:
exist_keys = [] exist_keys = []
new_keys = set([s for s in keys if s not in exist_keys]) new_keys = set([s for s in keys if s not in exist_keys])
print(f"keys: {keys}") # print(f"keys: {keys}")
print(f"new_keys: {new_keys}") # print(f"new_keys: {new_keys}")
return new_keys return new_keys
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -965,6 +1364,14 @@ class PGDocStatusStorage(DocStatusStorage):
if result is None or result == []: if result is None or result == []:
return None return None
else: else:
# Parse chunks_list JSON string back to list
chunks_list = result[0].get("chunks_list", [])
if isinstance(chunks_list, str):
try:
chunks_list = json.loads(chunks_list)
except json.JSONDecodeError:
chunks_list = []
return dict( return dict(
content=result[0]["content"], content=result[0]["content"],
content_length=result[0]["content_length"], content_length=result[0]["content_length"],
@ -974,6 +1381,7 @@ class PGDocStatusStorage(DocStatusStorage):
created_at=result[0]["created_at"], created_at=result[0]["created_at"],
updated_at=result[0]["updated_at"], updated_at=result[0]["updated_at"],
file_path=result[0]["file_path"], file_path=result[0]["file_path"],
chunks_list=chunks_list,
) )
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -988,19 +1396,32 @@ class PGDocStatusStorage(DocStatusStorage):
if not results: if not results:
return [] return []
return [
{ processed_results = []
"content": row["content"], for row in results:
"content_length": row["content_length"], # Parse chunks_list JSON string back to list
"content_summary": row["content_summary"], chunks_list = row.get("chunks_list", [])
"status": row["status"], if isinstance(chunks_list, str):
"chunks_count": row["chunks_count"], try:
"created_at": row["created_at"], chunks_list = json.loads(chunks_list)
"updated_at": row["updated_at"], except json.JSONDecodeError:
"file_path": row["file_path"], chunks_list = []
}
for row in results processed_results.append(
] {
"content": row["content"],
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"file_path": row["file_path"],
"chunks_list": chunks_list,
}
)
return processed_results
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
@ -1021,8 +1442,18 @@ class PGDocStatusStorage(DocStatusStorage):
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status.value} params = {"workspace": self.db.workspace, "status": status.value}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, params, True)
docs_by_status = {
element["id"]: DocProcessingStatus( docs_by_status = {}
for element in result:
# Parse chunks_list JSON string back to list
chunks_list = element.get("chunks_list", [])
if isinstance(chunks_list, str):
try:
chunks_list = json.loads(chunks_list)
except json.JSONDecodeError:
chunks_list = []
docs_by_status[element["id"]] = DocProcessingStatus(
content=element["content"], content=element["content"],
content_summary=element["content_summary"], content_summary=element["content_summary"],
content_length=element["content_length"], content_length=element["content_length"],
@ -1031,9 +1462,9 @@ class PGDocStatusStorage(DocStatusStorage):
updated_at=element["updated_at"], updated_at=element["updated_at"],
chunks_count=element["chunks_count"], chunks_count=element["chunks_count"],
file_path=element["file_path"], file_path=element["file_path"],
chunks_list=chunks_list,
) )
for element in result
}
return docs_by_status return docs_by_status
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
@ -1097,10 +1528,10 @@ class PGDocStatusStorage(DocStatusStorage):
logger.warning(f"Unable to parse datetime string: {dt_str}") logger.warning(f"Unable to parse datetime string: {dt_str}")
return None return None
# Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations # Modified SQL to include created_at, updated_at, and chunks_list in both INSERT and UPDATE operations
# Both fields are updated from the input data in both INSERT and UPDATE cases # All fields are updated from the input data in both INSERT and UPDATE cases
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at) sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,chunks_list,created_at,updated_at)
values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
on conflict(id,workspace) do update set on conflict(id,workspace) do update set
content = EXCLUDED.content, content = EXCLUDED.content,
content_summary = EXCLUDED.content_summary, content_summary = EXCLUDED.content_summary,
@ -1108,6 +1539,7 @@ class PGDocStatusStorage(DocStatusStorage):
chunks_count = EXCLUDED.chunks_count, chunks_count = EXCLUDED.chunks_count,
status = EXCLUDED.status, status = EXCLUDED.status,
file_path = EXCLUDED.file_path, file_path = EXCLUDED.file_path,
chunks_list = EXCLUDED.chunks_list,
created_at = EXCLUDED.created_at, created_at = EXCLUDED.created_at,
updated_at = EXCLUDED.updated_at""" updated_at = EXCLUDED.updated_at"""
for k, v in data.items(): for k, v in data.items():
@ -1115,7 +1547,7 @@ class PGDocStatusStorage(DocStatusStorage):
created_at = parse_datetime(v.get("created_at")) created_at = parse_datetime(v.get("created_at"))
updated_at = parse_datetime(v.get("updated_at")) updated_at = parse_datetime(v.get("updated_at"))
# chunks_count is optional # chunks_count and chunks_list are optional
await self.db.execute( await self.db.execute(
sql, sql,
{ {
@ -1127,6 +1559,7 @@ class PGDocStatusStorage(DocStatusStorage):
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1, "chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"], "status": v["status"],
"file_path": v["file_path"], "file_path": v["file_path"],
"chunks_list": json.dumps(v.get("chunks_list", [])),
"created_at": created_at, # Use the converted datetime object "created_at": created_at, # Use the converted datetime object
"updated_at": updated_at, # Use the converted datetime object "updated_at": updated_at, # Use the converted datetime object
}, },
@ -2409,7 +2842,7 @@ class PGGraphStorage(BaseGraphStorage):
NAMESPACE_TABLE_MAP = { NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION", NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS", NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
@ -2438,6 +2871,21 @@ TABLES = {
}, },
"LIGHTRAG_DOC_CHUNKS": { "LIGHTRAG_DOC_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id VARCHAR(255),
workspace VARCHAR(255),
full_doc_id VARCHAR(256),
chunk_order_index INTEGER,
tokens INTEGER,
content TEXT,
file_path VARCHAR(256),
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
create_time TIMESTAMP(0) WITH TIME ZONE,
update_time TIMESTAMP(0) WITH TIME ZONE,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
"LIGHTRAG_VDB_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
id VARCHAR(255), id VARCHAR(255),
workspace VARCHAR(255), workspace VARCHAR(255),
full_doc_id VARCHAR(256), full_doc_id VARCHAR(256),
@ -2448,7 +2896,7 @@ TABLES = {
file_path VARCHAR(256), file_path VARCHAR(256),
create_time TIMESTAMP(0) WITH TIME ZONE, create_time TIMESTAMP(0) WITH TIME ZONE,
update_time TIMESTAMP(0) WITH TIME ZONE, update_time TIMESTAMP(0) WITH TIME ZONE,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
)""" )"""
}, },
"LIGHTRAG_VDB_ENTITY": { "LIGHTRAG_VDB_ENTITY": {
@ -2503,6 +2951,7 @@ TABLES = {
chunks_count int4 NULL, chunks_count int4 NULL,
status varchar(64) NULL, status varchar(64) NULL,
file_path TEXT NULL, file_path TEXT NULL,
chunks_list JSONB NULL DEFAULT '[]'::jsonb,
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
@ -2517,24 +2966,30 @@ SQL_TEMPLATES = {
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
""", """,
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""", """,
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 create_time, update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
""", """,
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
""", """,
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
""", """,
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""", """,
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids}) create_time, update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
""", """,
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
@ -2542,16 +2997,31 @@ SQL_TEMPLATES = {
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET content = $2, update_time = CURRENT_TIMESTAMP SET content = $2, update_time = CURRENT_TIMESTAMP
""", """,
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id) "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (workspace,mode,id) DO UPDATE ON CONFLICT (workspace,mode,id) DO UPDATE
SET original_prompt = EXCLUDED.original_prompt, SET original_prompt = EXCLUDED.original_prompt,
return_value=EXCLUDED.return_value, return_value=EXCLUDED.return_value,
mode=EXCLUDED.mode, mode=EXCLUDED.mode,
chunk_id=EXCLUDED.chunk_id, chunk_id=EXCLUDED.chunk_id,
cache_type=EXCLUDED.cache_type,
update_time = CURRENT_TIMESTAMP update_time = CURRENT_TIMESTAMP
""", """,
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (workspace,id) DO UPDATE
SET tokens=EXCLUDED.tokens,
chunk_order_index=EXCLUDED.chunk_order_index,
full_doc_id=EXCLUDED.full_doc_id,
content = EXCLUDED.content,
file_path=EXCLUDED.file_path,
llm_cache_list=EXCLUDED.llm_cache_list,
update_time = EXCLUDED.update_time
""",
# SQL for VectorStorage
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector, file_path, chunk_order_index, full_doc_id, content, content_vector, file_path,
create_time, update_time) create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@ -2564,7 +3034,6 @@ SQL_TEMPLATES = {
file_path=EXCLUDED.file_path, file_path=EXCLUDED.file_path,
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time
""", """,
# SQL for VectorStorage
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path, create_time, update_time) content_vector, chunk_ids, file_path, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9) VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
@ -2591,7 +3060,7 @@ SQL_TEMPLATES = {
"relationships": """ "relationships": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (
SELECT id as chunk_id SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
) )
SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at
@ -2608,7 +3077,7 @@ SQL_TEMPLATES = {
"entities": """ "entities": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (
SELECT id as chunk_id SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
) )
SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
@ -2625,13 +3094,13 @@ SQL_TEMPLATES = {
"chunks": """ "chunks": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (
SELECT id as chunk_id SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
) )
SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
( (
SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_DOC_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE workspace=$1 WHERE workspace=$1
AND id IN (SELECT chunk_id FROM relevant_chunks) AND id IN (SELECT chunk_id FROM relevant_chunks)
) as chunk_distances ) as chunk_distances

View file

@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return

View file

@ -1,9 +1,10 @@
import os import os
from typing import Any, final from typing import Any, final, Union
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
import configparser import configparser
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import threading
if not pm.is_installed("redis"): if not pm.is_installed("redis"):
pm.install("redis") pm.install("redis")
@ -13,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
from redis.exceptions import RedisError, ConnectionError # type: ignore from redis.exceptions import RedisError, ConnectionError # type: ignore
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseKVStorage from lightrag.base import (
BaseKVStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
)
import json import json
@ -26,6 +32,41 @@ SOCKET_TIMEOUT = 5.0
SOCKET_CONNECT_TIMEOUT = 3.0 SOCKET_CONNECT_TIMEOUT = 3.0
class RedisConnectionManager:
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
_pools = {}
_lock = threading.Lock()
@classmethod
def get_pool(cls, redis_url: str) -> ConnectionPool:
"""Get or create a connection pool for the given Redis URL"""
if redis_url not in cls._pools:
with cls._lock:
if redis_url not in cls._pools:
cls._pools[redis_url] = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
logger.info(f"Created shared Redis connection pool for {redis_url}")
return cls._pools[redis_url]
@classmethod
def close_all_pools(cls):
"""Close all connection pools (for cleanup)"""
with cls._lock:
for url, pool in cls._pools.items():
try:
pool.disconnect()
logger.info(f"Closed Redis connection pool for {url}")
except Exception as e:
logger.error(f"Error closing Redis pool for {url}: {e}")
cls._pools.clear()
@final @final
@dataclass @dataclass
class RedisKVStorage(BaseKVStorage): class RedisKVStorage(BaseKVStorage):
@ -33,19 +74,28 @@ class RedisKVStorage(BaseKVStorage):
redis_url = os.environ.get( redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
) )
# Create a connection pool with limits # Use shared connection pool
self._pool = ConnectionPool.from_url( self._pool = RedisConnectionManager.get_pool(redis_url)
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
self._redis = Redis(connection_pool=self._pool) self._redis = Redis(connection_pool=self._pool)
logger.info( logger.info(
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections" f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
) )
async def initialize(self):
"""Initialize Redis connection and migrate legacy cache structure if needed"""
# Test connection
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(f"Connected to Redis for namespace {self.namespace}")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
# Migrate legacy cache structure if this is a cache namespace
if self.namespace.endswith("_cache"):
await self._migrate_legacy_cache_structure()
@asynccontextmanager @asynccontextmanager
async def _get_redis_connection(self): async def _get_redis_connection(self):
"""Safe context manager for Redis operations.""" """Safe context manager for Redis operations."""
@ -82,7 +132,13 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
data = await redis.get(f"{self.namespace}:{id}") data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None if data:
result = json.loads(data)
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
return result
return None
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}") logger.error(f"JSON decode error for id {id}: {e}")
return None return None
@ -94,35 +150,113 @@ class RedisKVStorage(BaseKVStorage):
for id in ids: for id in ids:
pipe.get(f"{self.namespace}:{id}") pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute() results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
processed_results = []
for result in results:
if result:
data = json.loads(result)
# Ensure time fields are present for all documents
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
processed_results.append(data)
else:
processed_results.append(None)
return processed_results
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}") logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids) return [None] * len(ids)
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
async with self._get_redis_connection() as redis:
try:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
if not keys:
return {}
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Build result dictionary
result = {}
for key, value in zip(keys, values):
if value:
# Extract the ID part (after namespace:)
key_id = key.split(":", 1)[1]
try:
data = json.loads(value)
# Ensure time fields are present for all documents
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
result[key_id] = data
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for key {key}: {e}")
continue
return result
except Exception as e:
logger.error(f"Error getting all data from Redis: {e}")
return {}
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
pipe = redis.pipeline() pipe = redis.pipeline()
for key in keys: keys_list = list(keys) # Convert set to list for indexing
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}") pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute() results = await pipe.execute()
existing_ids = {keys[i] for i, exists in enumerate(results) if exists} existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data: if not data:
return return
logger.info(f"Inserting {len(data)} items to {self.namespace}") import time
current_time = int(time.time()) # Get current Unix timestamp
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
# Check which keys already exist to determine create vs update
pipe = redis.pipeline()
for k in data.keys():
pipe.exists(f"{self.namespace}:{k}")
exists_results = await pipe.execute()
# Add timestamps to data
for i, (k, v) in enumerate(data.items()):
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Add timestamps based on whether key exists
if exists_results[i]: # Key exists, only update update_time
v["update_time"] = current_time
else: # New key, set both create_time and update_time
v["create_time"] = current_time
v["update_time"] = current_time
v["_id"] = k
# Store the data
pipe = redis.pipeline() pipe = redis.pipeline()
for k, v in data.items(): for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute() await pipe.execute()
for k in data:
data[k]["_id"] = k
except json.JSONEncodeError as e: except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}") logger.error(f"JSON encode error during upsert: {e}")
raise raise
@ -148,13 +282,13 @@ class RedisKVStorage(BaseKVStorage):
) )
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode """Delete specific records from storage by cache mode
Importance notes for Redis storage: Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis 1. This will immediately delete the specified cache modes from Redis
Args: Args:
modes (list[str]): List of cache mode to be drop from storage modes (list[str]): List of cache modes to be dropped from storage
Returns: Returns:
True: if the cache drop successfully True: if the cache drop successfully
@ -164,9 +298,47 @@ class RedisKVStorage(BaseKVStorage):
return False return False
try: try:
await self.delete(modes) async with self._get_redis_connection() as redis:
keys_to_delete = []
# Find matching keys for each mode using SCAN
for mode in modes:
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
pattern = f"{self.namespace}:{mode}:*"
cursor = 0
mode_keys = []
while True:
cursor, keys = await redis.scan(
cursor, match=pattern, count=1000
)
if keys:
mode_keys.extend(keys)
if cursor == 0:
break
keys_to_delete.extend(mode_keys)
logger.info(
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
)
if keys_to_delete:
# Batch delete
pipe = redis.pipeline()
for key in keys_to_delete:
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Dropped {deleted_count} cache entries for modes: {modes}"
)
else:
logger.warning(f"No cache entries found for modes: {modes}")
return True return True
except Exception: except Exception as e:
logger.error(f"Error dropping cache by modes in Redis: {e}")
return False return False
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
@ -177,24 +349,370 @@ class RedisKVStorage(BaseKVStorage):
""" """
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
keys = await redis.keys(f"{self.namespace}:*") # Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
cursor = 0
deleted_count = 0
if keys: while True:
pipe = redis.pipeline() cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
for key in keys: if keys:
pipe.delete(key) # Delete keys in batches
results = await pipe.execute() pipe = redis.pipeline()
deleted_count = sum(results) for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
logger.info(f"Dropped {deleted_count} keys from {self.namespace}") if cursor == 0:
return { break
"status": "success",
"message": f"{deleted_count} keys dropped", logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
} return {
else: "status": "success",
logger.info(f"No keys found to drop in {self.namespace}") "message": f"{deleted_count} keys dropped",
return {"status": "success", "message": "no keys to drop"} }
except Exception as e: except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}") logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self):
"""Migrate legacy nested cache structure to flattened structure for Redis
Redis already stores data in a flattened way, but we need to check for
legacy keys that might contain nested JSON structures and migrate them.
Early exit if any flattened key is found (indicating migration already done).
"""
from lightrag.utils import generate_cache_key
async with self._get_redis_connection() as redis:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
if not keys:
return
# Check if we have any flattened keys already - if so, skip migration
has_flattened_keys = False
keys_to_migrate = []
for key in keys:
# Extract the ID part (after namespace:)
key_id = key.split(":", 1)[1]
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
if ":" in key_id and len(key_id.split(":")) == 3:
has_flattened_keys = True
break # Early exit - migration already done
# Get the data to check if it's a legacy nested structure
data = await redis.get(key)
if data:
try:
parsed_data = json.loads(data)
# Check if this looks like a legacy cache mode with nested structure
if isinstance(parsed_data, dict) and all(
isinstance(v, dict) and "return" in v
for v in parsed_data.values()
):
keys_to_migrate.append((key, key_id, parsed_data))
except json.JSONDecodeError:
continue
# If we found any flattened keys, assume migration is already done
if has_flattened_keys:
logger.debug(
f"Found flattened cache keys in {self.namespace}, skipping migration"
)
return
if not keys_to_migrate:
return
# Perform migration
pipe = redis.pipeline()
migration_count = 0
for old_key, mode, nested_data in keys_to_migrate:
# Delete the old key
pipe.delete(old_key)
# Create new flattened keys
for cache_hash, cache_entry in nested_data.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
full_key = f"{self.namespace}:{flattened_key}"
pipe.set(full_key, json.dumps(cache_entry))
migration_count += 1
await pipe.execute()
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
)
@final
@dataclass
class RedisDocStatusStorage(DocStatusStorage):
"""Redis implementation of document status storage"""
def __post_init__(self):
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
# Use shared connection pool
self._pool = RedisConnectionManager.get_pool(redis_url)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
)
async def initialize(self):
"""Initialize Redis connection"""
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(
f"Connected to Redis for doc status namespace {self.namespace}"
)
except Exception as e:
logger.error(f"Failed to connect to Redis for doc status: {e}")
raise
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
try:
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
raise
except RedisError as e:
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
raise
except Exception as e:
logger.error(
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
)
raise
async def close(self):
"""Close the Redis connection."""
if hasattr(self, "_redis") and self._redis:
await self._redis.close()
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
async def __aenter__(self):
"""Support for async context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure Redis resources are cleaned up when exiting context."""
await self.close()
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
keys_list = list(keys)
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
for result_data in results:
if result_data:
try:
result.append(json.loads(result_data))
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in get_by_ids: {e}")
continue
except Exception as e:
logger.error(f"Error in get_by_ids: {e}")
return result
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus}
async with self._get_redis_connection() as redis:
try:
# Use SCAN to iterate through all keys in the namespace
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
)
if keys:
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Count statuses
for value in values:
if value:
try:
doc_data = json.loads(value)
status = doc_data.get("status")
if status in counts:
counts[status] += 1
except json.JSONDecodeError:
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting status counts: {e}")
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
result = {}
async with self._get_redis_connection() as redis:
try:
# Use SCAN to iterate through all keys in the namespace
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
)
if keys:
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Filter by status and create DocProcessingStatus objects
for key, value in zip(keys, values):
if value:
try:
doc_data = json.loads(value)
if doc_data.get("status") == status.value:
# Extract document ID from key
doc_id = key.split(":", 1)[1]
# Make a copy of the data to avoid modifying the original
data = doc_data.copy()
# If content is missing, use content_summary as content
if (
"content" not in data
and "content_summary" in data
):
data["content"] = data["content_summary"]
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e:
logger.error(
f"Error processing document {key}: {e}"
)
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting docs by status: {e}")
return result
async def index_done_callback(self) -> None:
"""Redis handles persistence automatically"""
pass
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Insert or update document status data"""
if not data:
return
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._get_redis_connection() as redis:
try:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
if "chunks_list" not in doc_data:
doc_data["chunks_list"] = []
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
async def delete(self, doc_ids: list[str]) -> None:
"""Delete specific records from storage by their IDs"""
if not doc_ids:
return
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for doc_id in doc_ids:
pipe.delete(f"{self.namespace}:{doc_id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
)
async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources"""
try:
async with self._get_redis_connection() as redis:
# Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys:
# Delete keys in batches
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0:
break
logger.info(
f"Dropped {deleted_count} doc status keys from {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping doc status {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View file

@ -22,6 +22,7 @@ from typing import (
Dict, Dict,
) )
from lightrag.constants import ( from lightrag.constants import (
DEFAULT_MAX_GLEANING,
DEFAULT_MAX_TOKEN_SUMMARY, DEFAULT_MAX_TOKEN_SUMMARY,
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
) )
@ -124,7 +125,9 @@ class LightRAG:
# Entity extraction # Entity extraction
# --- # ---
entity_extract_max_gleaning: int = field(default=1) entity_extract_max_gleaning: int = field(
default=get_env_value("MAX_GLEANING", DEFAULT_MAX_GLEANING, int)
)
"""Maximum number of entity extraction attempts for ambiguous content.""" """Maximum number of entity extraction attempts for ambiguous content."""
summary_to_max_tokens: int = field( summary_to_max_tokens: int = field(
@ -346,6 +349,7 @@ class LightRAG:
# Fix global_config now # Fix global_config now
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@ -394,13 +398,13 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
# TODO: deprecating, text_chunks is redundant with chunks_vdb
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
@ -949,6 +953,7 @@ class LightRAG:
**dp, **dp,
"full_doc_id": doc_id, "full_doc_id": doc_id,
"file_path": file_path, # Add file path to each chunk "file_path": file_path, # Add file path to each chunk
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
} }
for dp in self.chunking_func( for dp in self.chunking_func(
self.tokenizer, self.tokenizer,
@ -960,14 +965,17 @@ class LightRAG:
) )
} }
# Process document (text chunks and full docs) in parallel # Process document in two stages
# Create tasks with references for potential cancellation # Stage 1: Process text chunks and docs (parallel execution)
doc_status_task = asyncio.create_task( doc_status_task = asyncio.create_task(
self.doc_status.upsert( self.doc_status.upsert(
{ {
doc_id: { doc_id: {
"status": DocStatus.PROCESSING, "status": DocStatus.PROCESSING,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # Save chunks list
"content": status_doc.content, "content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,
"content_length": status_doc.content_length, "content_length": status_doc.content_length,
@ -983,11 +991,6 @@ class LightRAG:
chunks_vdb_task = asyncio.create_task( chunks_vdb_task = asyncio.create_task(
self.chunks_vdb.upsert(chunks) self.chunks_vdb.upsert(chunks)
) )
entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock
)
)
full_docs_task = asyncio.create_task( full_docs_task = asyncio.create_task(
self.full_docs.upsert( self.full_docs.upsert(
{doc_id: {"content": status_doc.content}} {doc_id: {"content": status_doc.content}}
@ -996,14 +999,26 @@ class LightRAG:
text_chunks_task = asyncio.create_task( text_chunks_task = asyncio.create_task(
self.text_chunks.upsert(chunks) self.text_chunks.upsert(chunks)
) )
tasks = [
# First stage tasks (parallel execution)
first_stage_tasks = [
doc_status_task, doc_status_task,
chunks_vdb_task, chunks_vdb_task,
entity_relation_task,
full_docs_task, full_docs_task,
text_chunks_task, text_chunks_task,
] ]
await asyncio.gather(*tasks) entity_relation_task = None
# Execute first stage tasks
await asyncio.gather(*first_stage_tasks)
# Stage 2: Process entity relation graph (after text_chunks are saved)
entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock
)
)
await entity_relation_task
file_extraction_stage_ok = True file_extraction_stage_ok = True
except Exception as e: except Exception as e:
@ -1018,14 +1033,14 @@ class LightRAG:
) )
pipeline_status["history_messages"].append(error_msg) pipeline_status["history_messages"].append(error_msg)
# Cancel other tasks as they are no longer meaningful # Cancel tasks that are not yet completed
for task in [ all_tasks = first_stage_tasks + (
chunks_vdb_task, [entity_relation_task]
entity_relation_task, if entity_relation_task
full_docs_task, else []
text_chunks_task, )
]: for task in all_tasks:
if not task.done(): if task and not task.done():
task.cancel() task.cancel()
# Persistent llm cache # Persistent llm cache
@ -1075,6 +1090,9 @@ class LightRAG:
doc_id: { doc_id: {
"status": DocStatus.PROCESSED, "status": DocStatus.PROCESSED,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # 保留 chunks_list
"content": status_doc.content, "content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,
"content_length": status_doc.content_length, "content_length": status_doc.content_length,
@ -1193,6 +1211,7 @@ class LightRAG:
pipeline_status=pipeline_status, pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock, pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
text_chunks_storage=self.text_chunks,
) )
return chunk_results return chunk_results
except Exception as e: except Exception as e:
@ -1723,28 +1742,10 @@ class LightRAG:
file_path="", file_path="",
) )
# 2. Get all chunks related to this document # 2. Get chunk IDs from document status
try: chunk_ids = set(doc_status_data.get("chunks_list", []))
all_chunks = await self.text_chunks.get_all()
related_chunks = {
chunk_id: chunk_data
for chunk_id, chunk_data in all_chunks.items()
if isinstance(chunk_data, dict)
and chunk_data.get("full_doc_id") == doc_id
}
# Update pipeline status after getting chunks count if not chunk_ids:
async with pipeline_status_lock:
log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}")
raise Exception(f"Failed to retrieve document chunks: {e}") from e
if not related_chunks:
logger.warning(f"No chunks found for document {doc_id}") logger.warning(f"No chunks found for document {doc_id}")
# Mark that deletion operations have started # Mark that deletion operations have started
deletion_operations_started = True deletion_operations_started = True
@ -1775,7 +1776,6 @@ class LightRAG:
file_path=file_path, file_path=file_path,
) )
chunk_ids = set(related_chunks.keys())
# Mark that deletion operations have started # Mark that deletion operations have started
deletion_operations_started = True deletion_operations_started = True
@ -1799,26 +1799,12 @@ class LightRAG:
) )
) )
# Update pipeline status after getting affected_nodes
async with pipeline_status_lock:
log_message = f"Found {len(affected_nodes)} affected entities"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
affected_edges = ( affected_edges = (
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids( await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
list(chunk_ids) list(chunk_ids)
) )
) )
# Update pipeline status after getting affected_edges
async with pipeline_status_lock:
log_message = f"Found {len(affected_edges)} affected relations"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e: except Exception as e:
logger.error(f"Failed to analyze affected graph elements: {e}") logger.error(f"Failed to analyze affected graph elements: {e}")
raise Exception(f"Failed to analyze graph dependencies: {e}") from e raise Exception(f"Failed to analyze graph dependencies: {e}") from e
@ -1836,6 +1822,14 @@ class LightRAG:
elif remaining_sources != sources: elif remaining_sources != sources:
entities_to_rebuild[node_label] = remaining_sources entities_to_rebuild[node_label] = remaining_sources
async with pipeline_status_lock:
log_message = (
f"Found {len(entities_to_rebuild)} affected entities"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process relationships # Process relationships
for edge_data in affected_edges: for edge_data in affected_edges:
src = edge_data.get("source") src = edge_data.get("source")
@ -1857,6 +1851,14 @@ class LightRAG:
elif remaining_sources != sources: elif remaining_sources != sources:
relationships_to_rebuild[edge_tuple] = remaining_sources relationships_to_rebuild[edge_tuple] = remaining_sources
async with pipeline_status_lock:
log_message = (
f"Found {len(relationships_to_rebuild)} affected relations"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e: except Exception as e:
logger.error(f"Failed to process graph analysis results: {e}") logger.error(f"Failed to process graph analysis results: {e}")
raise Exception(f"Failed to process graph dependencies: {e}") from e raise Exception(f"Failed to process graph dependencies: {e}") from e
@ -1940,17 +1942,13 @@ class LightRAG:
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entities_vdb=self.entities_vdb, entities_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb, relationships_vdb=self.relationships_vdb,
text_chunks=self.text_chunks, text_chunks_storage=self.text_chunks,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
global_config=asdict(self), global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
) )
async with pipeline_status_lock:
log_message = f"Successfully rebuilt {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relations"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e: except Exception as e:
logger.error(f"Failed to rebuild knowledge from chunks: {e}") logger.error(f"Failed to rebuild knowledge from chunks: {e}")
raise Exception( raise Exception(

View file

@ -25,6 +25,7 @@ from .utils import (
CacheData, CacheData,
get_conversation_turns, get_conversation_turns,
use_llm_func_with_cache, use_llm_func_with_cache,
update_chunk_cache_list,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@ -103,8 +104,6 @@ async def _handle_entity_relation_summary(
entity_or_relation_name: str, entity_or_relation_name: str,
description: str, description: str,
global_config: dict, global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> str: ) -> str:
"""Handle entity relation summary """Handle entity relation summary
@ -247,9 +246,11 @@ async def _rebuild_knowledge_from_chunks(
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks: BaseKVStorage, text_chunks_storage: BaseKVStorage,
llm_response_cache: BaseKVStorage, llm_response_cache: BaseKVStorage,
global_config: dict[str, str], global_config: dict[str, str],
pipeline_status: dict | None = None,
pipeline_status_lock=None,
) -> None: ) -> None:
"""Rebuild entity and relationship descriptions from cached extraction results """Rebuild entity and relationship descriptions from cached extraction results
@ -259,9 +260,12 @@ async def _rebuild_knowledge_from_chunks(
Args: Args:
entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data}
""" """
if not entities_to_rebuild and not relationships_to_rebuild: if not entities_to_rebuild and not relationships_to_rebuild:
return return
rebuilt_entities_count = 0
rebuilt_relationships_count = 0
# Get all referenced chunk IDs # Get all referenced chunk IDs
all_referenced_chunk_ids = set() all_referenced_chunk_ids = set()
@ -270,36 +274,74 @@ async def _rebuild_knowledge_from_chunks(
for chunk_ids in relationships_to_rebuild.values(): for chunk_ids in relationships_to_rebuild.values():
all_referenced_chunk_ids.update(chunk_ids) all_referenced_chunk_ids.update(chunk_ids)
logger.debug( status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions" logger.info(status_message)
) if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
# Get cached extraction results for these chunks # Get cached extraction results for these chunks using storage
# cached_results chunk_id -> [list of extraction result from LLM cache sorted by created_at]
cached_results = await _get_cached_extraction_results( cached_results = await _get_cached_extraction_results(
llm_response_cache, all_referenced_chunk_ids llm_response_cache,
all_referenced_chunk_ids,
text_chunks_storage=text_chunks_storage,
) )
if not cached_results: if not cached_results:
logger.warning("No cached extraction results found, cannot rebuild") status_message = "No cached extraction results found, cannot rebuild"
logger.warning(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
return return
# Process cached results to get entities and relationships for each chunk # Process cached results to get entities and relationships for each chunk
chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
for chunk_id, extraction_result in cached_results.items(): for chunk_id, extraction_results in cached_results.items():
try: try:
entities, relationships = await _parse_extraction_result( # Handle multiple extraction results per chunk
text_chunks=text_chunks, chunk_entities[chunk_id] = defaultdict(list)
extraction_result=extraction_result, chunk_relationships[chunk_id] = defaultdict(list)
chunk_id=chunk_id,
) # process multiple LLM extraction results for a single chunk_id
chunk_entities[chunk_id] = entities for extraction_result in extraction_results:
chunk_relationships[chunk_id] = relationships entities, relationships = await _parse_extraction_result(
text_chunks_storage=text_chunks_storage,
extraction_result=extraction_result,
chunk_id=chunk_id,
)
# Merge entities and relationships from this extraction result
# Only keep the first occurrence of each entity_name in the same chunk_id
for entity_name, entity_list in entities.items():
if (
entity_name not in chunk_entities[chunk_id]
or len(chunk_entities[chunk_id][entity_name]) == 0
):
chunk_entities[chunk_id][entity_name].extend(entity_list)
# Only keep the first occurrence of each rel_key in the same chunk_id
for rel_key, rel_list in relationships.items():
if (
rel_key not in chunk_relationships[chunk_id]
or len(chunk_relationships[chunk_id][rel_key]) == 0
):
chunk_relationships[chunk_id][rel_key].extend(rel_list)
except Exception as e: except Exception as e:
logger.error( status_message = (
f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
) )
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
continue continue
# Rebuild entities # Rebuild entities
@ -314,11 +356,22 @@ async def _rebuild_knowledge_from_chunks(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
global_config=global_config, global_config=global_config,
) )
logger.debug( rebuilt_entities_count += 1
f"Rebuilt entity {entity_name} from {len(chunk_ids)} cached extractions" status_message = (
f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks"
) )
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
except Exception as e: except Exception as e:
logger.error(f"Failed to rebuild entity {entity_name}: {e}") status_message = f"Failed to rebuild entity {entity_name}: {e}"
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
# Rebuild relationships # Rebuild relationships
for (src, tgt), chunk_ids in relationships_to_rebuild.items(): for (src, tgt), chunk_ids in relationships_to_rebuild.items():
@ -333,53 +386,112 @@ async def _rebuild_knowledge_from_chunks(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
global_config=global_config, global_config=global_config,
) )
logger.debug( rebuilt_relationships_count += 1
f"Rebuilt relationship {src}-{tgt} from {len(chunk_ids)} cached extractions" status_message = (
f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks"
) )
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
except Exception as e: except Exception as e:
logger.error(f"Failed to rebuild relationship {src}-{tgt}: {e}") status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
logger.debug("Completed rebuilding knowledge from cached extractions") status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
async def _get_cached_extraction_results( async def _get_cached_extraction_results(
llm_response_cache: BaseKVStorage, chunk_ids: set[str] llm_response_cache: BaseKVStorage,
) -> dict[str, str]: chunk_ids: set[str],
text_chunks_storage: BaseKVStorage,
) -> dict[str, list[str]]:
"""Get cached extraction results for specific chunk IDs """Get cached extraction results for specific chunk IDs
Args: Args:
llm_response_cache: LLM response cache storage
chunk_ids: Set of chunk IDs to get cached results for chunk_ids: Set of chunk IDs to get cached results for
text_chunks_data: Pre-loaded chunk data (optional, for performance)
text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
Returns: Returns:
Dict mapping chunk_id -> extraction_result_text Dict mapping chunk_id -> list of extraction_result_text
""" """
cached_results = {} cached_results = {}
# Get all cached data for "default" mode (entity extraction cache) # Collect all LLM cache IDs from chunks
default_cache = await llm_response_cache.get_by_id("default") or {} all_cache_ids = set()
for cache_key, cache_entry in default_cache.items(): # Read from storage
chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
if chunk_data and isinstance(chunk_data, dict):
llm_cache_list = chunk_data.get("llm_cache_list", [])
if llm_cache_list:
all_cache_ids.update(llm_cache_list)
else:
logger.warning(
f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
)
if not all_cache_ids:
logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
return cached_results
# Batch get LLM cache entries
cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids))
# Process cache entries and group by chunk_id
valid_entries = 0
for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
if ( if (
isinstance(cache_entry, dict) cache_entry is not None
and isinstance(cache_entry, dict)
and cache_entry.get("cache_type") == "extract" and cache_entry.get("cache_type") == "extract"
and cache_entry.get("chunk_id") in chunk_ids and cache_entry.get("chunk_id") in chunk_ids
): ):
chunk_id = cache_entry["chunk_id"] chunk_id = cache_entry["chunk_id"]
extraction_result = cache_entry["return"] extraction_result = cache_entry["return"]
cached_results[chunk_id] = extraction_result create_time = cache_entry.get(
"create_time", 0
) # Get creation time, default to 0
valid_entries += 1
logger.debug( # Support multiple LLM caches per chunk
f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs" if chunk_id not in cached_results:
cached_results[chunk_id] = []
# Store tuple with extraction result and creation time for sorting
cached_results[chunk_id].append((extraction_result, create_time))
# Sort extraction results by create_time for each chunk
for chunk_id in cached_results:
# Sort by create_time (x[1]), then extract only extraction_result (x[0])
cached_results[chunk_id].sort(key=lambda x: x[1])
cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
logger.info(
f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
) )
return cached_results return cached_results
async def _parse_extraction_result( async def _parse_extraction_result(
text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
"""Parse cached extraction result using the same logic as extract_entities """Parse cached extraction result using the same logic as extract_entities
Args: Args:
text_chunks_storage: Text chunks storage to get chunk data
extraction_result: The cached LLM extraction result extraction_result: The cached LLM extraction result
chunk_id: The chunk ID for source tracking chunk_id: The chunk ID for source tracking
@ -387,8 +499,8 @@ async def _parse_extraction_result(
Tuple of (entities_dict, relationships_dict) Tuple of (entities_dict, relationships_dict)
""" """
# Get chunk data for file_path # Get chunk data for file_path from storage
chunk_data = await text_chunks.get_by_id(chunk_id) chunk_data = await text_chunks_storage.get_by_id(chunk_id)
file_path = ( file_path = (
chunk_data.get("file_path", "unknown_source") chunk_data.get("file_path", "unknown_source")
if chunk_data if chunk_data
@ -761,8 +873,6 @@ async def _merge_nodes_then_upsert(
entity_name, entity_name,
description, description,
global_config, global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
else: else:
@ -925,8 +1035,6 @@ async def _merge_edges_then_upsert(
f"({src_id}, {tgt_id})", f"({src_id}, {tgt_id})",
description, description,
global_config, global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
else: else:
@ -1102,6 +1210,7 @@ async def extract_entities(
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None,
) -> list: ) -> list:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -1208,6 +1317,9 @@ async def extract_entities(
# Get file path from chunk data or use default # Get file path from chunk data or use default
file_path = chunk_dp.get("file_path", "unknown_source") file_path = chunk_dp.get("file_path", "unknown_source")
# Create cache keys collector for batch processing
cache_keys_collector = []
# Get initial extraction # Get initial extraction
hint_prompt = entity_extract_prompt.format( hint_prompt = entity_extract_prompt.format(
**{**context_base, "input_text": content} **{**context_base, "input_text": content}
@ -1219,7 +1331,10 @@ async def extract_entities(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
cache_type="extract", cache_type="extract",
chunk_id=chunk_key, chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
) )
# Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# Process initial extraction with file path # Process initial extraction with file path
@ -1236,6 +1351,7 @@ async def extract_entities(
history_messages=history, history_messages=history,
cache_type="extract", cache_type="extract",
chunk_id=chunk_key, chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
) )
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@ -1266,11 +1382,21 @@ async def extract_entities(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
history_messages=history, history_messages=history,
cache_type="extract", cache_type="extract",
cache_keys_collector=cache_keys_collector,
) )
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes": if if_loop_result != "yes":
break break
# Batch update chunk's llm_cache_list with all collected cache keys
if cache_keys_collector and text_chunks_storage:
await update_chunk_cache_list(
chunk_key,
text_chunks_storage,
cache_keys_collector,
"entity_extraction",
)
processed_chunks += 1 processed_chunks += 1
entities_count = len(maybe_nodes) entities_count = len(maybe_nodes)
relations_count = len(maybe_edges) relations_count = len(maybe_edges)
@ -1343,7 +1469,7 @@ async def kg_query(
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
# Handle cache # Handle cache
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
) )
@ -1390,7 +1516,7 @@ async def kg_query(
) )
if query_param.only_need_context: if query_param.only_need_context:
return context return context if context is not None else PROMPTS["fail_response"]
if context is None: if context is None:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
@ -1502,7 +1628,7 @@ async def extract_keywords_only(
""" """
# 1. Handle cache if needed - add cache type for keywords # 1. Handle cache if needed - add cache type for keywords
args_hash = compute_args_hash(param.mode, text, cache_type="keywords") args_hash = compute_args_hash(param.mode, text)
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, text, param.mode, cache_type="keywords" hashing_kv, args_hash, text, param.mode, cache_type="keywords"
) )
@ -1647,7 +1773,7 @@ async def _get_vector_context(
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
) )
logger.info( logger.info(
f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
) )
if not maybe_trun_chunks: if not maybe_trun_chunks:
@ -1871,7 +1997,7 @@ async def _get_node_data(
) )
logger.info( logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
) )
# build prompt # build prompt
@ -2180,7 +2306,7 @@ async def _get_edge_data(
), ),
) )
logger.info( logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
) )
relations_context = [] relations_context = []
@ -2369,7 +2495,7 @@ async def naive_query(
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
# Handle cache # Handle cache
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
) )
@ -2485,7 +2611,7 @@ async def kg_query_with_keywords(
# Apply higher priority (5) to query relation LLM function # Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
) )

View file

@ -14,7 +14,6 @@ from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Protocol, Callable, TYPE_CHECKING, List from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import numpy as np import numpy as np
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag.constants import ( from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES, DEFAULT_LOG_MAX_BYTES,
@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
raise e from None raise e from None
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: def compute_args_hash(*args: Any) -> str:
"""Compute a hash for the given arguments. """Compute a hash for the given arguments.
Args: Args:
*args: Arguments to hash *args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
Returns: Returns:
str: Hash string str: Hash string
""" """
@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
# Convert all arguments to strings and join them # Convert all arguments to strings and join them
args_str = "".join([str(arg) for arg in args]) args_str = "".join([str(arg) for arg in args])
if cache_type:
args_str = f"{cache_type}:{args_str}"
# Compute MD5 hash # Compute MD5 hash
return hashlib.md5(args_str.encode()).hexdigest() return hashlib.md5(args_str.encode()).hexdigest()
def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
"""Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
Args:
mode: Cache mode (e.g., 'default', 'local', 'global')
cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
hash_value: Hash value from compute_args_hash
Returns:
str: Flattened cache key
"""
return f"{mode}:{cache_type}:{hash_value}"
def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
"""Parse a flattened cache key back into its components
Args:
cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
Returns:
tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
"""
parts = cache_key.split(":", 2)
if len(parts) == 3:
return parts[0], parts[1], parts[2]
return None
def compute_mdhash_id(content: str, prefix: str = "") -> str: def compute_mdhash_id(content: str, prefix: str = "") -> str:
""" """
Compute a unique ID for a given content string. Compute a unique ID for a given content string.
@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
return combined_data return combined_data
async def get_best_cached_response(
hashing_kv,
current_embedding,
similarity_threshold=0.95,
mode="default",
use_llm_check=False,
llm_func=None,
original_prompt=None,
cache_type=None,
) -> str | None:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
best_similarity = -1
best_response = None
best_prompt = None
best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
# Skip if cache_type doesn't match
if cache_type and cache_data.get("cache_type") != cache_type:
continue
# Check if cache data is valid
if cache_data["embedding"] is None:
continue
try:
# Safely convert cached embedding
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
# Ensure min_val and max_val are valid float values
embedding_min = cache_data.get("embedding_min")
embedding_max = cache_data.get("embedding_max")
if (
embedding_min is None
or embedding_max is None
or embedding_min >= embedding_max
):
logger.warning(
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
)
continue
cached_embedding = dequantize_embedding(
cached_quantized,
embedding_min,
embedding_max,
)
except Exception as e:
logger.warning(f"Error processing cached embedding: {str(e)}")
continue
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided
if (
use_llm_check
and llm_func
and original_prompt
and best_prompt
and best_response is not None
):
compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt
)
try:
llm_result = await llm_func(compare_prompt)
llm_result = llm_result.strip()
llm_similarity = float(llm_result)
# Replace vector similarity with LLM similarity score
best_similarity = llm_similarity
if best_similarity < similarity_threshold:
log_data = {
"event": "cache_rejected_by_llm",
"type": cache_type,
"mode": mode,
"original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100
else original_prompt,
"cached_question": best_prompt[:100] + "..."
if len(best_prompt) > 100
else best_prompt,
"similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
return None
except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}")
return None # Return None directly when LLM check fails
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
log_data = {
"event": "cache_hit",
"type": cache_type,
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
return best_response
return None
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):
"""Calculate cosine similarity between two vectors""" """Calculate cosine similarity between two vectors"""
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
@ -957,7 +857,7 @@ async def handle_cache(
mode="default", mode="default",
cache_type=None, cache_type=None,
): ):
"""Generic cache handling function""" """Generic cache handling function with flattened cache keys"""
if hashing_kv is None: if hashing_kv is None:
return None, None, None, None return None, None, None, None
@ -968,15 +868,14 @@ async def handle_cache(
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
return None, None, None, None return None, None, None, None
if exists_func(hashing_kv, "get_by_mode_and_id"): # Use flattened cache key format: {mode}:{cache_type}:{hash}
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} flattened_key = generate_cache_key(mode, cache_type, args_hash)
else: cache_entry = await hashing_kv.get_by_id(flattened_key)
mode_cache = await hashing_kv.get_by_id(mode) or {} if cache_entry:
if args_hash in mode_cache: logger.debug(f"Flattened cache hit(key:{flattened_key})")
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") return cache_entry["return"], None, None, None
return mode_cache[args_hash]["return"], None, None, None
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
return None, None, None, None return None, None, None, None
@ -994,7 +893,7 @@ class CacheData:
async def save_to_cache(hashing_kv, cache_data: CacheData): async def save_to_cache(hashing_kv, cache_data: CacheData):
"""Save data to cache, with improved handling for streaming responses and duplicate content. """Save data to cache using flattened key structure.
Args: Args:
hashing_kv: The key-value storage for caching hashing_kv: The key-value storage for caching
@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
logger.debug("Streaming response detected, skipping cache") logger.debug("Streaming response detected, skipping cache")
return return
# Get existing cache data # Use flattened cache key format: {mode}:{cache_type}:{hash}
if exists_func(hashing_kv, "get_by_mode_and_id"): flattened_key = generate_cache_key(
mode_cache = ( cache_data.mode, cache_data.cache_type, cache_data.args_hash
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash) )
or {}
)
else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
# Check if we already have identical content cached # Check if we already have identical content cached
if cache_data.args_hash in mode_cache: existing_cache = await hashing_kv.get_by_id(flattened_key)
existing_content = mode_cache[cache_data.args_hash].get("return") if existing_cache:
existing_content = existing_cache.get("return")
if existing_content == cache_data.content: if existing_content == cache_data.content:
logger.info( logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
)
return return
# Update cache with new content # Create cache entry with flattened structure
mode_cache[cache_data.args_hash] = { cache_entry = {
"return": cache_data.content, "return": cache_data.content,
"cache_type": cache_data.cache_type, "cache_type": cache_data.cache_type,
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
"original_prompt": cache_data.prompt, "original_prompt": cache_data.prompt,
} }
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}") logger.info(f" == LLM cache == saving: {flattened_key}")
# Only upsert if there's actual new content # Save using flattened key
await hashing_kv.upsert({cache_data.mode: mode_cache}) await hashing_kv.upsert({flattened_key: cache_entry})
def safe_unicode_decode(content): def safe_unicode_decode(content):
@ -1529,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
return import_class return import_class
async def update_chunk_cache_list(
chunk_id: str,
text_chunks_storage: "BaseKVStorage",
cache_keys: list[str],
cache_scenario: str = "batch_update",
) -> None:
"""Update chunk's llm_cache_list with the given cache keys
Args:
chunk_id: Chunk identifier
text_chunks_storage: Text chunks storage instance
cache_keys: List of cache keys to add to the list
cache_scenario: Description of the cache scenario for logging
"""
if not cache_keys:
return
try:
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
if chunk_data:
# Ensure llm_cache_list exists
if "llm_cache_list" not in chunk_data:
chunk_data["llm_cache_list"] = []
# Add cache keys to the list if not already present
existing_keys = set(chunk_data["llm_cache_list"])
new_keys = [key for key in cache_keys if key not in existing_keys]
if new_keys:
chunk_data["llm_cache_list"].extend(new_keys)
# Update the chunk in storage
await text_chunks_storage.upsert({chunk_id: chunk_data})
logger.debug(
f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})"
)
except Exception as e:
logger.warning(
f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}"
)
async def use_llm_func_with_cache( async def use_llm_func_with_cache(
input_text: str, input_text: str,
use_llm_func: callable, use_llm_func: callable,
@ -1537,6 +1473,7 @@ async def use_llm_func_with_cache(
history_messages: list[dict[str, str]] = None, history_messages: list[dict[str, str]] = None,
cache_type: str = "extract", cache_type: str = "extract",
chunk_id: str | None = None, chunk_id: str | None = None,
cache_keys_collector: list = None,
) -> str: ) -> str:
"""Call LLM function with cache support """Call LLM function with cache support
@ -1551,6 +1488,8 @@ async def use_llm_func_with_cache(
history_messages: History messages list history_messages: History messages list
cache_type: Type of cache cache_type: Type of cache
chunk_id: Chunk identifier to store in cache chunk_id: Chunk identifier to store in cache
text_chunks_storage: Text chunks storage to update llm_cache_list
cache_keys_collector: Optional list to collect cache keys for batch processing
Returns: Returns:
LLM response text LLM response text
@ -1563,6 +1502,9 @@ async def use_llm_func_with_cache(
_prompt = input_text _prompt = input_text
arg_hash = compute_args_hash(_prompt) arg_hash = compute_args_hash(_prompt)
# Generate cache key for this LLM call
cache_key = generate_cache_key("default", cache_type, arg_hash)
cached_return, _1, _2, _3 = await handle_cache( cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, llm_response_cache,
arg_hash, arg_hash,
@ -1573,6 +1515,11 @@ async def use_llm_func_with_cache(
if cached_return: if cached_return:
logger.debug(f"Found cache for {arg_hash}") logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1 statistic_data["llm_cache"] += 1
# Add cache key to collector if provided
if cache_keys_collector is not None:
cache_keys_collector.append(cache_key)
return cached_return return cached_return
statistic_data["llm_call"] += 1 statistic_data["llm_call"] += 1
@ -1597,6 +1544,10 @@ async def use_llm_func_with_cache(
), ),
) )
# Add cache key to collector if provided
if cache_keys_collector is not None:
cache_keys_collector.append(cache_key)
return res return res
# When cache is disabled, directly call LLM # When cache is disabled, directly call LLM

View file

@ -6,7 +6,7 @@ from typing import Any, cast
from .base import DeletionResult from .base import DeletionResult
from .kg.shared_storage import get_graph_db_lock from .kg.shared_storage import get_graph_db_lock
from .prompt import GRAPH_FIELD_SEP from .constants import GRAPH_FIELD_SEP
from .utils import compute_mdhash_id, logger from .utils import compute_mdhash_id, logger
from .base import StorageNameSpace from .base import StorageNameSpace

View file

@ -57,6 +57,10 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
"Winner": "[Answer 1 or Answer 2]", "Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation here]" "Explanation": "[Provide explanation here]"
}}, }},
"Diversity": {{
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation here]"
}},
"Empowerment": {{ "Empowerment": {{
"Winner": "[Answer 1 or Answer 2]", "Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation here]" "Explanation": "[Provide explanation here]"

View file

@ -8,6 +8,7 @@
支持的图存储类型包括 支持的图存储类型包括
- NetworkXStorage - NetworkXStorage
- Neo4JStorage - Neo4JStorage
- MongoDBStorage
- PGGraphStorage - PGGraphStorage
- MemgraphStorage - MemgraphStorage
""" """