finish implement workspace isolation in lightrag_server

This commit is contained in:
Chen.Zhidong 2025-12-01 22:47:27 +08:00
parent 607c11c083
commit 4d2d781246
7 changed files with 410 additions and 301 deletions

View file

@ -2,67 +2,60 @@
LightRAG FastAPI Server LightRAG FastAPI Server
""" """
from fastapi import FastAPI, Depends, HTTPException, Request import configparser
import logging
import logging.config
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
import pipmaster as pm
import uvicorn
from ascii_colors import ASCIIColors
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import ( from fastapi.openapi.docs import (
get_swagger_ui_html, get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html, get_swagger_ui_oauth2_redirect_html,
) )
import os from fastapi.responses import JSONResponse, RedirectResponse
import logging from fastapi.security import OAuth2PasswordRequestForm
import logging.config
import sys
import uvicorn
import pipmaster as pm
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path from lightrag import LightRAG
import configparser from lightrag import __version__ as core_version
from ascii_colors import ASCIIColors
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from lightrag.api.utils_api import (
get_combined_auth_dependency,
display_splash_screen,
check_env_file,
)
from .config import (
global_args,
update_uvicorn_mode_config,
get_default_host,
)
from lightrag.utils import get_env_value
from lightrag import LightRAG, __version__ as core_version
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from lightrag.types import GPTKeywordExtractionFormat from lightrag.api.auth import auth_handler
from lightrag.utils import EmbeddingFunc from lightrag.api.routers.document_routes import DocumentManager, create_document_routes
from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
)
from lightrag.api.routers.document_routes import (
DocumentManager,
create_document_routes,
)
from lightrag.api.routers.query_routes import create_query_routes
from lightrag.api.routers.graph_routes import create_graph_routes from lightrag.api.routers.graph_routes import create_graph_routes
from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.api.routers.ollama_api import OllamaAPI
from lightrag.api.routers.query_routes import create_query_routes
from lightrag.utils import logger, set_verbose_debug from lightrag.api.utils_api import (
check_env_file,
display_splash_screen,
get_combined_auth_dependency,
)
from lightrag.constants import (
DEFAULT_EMBEDDING_TIMEOUT,
DEFAULT_LLM_TIMEOUT,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME,
DEFAULT_LOG_MAX_BYTES,
)
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data,
get_default_workspace,
# set_default_workspace,
cleanup_keyed_lock, cleanup_keyed_lock,
finalize_share_data, finalize_share_data,
get_default_workspace,
get_namespace_data,
set_default_workspace,
) )
from fastapi.security import OAuth2PasswordRequestForm from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api.auth import auth_handler from lightrag.utils import EmbeddingFunc, get_env_value, logger, set_verbose_debug
from .config import get_default_host, global_args, update_uvicorn_mode_config
# use the .env that is inside the current folder # use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance # allows to use different .env file for each lightrag instance
@ -343,8 +336,85 @@ def create_app(args):
# Check if API key is provided either through env var or args # Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize document manager with workspace support for data isolation doc_manager_cache = {}
doc_manager = DocumentManager(args.input_dir, workspace=args.workspace)
def create_doc_manager(request: Request | None) -> DocumentManager:
"""Create or retrieve DocumentManager for the current workspace"""
workspace = args.workspace
if request is not None:
workspace = get_workspace_from_request(request, args.workspace)
logger.debug(f"Using DocumentManager for workspace: '{workspace}'")
if workspace in doc_manager_cache:
return doc_manager_cache[workspace]
doc_manager = DocumentManager(args.input_dir, workspace=workspace)
doc_manager_cache[workspace] = doc_manager
return doc_manager_cache[workspace]
rag_cache = {}
async def create_rag(request: Request | None) -> LightRAG:
"""Create or retrieve LightRAG instance for the current workspace"""
workspace = args.workspace
if request is not None:
workspace = get_workspace_from_request(request, args.workspace)
logger.debug(f"Using LightRAG instance for workspace: '{workspace}'")
if workspace in rag_cache:
return rag_cache[workspace]
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag)
# Initialize RAG with unified configuration
try:
rag = LightRAG(
working_dir=args.working_dir,
workspace=workspace,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=create_llm_model_kwargs(args.llm_binding, args, llm_timeout),
embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
"entity_types": args.entity_types,
},
ollama_server_infos=ollama_server_infos,
)
# Initialize database connections
# Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace
await rag.initialize_storages()
# Data migration regardless of storage implementation
await rag.check_and_migrate_data()
rag_cache[workspace] = rag
return rag
except Exception as e:
logger.error(f"Failed to initialize LightRAG: {e}")
raise
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@ -353,12 +423,8 @@ def create_app(args):
app.state.background_tasks = set() app.state.background_tasks = set()
try: try:
# Initialize database connections create_doc_manager(None) # Pre-create default DocumentManager
# Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace await create_rag(None) # Pre-create default LightRAG
await rag.initialize_storages()
# Data migration regardless of storage implementation
await rag.check_and_migrate_data()
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n") ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
@ -366,7 +432,8 @@ def create_app(args):
finally: finally:
# Clean up database connections # Clean up database connections
await rag.finalize_storages() for rag in rag_cache.values():
await rag.finalize_storages()
if "LIGHTRAG_GUNICORN_MODE" not in os.environ: if "LIGHTRAG_GUNICORN_MODE" not in os.environ:
# Only perform cleanup in Uvicorn single-process mode # Only perform cleanup in Uvicorn single-process mode
@ -404,6 +471,7 @@ def create_app(args):
"tryItOutEnabled": True, "tryItOutEnabled": True,
} }
set_default_workspace(args.workspace)
app = FastAPI(**app_kwargs) app = FastAPI(**app_kwargs)
# Add custom validation error handler for /query/data endpoint # Add custom validation error handler for /query/data endpoint
@ -456,7 +524,7 @@ def create_app(args):
# Create combined auth dependency for all endpoints # Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
def get_workspace_from_request(request: Request) -> str | None: def get_workspace_from_request(request: Request, default: str) -> str:
""" """
Extract workspace from HTTP request header or use default. Extract workspace from HTTP request header or use default.
@ -474,7 +542,7 @@ def create_app(args):
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
if not workspace: if not workspace:
workspace = None workspace = default
return workspace return workspace
@ -1022,66 +1090,19 @@ def create_app(args):
else: else:
logger.info("Reranking is disabled") logger.info("Reranking is disabled")
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(
name=args.simulated_model_name, tag=args.simulated_model_tag
)
# Initialize RAG with unified configuration
try:
rag = LightRAG(
working_dir=args.working_dir,
workspace=args.workspace,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=create_llm_model_kwargs(
args.llm_binding, args, llm_timeout
),
embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
"entity_types": args.entity_types,
},
ollama_server_infos=ollama_server_infos,
)
except Exception as e:
logger.error(f"Failed to initialize LightRAG: {e}")
raise
# Add routes # Add routes
app.include_router( app.include_router(
create_document_routes( create_document_routes(
rag, create_rag,
doc_manager, create_doc_manager,
api_key, api_key,
) )
) )
app.include_router(create_query_routes(rag, api_key, args.top_k)) app.include_router(create_query_routes(create_rag, api_key, args.top_k))
app.include_router(create_graph_routes(rag, api_key)) app.include_router(create_graph_routes(create_rag, api_key))
# Add Ollama API routes # Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) ollama_api = OllamaAPI(create_rag, top_k=args.top_k, api_key=api_key)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")
# Custom Swagger UI endpoint for offline support # Custom Swagger UI endpoint for offline support
@ -1212,10 +1233,7 @@ def create_app(args):
async def get_status(request: Request): async def get_status(request: Request):
"""Get current system status including WebUI availability""" """Get current system status including WebUI availability"""
try: try:
workspace = get_workspace_from_request(request) workspace = get_workspace_from_request(request, get_default_workspace())
default_workspace = get_default_workspace()
if workspace is None:
workspace = default_workspace
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data(
"pipeline_status", workspace=workspace "pipeline_status", workspace=workspace
) )
@ -1250,7 +1268,7 @@ def create_app(args):
"vector_storage": args.vector_storage, "vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache": args.enable_llm_cache,
"workspace": default_workspace, "workspace": workspace,
"max_graph_nodes": args.max_graph_nodes, "max_graph_nodes": args.max_graph_nodes,
# Rerank configuration # Rerank configuration
"enable_rerank": rerank_model_func is not None, "enable_rerank": rerank_model_func is not None,

View file

@ -3,29 +3,31 @@ This module contains all document-related routes for the LightRAG API.
""" """
import asyncio import asyncio
from functools import lru_cache
from lightrag.utils import logger, get_pinyin_sort_key
import aiofiles
import shutil import shutil
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from functools import lru_cache
from typing import Dict, List, Optional, Any, Literal
from io import BytesIO from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
import aiofiles
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
BackgroundTasks, BackgroundTasks,
Depends, Depends,
File, File,
HTTPException, HTTPException,
Request,
UploadFile, UploadFile,
) )
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
from lightrag.utils import generate_track_id
from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
from lightrag.utils import generate_track_id, get_pinyin_sort_key, logger
from ..config import global_args from ..config import global_args
@ -2029,16 +2031,12 @@ async def background_delete_documents(
logger.error(f"Error processing pending documents after deletion: {e}") logger.error(f"Error processing pending documents after deletion: {e}")
def create_document_routes( def create_document_routes(create_rag, create_doc_manager, api_key: Optional[str] = None):
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
):
# Create combined auth dependency for document routes # Create combined auth dependency for document routes
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.post( @router.post("/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)])
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)] async def scan_for_new_documents(raw_request: Request, background_tasks: BackgroundTasks):
)
async def scan_for_new_documents(background_tasks: BackgroundTasks):
""" """
Trigger the scanning process for new documents. Trigger the scanning process for new documents.
@ -2049,6 +2047,9 @@ def create_document_routes(
Returns: Returns:
ScanResponse: A response object containing the scanning status and track_id ScanResponse: A response object containing the scanning status and track_id
""" """
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
# Generate track_id with "scan" prefix for scanning operation # Generate track_id with "scan" prefix for scanning operation
track_id = generate_track_id("scan") track_id = generate_track_id("scan")
@ -2060,11 +2061,9 @@ def create_document_routes(
track_id=track_id, track_id=track_id,
) )
@router.post( @router.post("/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)])
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
async def upload_to_input_dir( async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...) raw_request: Request, background_tasks: BackgroundTasks, file: UploadFile = File(...)
): ):
""" """
Upload a file to the input directory and index it. Upload a file to the input directory and index it.
@ -2085,6 +2084,9 @@ def create_document_routes(
HTTPException: If the file type is not supported (400) or other errors occur (500). HTTPException: If the file type is not supported (400) or other errors occur (500).
""" """
try: try:
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
# Sanitize filename to prevent Path Traversal attacks # Sanitize filename to prevent Path Traversal attacks
safe_filename = sanitize_filename(file.filename, doc_manager.input_dir) safe_filename = sanitize_filename(file.filename, doc_manager.input_dir)
@ -2133,12 +2135,8 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post("/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)])
"/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] async def insert_text(raw_request: Request, request: InsertTextRequest, background_tasks: BackgroundTasks):
)
async def insert_text(
request: InsertTextRequest, background_tasks: BackgroundTasks
):
""" """
Insert text into the RAG system. Insert text into the RAG system.
@ -2156,6 +2154,8 @@ def create_document_routes(
HTTPException: If an error occurs during text processing (500). HTTPException: If an error occurs during text processing (500).
""" """
try: try:
rag = await create_rag(raw_request)
# Check if file_source already exists in doc_status storage # Check if file_source already exists in doc_status storage
if ( if (
request.file_source request.file_source
@ -2200,9 +2200,7 @@ def create_document_routes(
response_model=InsertResponse, response_model=InsertResponse,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def insert_texts( async def insert_texts(raw_request: Request, request: InsertTextsRequest, background_tasks: BackgroundTasks):
request: InsertTextsRequest, background_tasks: BackgroundTasks
):
""" """
Insert multiple texts into the RAG system. Insert multiple texts into the RAG system.
@ -2220,6 +2218,8 @@ def create_document_routes(
HTTPException: If an error occurs during text processing (500). HTTPException: If an error occurs during text processing (500).
""" """
try: try:
rag = await create_rag(raw_request)
# Check if any file_sources already exist in doc_status storage # Check if any file_sources already exist in doc_status storage
if request.file_sources: if request.file_sources:
for file_source in request.file_sources: for file_source in request.file_sources:
@ -2261,10 +2261,8 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.delete( @router.delete("", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)])
"", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)] async def clear_documents(raw_request: Request):
)
async def clear_documents():
""" """
Clear all documents from the RAG system. Clear all documents from the RAG system.
@ -2285,46 +2283,49 @@ def create_document_routes(
HTTPException: Raised when a serious error occurs during the clearing process, HTTPException: Raised when a serious error occurs during the clearing process,
with status code 500 and error details in the detail field. with status code 500 and error details in the detail field.
""" """
from lightrag.kg.shared_storage import (
get_namespace_data,
get_namespace_lock,
)
# Get pipeline status and lock
pipeline_status = await get_namespace_data(
"pipeline_status", workspace=rag.workspace
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
# Check and set status with lock
async with pipeline_status_lock:
if pipeline_status.get("busy", False):
return ClearDocumentsResponse(
status="busy",
message="Cannot clear documents while pipeline is busy",
)
# Set busy to true
pipeline_status.update(
{
"busy": True,
"job_name": "Clearing Documents",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "Starting document clearing process",
}
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
pipeline_status["history_messages"].append(
"Starting document clearing process"
)
try: try:
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
from lightrag.kg.shared_storage import (
get_namespace_data,
get_namespace_lock,
)
# Get pipeline status and lock
pipeline_status = await get_namespace_data(
"pipeline_status", workspace=rag.workspace
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
# Check and set status with lock
async with pipeline_status_lock:
if pipeline_status.get("busy", False):
return ClearDocumentsResponse(
status="busy",
message="Cannot clear documents while pipeline is busy",
)
# Set busy to true
pipeline_status.update(
{
"busy": True,
"job_name": "Clearing Documents",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "Starting document clearing process",
}
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
pipeline_status["history_messages"].append(
"Starting document clearing process"
)
# Use drop method to clear all data # Use drop method to clear all data
drop_tasks = [] drop_tasks = []
storages = [ storages = [
@ -2460,7 +2461,7 @@ def create_document_routes(
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
response_model=PipelineStatusResponse, response_model=PipelineStatusResponse,
) )
async def get_pipeline_status() -> PipelineStatusResponse: async def get_pipeline_status(raw_request: Request) -> PipelineStatusResponse:
""" """
Get the current status of the document indexing pipeline. Get the current status of the document indexing pipeline.
@ -2485,10 +2486,12 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving pipeline status (500) HTTPException: If an error occurs while retrieving pipeline status (500)
""" """
try: try:
rag = await create_rag(raw_request)
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_all_update_flags_status,
get_namespace_data, get_namespace_data,
get_namespace_lock, get_namespace_lock,
get_all_update_flags_status,
) )
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data(
@ -2556,10 +2559,8 @@ def create_document_routes(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# TODO: Deprecated, use /documents/paginated instead # TODO: Deprecated, use /documents/paginated instead
@router.get( @router.get("", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)])
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)] async def documents(raw_request: Request) -> DocsStatusesResponse:
)
async def documents() -> DocsStatusesResponse:
""" """
Get the status of all documents in the system. This endpoint is deprecated; use /documents/paginated instead. Get the status of all documents in the system. This endpoint is deprecated; use /documents/paginated instead.
To prevent excessive resource consumption, a maximum of 1,000 records is returned. To prevent excessive resource consumption, a maximum of 1,000 records is returned.
@ -2578,6 +2579,8 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving document statuses (500). HTTPException: If an error occurs while retrieving document statuses (500).
""" """
try: try:
rag = await create_rag(raw_request)
statuses = ( statuses = (
DocStatus.PENDING, DocStatus.PENDING,
DocStatus.PROCESSING, DocStatus.PROCESSING,
@ -2673,6 +2676,7 @@ def create_document_routes(
summary="Delete a document and all its associated data by its ID.", summary="Delete a document and all its associated data by its ID.",
) )
async def delete_document( async def delete_document(
raw_request: Request,
delete_request: DeleteDocRequest, delete_request: DeleteDocRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
) -> DeleteDocByIdResponse: ) -> DeleteDocByIdResponse:
@ -2699,9 +2703,12 @@ def create_document_routes(
HTTPException: HTTPException:
- 500: If an unexpected internal error occurs during initialization. - 500: If an unexpected internal error occurs during initialization.
""" """
doc_ids = delete_request.doc_ids
try: try:
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
doc_ids = delete_request.doc_ids
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_namespace_lock, get_namespace_lock,
@ -2750,7 +2757,7 @@ def create_document_routes(
response_model=ClearCacheResponse, response_model=ClearCacheResponse,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def clear_cache(request: ClearCacheRequest): async def clear_cache(raw_request: Request, request: ClearCacheRequest):
""" """
Clear all cache data from the LLM response cache storage. Clear all cache data from the LLM response cache storage.
@ -2767,6 +2774,7 @@ def create_document_routes(
HTTPException: If an error occurs during cache clearing (500). HTTPException: If an error occurs during cache clearing (500).
""" """
try: try:
rag = await create_rag(raw_request)
# Call the aclear_cache method (no modes parameter) # Call the aclear_cache method (no modes parameter)
await rag.aclear_cache() await rag.aclear_cache()
@ -2784,7 +2792,7 @@ def create_document_routes(
response_model=DeletionResult, response_model=DeletionResult,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def delete_entity(request: DeleteEntityRequest): async def delete_entity(raw_request: Request, request: DeleteEntityRequest):
""" """
Delete an entity and all its relationships from the knowledge graph. Delete an entity and all its relationships from the knowledge graph.
@ -2798,6 +2806,8 @@ def create_document_routes(
HTTPException: If the entity is not found (404) or an error occurs (500). HTTPException: If the entity is not found (404) or an error occurs (500).
""" """
try: try:
rag = await create_rag(raw_request)
result = await rag.adelete_by_entity(entity_name=request.entity_name) result = await rag.adelete_by_entity(entity_name=request.entity_name)
if result.status == "not_found": if result.status == "not_found":
raise HTTPException(status_code=404, detail=result.message) raise HTTPException(status_code=404, detail=result.message)
@ -2819,7 +2829,7 @@ def create_document_routes(
response_model=DeletionResult, response_model=DeletionResult,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def delete_relation(request: DeleteRelationRequest): async def delete_relation(raw_request: Request, request: DeleteRelationRequest):
""" """
Delete a relationship between two entities from the knowledge graph. Delete a relationship between two entities from the knowledge graph.
@ -2833,6 +2843,8 @@ def create_document_routes(
HTTPException: If the relation is not found (404) or an error occurs (500). HTTPException: If the relation is not found (404) or an error occurs (500).
""" """
try: try:
rag = await create_rag(raw_request)
result = await rag.adelete_by_relation( result = await rag.adelete_by_relation(
source_entity=request.source_entity, source_entity=request.source_entity,
target_entity=request.target_entity, target_entity=request.target_entity,
@ -2857,7 +2869,7 @@ def create_document_routes(
response_model=TrackStatusResponse, response_model=TrackStatusResponse,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def get_track_status(track_id: str) -> TrackStatusResponse: async def get_track_status(raw_request: Request, track_id: str) -> TrackStatusResponse:
""" """
Get the processing status of documents by tracking ID. Get the processing status of documents by tracking ID.
@ -2877,6 +2889,8 @@ def create_document_routes(
HTTPException: If track_id is invalid (400) or an error occurs (500). HTTPException: If track_id is invalid (400) or an error occurs (500).
""" """
try: try:
rag = await create_rag(raw_request)
# Validate track_id # Validate track_id
if not track_id or not track_id.strip(): if not track_id or not track_id.strip():
raise HTTPException(status_code=400, detail="Track ID cannot be empty") raise HTTPException(status_code=400, detail="Track ID cannot be empty")
@ -2932,6 +2946,7 @@ def create_document_routes(
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def get_documents_paginated( async def get_documents_paginated(
raw_request: Request,
request: DocumentsRequest, request: DocumentsRequest,
) -> PaginatedDocsResponse: ) -> PaginatedDocsResponse:
""" """
@ -2954,6 +2969,8 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving documents (500). HTTPException: If an error occurs while retrieving documents (500).
""" """
try: try:
rag = await create_rag(raw_request)
# Get paginated documents and status counts in parallel # Get paginated documents and status counts in parallel
docs_task = rag.doc_status.get_docs_paginated( docs_task = rag.doc_status.get_docs_paginated(
status_filter=request.status_filter, status_filter=request.status_filter,
@ -3018,7 +3035,7 @@ def create_document_routes(
response_model=StatusCountsResponse, response_model=StatusCountsResponse,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def get_document_status_counts() -> StatusCountsResponse: async def get_document_status_counts(raw_request: Request) -> StatusCountsResponse:
""" """
Get counts of documents by status. Get counts of documents by status.
@ -3032,6 +3049,8 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving status counts (500). HTTPException: If an error occurs while retrieving status counts (500).
""" """
try: try:
rag = await create_rag(raw_request)
status_counts = await rag.doc_status.get_all_status_counts() status_counts = await rag.doc_status.get_all_status_counts()
return StatusCountsResponse(status_counts=status_counts) return StatusCountsResponse(status_counts=status_counts)
@ -3045,7 +3064,7 @@ def create_document_routes(
response_model=ReprocessResponse, response_model=ReprocessResponse,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def reprocess_failed_documents(background_tasks: BackgroundTasks): async def reprocess_failed_documents(raw_request: Request, background_tasks: BackgroundTasks):
""" """
Reprocess failed and pending documents. Reprocess failed and pending documents.
@ -3068,6 +3087,8 @@ def create_document_routes(
HTTPException: If an error occurs while initiating reprocessing (500). HTTPException: If an error occurs while initiating reprocessing (500).
""" """
try: try:
rag = await create_rag(raw_request)
# Generate track_id with "retry" prefix for retry operation # Generate track_id with "retry" prefix for retry operation
track_id = generate_track_id("retry") track_id = generate_track_id("retry")
@ -3093,7 +3114,7 @@ def create_document_routes(
response_model=CancelPipelineResponse, response_model=CancelPipelineResponse,
dependencies=[Depends(combined_auth)], dependencies=[Depends(combined_auth)],
) )
async def cancel_pipeline(): async def cancel_pipeline(raw_request: Request):
""" """
Request cancellation of the currently running pipeline. Request cancellation of the currently running pipeline.
@ -3115,6 +3136,8 @@ def create_document_routes(
HTTPException: If an error occurs while setting cancellation flag (500). HTTPException: If an error occurs while setting cancellation flag (500).
""" """
try: try:
rag = await create_rag(raw_request)
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_namespace_lock, get_namespace_lock,

View file

@ -2,12 +2,14 @@
This module contains all graph-related routes for the LightRAG API. This module contains all graph-related routes for the LightRAG API.
""" """
from typing import Optional, Dict, Any
import traceback import traceback
from fastapi import APIRouter, Depends, Query, HTTPException from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from lightrag.utils import logger from lightrag.utils import logger
from ..utils_api import get_combined_auth_dependency from ..utils_api import get_combined_auth_dependency
router = APIRouter(tags=["graph"]) router = APIRouter(tags=["graph"])
@ -86,11 +88,11 @@ class RelationCreateRequest(BaseModel):
) )
def create_graph_routes(rag, api_key: Optional[str] = None): def create_graph_routes(create_rag, api_key: Optional[str] = None):
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)]) @router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
async def get_graph_labels(): async def get_graph_labels(raw_request: Request):
""" """
Get all graph labels Get all graph labels
@ -98,6 +100,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of graph labels List[str]: List of graph labels
""" """
try: try:
rag = await create_rag(raw_request)
return await rag.get_graph_labels() return await rag.get_graph_labels()
except Exception as e: except Exception as e:
logger.error(f"Error getting graph labels: {str(e)}") logger.error(f"Error getting graph labels: {str(e)}")
@ -108,9 +112,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/popular", dependencies=[Depends(combined_auth)]) @router.get("/graph/label/popular", dependencies=[Depends(combined_auth)])
async def get_popular_labels( async def get_popular_labels(
limit: int = Query( raw_request: Request,
300, description="Maximum number of popular labels to return", ge=1, le=1000 limit: int = Query(300, description="Maximum number of popular labels to return", ge=1, le=1000),
),
): ):
""" """
Get popular labels by node degree (most connected entities) Get popular labels by node degree (most connected entities)
@ -122,6 +125,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of popular labels sorted by degree (highest first) List[str]: List of popular labels sorted by degree (highest first)
""" """
try: try:
rag = await create_rag(raw_request)
return await rag.chunk_entity_relation_graph.get_popular_labels(limit) return await rag.chunk_entity_relation_graph.get_popular_labels(limit)
except Exception as e: except Exception as e:
logger.error(f"Error getting popular labels: {str(e)}") logger.error(f"Error getting popular labels: {str(e)}")
@ -132,6 +137,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/search", dependencies=[Depends(combined_auth)]) @router.get("/graph/label/search", dependencies=[Depends(combined_auth)])
async def search_labels( async def search_labels(
raw_request: Request,
q: str = Query(..., description="Search query string"), q: str = Query(..., description="Search query string"),
limit: int = Query( limit: int = Query(
50, description="Maximum number of search results to return", ge=1, le=100 50, description="Maximum number of search results to return", ge=1, le=100
@ -148,6 +154,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of matching labels sorted by relevance List[str]: List of matching labels sorted by relevance
""" """
try: try:
rag = await create_rag(raw_request)
return await rag.chunk_entity_relation_graph.search_labels(q, limit) return await rag.chunk_entity_relation_graph.search_labels(q, limit)
except Exception as e: except Exception as e:
logger.error(f"Error searching labels with query '{q}': {str(e)}") logger.error(f"Error searching labels with query '{q}': {str(e)}")
@ -158,6 +166,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graphs", dependencies=[Depends(combined_auth)]) @router.get("/graphs", dependencies=[Depends(combined_auth)])
async def get_knowledge_graph( async def get_knowledge_graph(
raw_request: Request,
label: str = Query(..., description="Label to get knowledge graph for"), label: str = Query(..., description="Label to get knowledge graph for"),
max_depth: int = Query(3, description="Maximum depth of graph", ge=1), max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
@ -177,6 +186,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict[str, List[str]]: Knowledge graph for label Dict[str, List[str]]: Knowledge graph for label
""" """
try: try:
rag = await create_rag(raw_request)
# Log the label parameter to check for leading spaces # Log the label parameter to check for leading spaces
logger.debug( logger.debug(
f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})" f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})"
@ -196,6 +207,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)]) @router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)])
async def check_entity_exists( async def check_entity_exists(
raw_request: Request,
name: str = Query(..., description="Entity name to check"), name: str = Query(..., description="Entity name to check"),
): ):
""" """
@ -208,6 +220,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists
""" """
try: try:
rag = await create_rag(raw_request)
exists = await rag.chunk_entity_relation_graph.has_node(name) exists = await rag.chunk_entity_relation_graph.has_node(name)
return {"exists": exists} return {"exists": exists}
except Exception as e: except Exception as e:
@ -218,7 +232,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
) )
@router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)]) @router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)])
async def update_entity(request: EntityUpdateRequest): async def update_entity(raw_request: Request, request: EntityUpdateRequest):
""" """
Update an entity's properties in the knowledge graph Update an entity's properties in the knowledge graph
@ -353,6 +367,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
} }
""" """
try: try:
rag = await create_rag(raw_request)
result = await rag.aedit_entity( result = await rag.aedit_entity(
entity_name=request.entity_name, entity_name=request.entity_name,
updated_data=request.updated_data, updated_data=request.updated_data,
@ -408,7 +424,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
) )
@router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)]) @router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)])
async def update_relation(request: RelationUpdateRequest): async def update_relation(raw_request: Request, request: RelationUpdateRequest):
"""Update a relation's properties in the knowledge graph """Update a relation's properties in the knowledge graph
Args: Args:
@ -418,6 +434,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict: Updated relation information Dict: Updated relation information
""" """
try: try:
rag = await create_rag(raw_request)
result = await rag.aedit_relation( result = await rag.aedit_relation(
source_entity=request.source_id, source_entity=request.source_id,
target_entity=request.target_id, target_entity=request.target_id,
@ -443,7 +461,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
) )
@router.post("/graph/entity/create", dependencies=[Depends(combined_auth)]) @router.post("/graph/entity/create", dependencies=[Depends(combined_auth)])
async def create_entity(request: EntityCreateRequest): async def create_entity(raw_request: Request, request: EntityCreateRequest):
""" """
Create a new entity in the knowledge graph Create a new entity in the knowledge graph
@ -488,6 +506,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
} }
""" """
try: try:
rag = await create_rag(raw_request)
# Use the proper acreate_entity method which handles: # Use the proper acreate_entity method which handles:
# - Graph lock for concurrency # - Graph lock for concurrency
# - Vector embedding creation in entities_vdb # - Vector embedding creation in entities_vdb
@ -516,7 +536,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
) )
@router.post("/graph/relation/create", dependencies=[Depends(combined_auth)]) @router.post("/graph/relation/create", dependencies=[Depends(combined_auth)])
async def create_relation(request: RelationCreateRequest): async def create_relation(raw_request: Request, request: RelationCreateRequest):
""" """
Create a new relationship between two entities in the knowledge graph Create a new relationship between two entities in the knowledge graph
@ -573,6 +593,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
} }
""" """
try: try:
rag = await create_rag(raw_request)
# Use the proper acreate_relation method which handles: # Use the proper acreate_relation method which handles:
# - Graph lock for concurrency # - Graph lock for concurrency
# - Entity existence validation # - Entity existence validation
@ -605,7 +627,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
) )
@router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)]) @router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)])
async def merge_entities(request: EntityMergeRequest): async def merge_entities(raw_request: Request, request: EntityMergeRequest):
""" """
Merge multiple entities into a single entity, preserving all relationships Merge multiple entities into a single entity, preserving all relationships
@ -662,6 +684,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
- This operation cannot be undone, so verify entity names before merging - This operation cannot be undone, so verify entity names before merging
""" """
try: try:
rag = await create_rag(raw_request)
result = await rag.amerge_entities( result = await rag.amerge_entities(
source_entities=request.entities_to_change, source_entities=request.entities_to_change,
target_entity=request.entity_to_change_into, target_entity=request.entity_to_change_into,

View file

@ -1,17 +1,17 @@
from fastapi import APIRouter, HTTPException, Request import asyncio
from pydantic import BaseModel
from typing import List, Dict, Any, Optional, Type
from lightrag.utils import logger
import time
import json import json
import re import re
import time
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Type
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import asyncio from pydantic import BaseModel
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
from fastapi import Depends from lightrag.utils import TiktokenTokenizer, logger
# query mode according to query prefix (bypass is not LightRAG quer mode) # query mode according to query prefix (bypass is not LightRAG quer mode)
@ -117,9 +117,7 @@ class OllamaPsResponse(BaseModel):
models: List[OllamaRunningModel] models: List[OllamaRunningModel]
async def parse_request_body( async def parse_request_body(request: Request, model_class: Type[BaseModel]) -> BaseModel:
request: Request, model_class: Type[BaseModel]
) -> BaseModel:
""" """
Parse request body based on Content-Type header. Parse request body based on Content-Type header.
Supports both application/json and application/octet-stream. Supports both application/json and application/octet-stream.
@ -151,9 +149,7 @@ async def parse_request_body(
except json.JSONDecodeError: except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON in request body") raise HTTPException(status_code=400, detail="Invalid JSON in request body")
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=400, detail=f"Error parsing request body: {str(e)}")
status_code=400, detail=f"Error parsing request body: {str(e)}"
)
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
@ -218,9 +214,8 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
class OllamaAPI: class OllamaAPI:
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None): def __init__(self, create_rag, top_k: int = 60, api_key: Optional[str] = None):
self.rag = rag self.create_rag = create_rag
self.ollama_server_infos = rag.ollama_server_infos
self.top_k = top_k self.top_k = top_k
self.api_key = api_key self.api_key = api_key
self.router = APIRouter(tags=["ollama"]) self.router = APIRouter(tags=["ollama"])
@ -236,21 +231,24 @@ class OllamaAPI:
return OllamaVersionResponse(version="0.9.3") return OllamaVersionResponse(version="0.9.3")
@self.router.get("/tags", dependencies=[Depends(combined_auth)]) @self.router.get("/tags", dependencies=[Depends(combined_auth)])
async def get_tags(): async def get_tags(raw_request: Request):
"""Return available models acting as an Ollama server""" """Return available models acting as an Ollama server"""
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
return OllamaTagResponse( return OllamaTagResponse(
models=[ models=[
{ {
"name": self.ollama_server_infos.LIGHTRAG_MODEL, "name": ollama_server_infos.LIGHTRAG_MODEL,
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"size": self.ollama_server_infos.LIGHTRAG_SIZE, "size": ollama_server_infos.LIGHTRAG_SIZE,
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST, "digest": ollama_server_infos.LIGHTRAG_DIGEST,
"details": { "details": {
"parent_model": "", "parent_model": "",
"format": "gguf", "format": "gguf",
"family": self.ollama_server_infos.LIGHTRAG_NAME, "family": ollama_server_infos.LIGHTRAG_NAME,
"families": [self.ollama_server_infos.LIGHTRAG_NAME], "families": [ollama_server_infos.LIGHTRAG_NAME],
"parameter_size": "13B", "parameter_size": "13B",
"quantization_level": "Q4_0", "quantization_level": "Q4_0",
}, },
@ -259,15 +257,18 @@ class OllamaAPI:
) )
@self.router.get("/ps", dependencies=[Depends(combined_auth)]) @self.router.get("/ps", dependencies=[Depends(combined_auth)])
async def get_running_models(): async def get_running_models(raw_request: Request):
"""List Running Models - returns currently running models""" """List Running Models - returns currently running models"""
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
return OllamaPsResponse( return OllamaPsResponse(
models=[ models=[
{ {
"name": self.ollama_server_infos.LIGHTRAG_MODEL, "name": ollama_server_infos.LIGHTRAG_MODEL,
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"size": self.ollama_server_infos.LIGHTRAG_SIZE, "size": ollama_server_infos.LIGHTRAG_SIZE,
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST, "digest": ollama_server_infos.LIGHTRAG_DIGEST,
"details": { "details": {
"parent_model": "", "parent_model": "",
"format": "gguf", "format": "gguf",
@ -277,14 +278,12 @@ class OllamaAPI:
"quantization_level": "Q4_0", "quantization_level": "Q4_0",
}, },
"expires_at": "2050-12-31T14:38:31.83753-07:00", "expires_at": "2050-12-31T14:38:31.83753-07:00",
"size_vram": self.ollama_server_infos.LIGHTRAG_SIZE, "size_vram": ollama_server_infos.LIGHTRAG_SIZE,
} }
] ]
) )
@self.router.post( @self.router.post("/generate", dependencies=[Depends(combined_auth)], include_in_schema=True)
"/generate", dependencies=[Depends(combined_auth)], include_in_schema=True
)
async def generate(raw_request: Request): async def generate(raw_request: Request):
"""Handle generate completion requests acting as an Ollama model """Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG, For compatibility purpose, the request is not processed by LightRAG,
@ -292,6 +291,9 @@ class OllamaAPI:
Supports both application/json and application/octet-stream Content-Types. Supports both application/json and application/octet-stream Content-Types.
""" """
try: try:
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
# Parse the request body manually # Parse the request body manually
request = await parse_request_body(raw_request, OllamaGenerateRequest) request = await parse_request_body(raw_request, OllamaGenerateRequest)
@ -300,12 +302,10 @@ class OllamaAPI:
prompt_tokens = estimate_tokens(query) prompt_tokens = estimate_tokens(query)
if request.system: if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream: if request.stream:
response = await self.rag.llm_model_func( response = await rag.llm_model_func(query, stream=True, **rag.llm_model_kwargs)
query, stream=True, **self.rag.llm_model_kwargs
)
async def stream_generator(): async def stream_generator():
first_chunk_time = None first_chunk_time = None
@ -320,8 +320,8 @@ class OllamaAPI:
total_response = response total_response = response
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": response, "response": response,
"done": False, "done": False,
} }
@ -333,8 +333,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time eval_time = last_chunk_time - first_chunk_time
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "", "response": "",
"done": True, "done": True,
"done_reason": "stop", "done_reason": "stop",
@ -358,8 +358,8 @@ class OllamaAPI:
total_response += chunk total_response += chunk
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk, "response": chunk,
"done": False, "done": False,
} }
@ -375,8 +375,8 @@ class OllamaAPI:
# Send error message to client # Send error message to client
error_data = { error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": f"\n\nError: {error_msg}", "response": f"\n\nError: {error_msg}",
"error": f"\n\nError: {error_msg}", "error": f"\n\nError: {error_msg}",
"done": False, "done": False,
@ -385,8 +385,8 @@ class OllamaAPI:
# Send final message to close the stream # Send final message to close the stream
final_data = { final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "", "response": "",
"done": True, "done": True,
} }
@ -400,8 +400,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time eval_time = last_chunk_time - first_chunk_time
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "", "response": "",
"done": True, "done": True,
"done_reason": "stop", "done_reason": "stop",
@ -428,9 +428,7 @@ class OllamaAPI:
) )
else: else:
first_chunk_time = time.time_ns() first_chunk_time = time.time_ns()
response_text = await self.rag.llm_model_func( response_text = await rag.llm_model_func(query, stream=False, **rag.llm_model_kwargs)
query, stream=False, **self.rag.llm_model_kwargs
)
last_chunk_time = time.time_ns() last_chunk_time = time.time_ns()
if not response_text: if not response_text:
@ -442,8 +440,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time eval_time = last_chunk_time - first_chunk_time
return { return {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": str(response_text), "response": str(response_text),
"done": True, "done": True,
"done_reason": "stop", "done_reason": "stop",
@ -468,7 +466,11 @@ class OllamaAPI:
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
Supports both application/json and application/octet-stream Content-Types. Supports both application/json and application/octet-stream Content-Types.
""" """
try: try:
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
# Parse the request body manually # Parse the request body manually
request = await parse_request_body(raw_request, OllamaChatRequest) request = await parse_request_body(raw_request, OllamaChatRequest)
@ -516,15 +518,15 @@ class OllamaAPI:
# Determine if the request is prefix with "/bypass" # Determine if the request is prefix with "/bypass"
if mode == SearchMode.bypass: if mode == SearchMode.bypass:
if request.system: if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system rag.llm_model_kwargs["system_prompt"] = request.system
response = await self.rag.llm_model_func( response = await rag.llm_model_func(
cleaned_query, cleaned_query,
stream=True, stream=True,
history_messages=conversation_history, history_messages=conversation_history,
**self.rag.llm_model_kwargs, **rag.llm_model_kwargs,
) )
else: else:
response = await self.rag.aquery( response = await rag.aquery(
cleaned_query, param=query_param cleaned_query, param=query_param
) )
@ -541,8 +543,8 @@ class OllamaAPI:
total_response = response total_response = response
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": response, "content": response,
@ -558,8 +560,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time eval_time = last_chunk_time - first_chunk_time
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "", "content": "",
@ -586,8 +588,8 @@ class OllamaAPI:
total_response += chunk total_response += chunk
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": chunk, "content": chunk,
@ -607,8 +609,8 @@ class OllamaAPI:
# Send error message to client # Send error message to client
error_data = { error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": f"\n\nError: {error_msg}", "content": f"\n\nError: {error_msg}",
@ -621,8 +623,8 @@ class OllamaAPI:
# Send final message to close the stream # Send final message to close the stream
final_data = { final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "", "content": "",
@ -641,8 +643,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time eval_time = last_chunk_time - first_chunk_time
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "", "content": "",
@ -678,18 +680,16 @@ class OllamaAPI:
) )
if match_result or mode == SearchMode.bypass: if match_result or mode == SearchMode.bypass:
if request.system: if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await self.rag.llm_model_func( response_text = await rag.llm_model_func(
cleaned_query, cleaned_query,
stream=False, stream=False,
history_messages=conversation_history, history_messages=conversation_history,
**self.rag.llm_model_kwargs, **rag.llm_model_kwargs,
) )
else: else:
response_text = await self.rag.aquery( response_text = await rag.aquery(cleaned_query, param=query_param)
cleaned_query, param=query_param
)
last_chunk_time = time.time_ns() last_chunk_time = time.time_ns()
@ -702,8 +702,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time eval_time = last_chunk_time - first_chunk_time
return { return {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": str(response_text), "content": str(response_text),

View file

@ -4,12 +4,14 @@ This module contains all query-related routes for the LightRAG API.
import json import json
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam from fastapi import APIRouter, Depends, HTTPException, Request
from lightrag.api.utils_api import get_combined_auth_dependency
from lightrag.utils import logger
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from lightrag.api.utils_api import get_combined_auth_dependency
from lightrag.base import QueryParam
from lightrag.utils import logger
router = APIRouter(tags=["query"]) router = APIRouter(tags=["query"])
@ -190,7 +192,7 @@ class StreamChunkResponse(BaseModel):
) )
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): def create_query_routes(create_rag, api_key: Optional[str] = None, top_k: int = 60):
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.post( @router.post(
@ -322,7 +324,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
}, },
}, },
) )
async def query_text(request: QueryRequest): async def query_text(raw_request: Request, request: QueryRequest):
""" """
Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored. Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored.
@ -402,6 +404,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
- 500: Internal processing error (e.g., LLM service unavailable) - 500: Internal processing error (e.g., LLM service unavailable)
""" """
try: try:
rag = await create_rag(raw_request)
param = request.to_query_params( param = request.to_query_params(
False False
) # Ensure stream=False for non-streaming endpoint ) # Ensure stream=False for non-streaming endpoint
@ -532,7 +536,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
}, },
}, },
) )
async def query_text_stream(request: QueryRequest): async def query_text_stream(raw_request: Request, request: QueryRequest):
""" """
Advanced RAG query endpoint with flexible streaming response. Advanced RAG query endpoint with flexible streaming response.
@ -660,6 +664,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
Use streaming mode for real-time interfaces and non-streaming for batch processing. Use streaming mode for real-time interfaces and non-streaming for batch processing.
""" """
try: try:
rag = await create_rag(raw_request)
# Use the stream parameter from the request, defaulting to True if not specified # Use the stream parameter from the request, defaulting to True if not specified
stream_mode = request.stream if request.stream is not None else True stream_mode = request.stream if request.stream is not None else True
param = request.to_query_params(stream_mode) param = request.to_query_params(stream_mode)
@ -1035,7 +1041,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
}, },
}, },
) )
async def query_data(request: QueryRequest): async def query_data(raw_request: Request, request: QueryRequest):
""" """
Advanced data retrieval endpoint for structured RAG analysis. Advanced data retrieval endpoint for structured RAG analysis.
@ -1139,6 +1145,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
as structured data analysis typically requires source attribution. as structured data analysis typically requires source attribution.
""" """
try: try:
rag = await create_rag(raw_request)
param = request.to_query_params(False) # No streaming for data endpoint param = request.to_query_params(False) # No streaming for data endpoint
response = await rag.aquery_data(request.query, param=param) response = await rag.aquery_data(request.query, param=param)
@ -1151,6 +1159,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
status="failure", status="failure",
message="Invalid response type", message="Invalid response type",
data={}, data={},
metadata={},
) )
except Exception as e: except Exception as e:
logger.error(f"Error processing data query: {str(e)}", exc_info=True) logger.error(f"Error processing data query: {str(e)}", exc_info=True)

View file

@ -289,6 +289,7 @@ const axiosInstance = axios.create({
axiosInstance.interceptors.request.use((config) => { axiosInstance.interceptors.request.use((config) => {
const apiKey = useSettingsStore.getState().apiKey const apiKey = useSettingsStore.getState().apiKey
const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
const workspace = localStorage.getItem('LIGHTRAG-WORKSPACE');
// Always include token if it exists, regardless of path // Always include token if it exists, regardless of path
if (token) { if (token) {
@ -297,6 +298,9 @@ axiosInstance.interceptors.request.use((config) => {
if (apiKey) { if (apiKey) {
config.headers['X-API-Key'] = apiKey config.headers['X-API-Key'] = apiKey
} }
if (workspace) {
config.headers['LIGHTRAG-WORKSPACE'] = workspace
}
return config return config
}) })
@ -397,6 +401,7 @@ export const queryTextStream = async (
) => { ) => {
const apiKey = useSettingsStore.getState().apiKey; const apiKey = useSettingsStore.getState().apiKey;
const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
const workspace = localStorage.getItem('LIGHTRAG-WORKSPACE');
const headers: HeadersInit = { const headers: HeadersInit = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Accept': 'application/x-ndjson', 'Accept': 'application/x-ndjson',
@ -407,6 +412,9 @@ export const queryTextStream = async (
if (apiKey) { if (apiKey) {
headers['X-API-Key'] = apiKey; headers['X-API-Key'] = apiKey;
} }
if (workspace) {
headers['LIGHTRAG-WORKSPACE'] = workspace;
}
try { try {
const response = await fetch(`${backendBaseUrl}/query/stream`, { const response = await fetch(`${backendBaseUrl}/query/stream`, {

View file

@ -7,8 +7,10 @@ import { useAuthStore } from '@/stores/state'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { navigationService } from '@/services/navigation' import { navigationService } from '@/services/navigation'
import { ZapIcon, GithubIcon, LogOutIcon } from 'lucide-react' import { ZapIcon, GithubIcon, LogOutIcon, CheckIcon } from 'lucide-react'
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip' import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip'
import {useState, useEffect} from "react";
import {toast} from 'sonner';
interface NavigationTabProps { interface NavigationTabProps {
value: string value: string
@ -57,6 +59,7 @@ function TabsNavigation() {
export default function SiteHeader() { export default function SiteHeader() {
const { t } = useTranslation() const { t } = useTranslation()
const { isGuestMode, coreVersion, apiVersion, username, webuiTitle, webuiDescription } = useAuthStore() const { isGuestMode, coreVersion, apiVersion, username, webuiTitle, webuiDescription } = useAuthStore()
const [workspace, setWorkspace] = useState('');
const versionDisplay = (coreVersion && apiVersion) const versionDisplay = (coreVersion && apiVersion)
? `${coreVersion}/${apiVersion}` ? `${coreVersion}/${apiVersion}`
@ -72,6 +75,26 @@ export default function SiteHeader() {
navigationService.navigateToLogin(); navigationService.navigateToLogin();
} }
useEffect(() => {
const ws = localStorage.getItem('LIGHTRAG-WORKSPACE') || '';
setWorkspace(ws);
}, []);
const handleWorkspaceUpdate = () => {
const trimed = workspace.trim();
if (trimed) {
localStorage.setItem('LIGHTRAG-WORKSPACE', trimed);
toast.success(t('Workspace set. Reloading page...'));
} else {
localStorage.removeItem('LIGHTRAG-WORKSPACE');
toast.success(t('Workspace cleared. Reloading page...'));
}
setTimeout(() => {
window.location.reload();
}, 500);
}
return ( return (
<header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur"> <header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
<div className="min-w-[200px] w-auto flex items-center"> <div className="min-w-[200px] w-auto flex items-center">
@ -111,6 +134,10 @@ export default function SiteHeader() {
<nav className="w-[200px] flex items-center justify-end"> <nav className="w-[200px] flex items-center justify-end">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<div className="flex items-center gap-1">
<input type="text" value={workspace} onChange={(e) => setWorkspace(e.target.value)} placeholder="workspace" className="h-6 w-20 px-1 text-xs border rounded bg-background text-foreground"/>
<Button variant="ghost" size="icon" className="h-6 w-6" side="bottom" tooltip={t('header.updateWorkspace', 'Update Workspace')} onClick={handleWorkspaceUpdate}><CheckIcon className="size-3" aria-hidden="true" /></Button>
</div>
{versionDisplay && ( {versionDisplay && (
<TooltipProvider> <TooltipProvider>
<Tooltip> <Tooltip>