This commit is contained in:
Chen Zhidong 2025-12-02 14:39:05 +07:00 committed by GitHub
commit b1a948e5cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 410 additions and 301 deletions

View file

@ -2,67 +2,60 @@
LightRAG FastAPI Server
"""
from fastapi import FastAPI, Depends, HTTPException, Request
import configparser
import logging
import logging.config
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
import pipmaster as pm
import uvicorn
from ascii_colors import ASCIIColors
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import (
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
import os
import logging
import logging.config
import sys
import uvicorn
import pipmaster as pm
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from lightrag.api.utils_api import (
get_combined_auth_dependency,
display_splash_screen,
check_env_file,
)
from .config import (
global_args,
update_uvicorn_mode_config,
get_default_host,
)
from lightrag.utils import get_env_value
from lightrag import LightRAG, __version__ as core_version
from lightrag import LightRAG
from lightrag import __version__ as core_version
from lightrag.api import __api_version__
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.utils import EmbeddingFunc
from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
)
from lightrag.api.routers.document_routes import (
DocumentManager,
create_document_routes,
)
from lightrag.api.routers.query_routes import create_query_routes
from lightrag.api.auth import auth_handler
from lightrag.api.routers.document_routes import DocumentManager, create_document_routes
from lightrag.api.routers.graph_routes import create_graph_routes
from lightrag.api.routers.ollama_api import OllamaAPI
from lightrag.utils import logger, set_verbose_debug
from lightrag.api.routers.query_routes import create_query_routes
from lightrag.api.utils_api import (
check_env_file,
display_splash_screen,
get_combined_auth_dependency,
)
from lightrag.constants import (
DEFAULT_EMBEDDING_TIMEOUT,
DEFAULT_LLM_TIMEOUT,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME,
DEFAULT_LOG_MAX_BYTES,
)
from lightrag.kg.shared_storage import (
get_namespace_data,
get_default_workspace,
# set_default_workspace,
cleanup_keyed_lock,
finalize_share_data,
get_default_workspace,
get_namespace_data,
set_default_workspace,
)
from fastapi.security import OAuth2PasswordRequestForm
from lightrag.api.auth import auth_handler
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.utils import EmbeddingFunc, get_env_value, logger, set_verbose_debug
from .config import get_default_host, global_args, update_uvicorn_mode_config
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
@ -343,8 +336,85 @@ def create_app(args):
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize document manager with workspace support for data isolation
doc_manager = DocumentManager(args.input_dir, workspace=args.workspace)
doc_manager_cache = {}
def create_doc_manager(request: Request | None) -> DocumentManager:
"""Create or retrieve DocumentManager for the current workspace"""
workspace = args.workspace
if request is not None:
workspace = get_workspace_from_request(request, args.workspace)
logger.debug(f"Using DocumentManager for workspace: '{workspace}'")
if workspace in doc_manager_cache:
return doc_manager_cache[workspace]
doc_manager = DocumentManager(args.input_dir, workspace=workspace)
doc_manager_cache[workspace] = doc_manager
return doc_manager_cache[workspace]
rag_cache = {}
async def create_rag(request: Request | None) -> LightRAG:
"""Create or retrieve LightRAG instance for the current workspace"""
workspace = args.workspace
if request is not None:
workspace = get_workspace_from_request(request, args.workspace)
logger.debug(f"Using LightRAG instance for workspace: '{workspace}'")
if workspace in rag_cache:
return rag_cache[workspace]
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag)
# Initialize RAG with unified configuration
try:
rag = LightRAG(
working_dir=args.working_dir,
workspace=workspace,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=create_llm_model_kwargs(args.llm_binding, args, llm_timeout),
embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
"entity_types": args.entity_types,
},
ollama_server_infos=ollama_server_infos,
)
# Initialize database connections
# Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace
await rag.initialize_storages()
# Data migration regardless of storage implementation
await rag.check_and_migrate_data()
rag_cache[workspace] = rag
return rag
except Exception as e:
logger.error(f"Failed to initialize LightRAG: {e}")
raise
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -353,12 +423,8 @@ def create_app(args):
app.state.background_tasks = set()
try:
# Initialize database connections
# Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace
await rag.initialize_storages()
# Data migration regardless of storage implementation
await rag.check_and_migrate_data()
create_doc_manager(None) # Pre-create default DocumentManager
await create_rag(None) # Pre-create default LightRAG
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
@ -366,7 +432,8 @@ def create_app(args):
finally:
# Clean up database connections
await rag.finalize_storages()
for rag in rag_cache.values():
await rag.finalize_storages()
if "LIGHTRAG_GUNICORN_MODE" not in os.environ:
# Only perform cleanup in Uvicorn single-process mode
@ -404,6 +471,7 @@ def create_app(args):
"tryItOutEnabled": True,
}
set_default_workspace(args.workspace)
app = FastAPI(**app_kwargs)
# Add custom validation error handler for /query/data endpoint
@ -456,7 +524,7 @@ def create_app(args):
# Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key)
def get_workspace_from_request(request: Request) -> str | None:
def get_workspace_from_request(request: Request, default: str) -> str:
"""
Extract workspace from HTTP request header or use default.
@ -474,7 +542,7 @@ def create_app(args):
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
if not workspace:
workspace = None
workspace = default
return workspace
@ -1022,66 +1090,19 @@ def create_app(args):
else:
logger.info("Reranking is disabled")
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(
name=args.simulated_model_name, tag=args.simulated_model_tag
)
# Initialize RAG with unified configuration
try:
rag = LightRAG(
working_dir=args.working_dir,
workspace=args.workspace,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=create_llm_model_kwargs(
args.llm_binding, args, llm_timeout
),
embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
"entity_types": args.entity_types,
},
ollama_server_infos=ollama_server_infos,
)
except Exception as e:
logger.error(f"Failed to initialize LightRAG: {e}")
raise
# Add routes
app.include_router(
create_document_routes(
rag,
doc_manager,
create_rag,
create_doc_manager,
api_key,
)
)
app.include_router(create_query_routes(rag, api_key, args.top_k))
app.include_router(create_graph_routes(rag, api_key))
app.include_router(create_query_routes(create_rag, api_key, args.top_k))
app.include_router(create_graph_routes(create_rag, api_key))
# Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
ollama_api = OllamaAPI(create_rag, top_k=args.top_k, api_key=api_key)
app.include_router(ollama_api.router, prefix="/api")
# Custom Swagger UI endpoint for offline support
@ -1212,10 +1233,7 @@ def create_app(args):
async def get_status(request: Request):
"""Get current system status including WebUI availability"""
try:
workspace = get_workspace_from_request(request)
default_workspace = get_default_workspace()
if workspace is None:
workspace = default_workspace
workspace = get_workspace_from_request(request, get_default_workspace())
pipeline_status = await get_namespace_data(
"pipeline_status", workspace=workspace
)
@ -1250,7 +1268,7 @@ def create_app(args):
"vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache,
"workspace": default_workspace,
"workspace": workspace,
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": rerank_model_func is not None,

View file

@ -3,29 +3,31 @@ This module contains all document-related routes for the LightRAG API.
"""
import asyncio
from functools import lru_cache
from lightrag.utils import logger, get_pinyin_sort_key
import aiofiles
import shutil
import traceback
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Any, Literal
from functools import lru_cache
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
import aiofiles
from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
File,
HTTPException,
Request,
UploadFile,
)
from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
from lightrag.utils import generate_track_id
from lightrag.api.utils_api import get_combined_auth_dependency
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
from lightrag.utils import generate_track_id, get_pinyin_sort_key, logger
from ..config import global_args
@ -2029,16 +2031,12 @@ async def background_delete_documents(
logger.error(f"Error processing pending documents after deletion: {e}")
def create_document_routes(
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
):
def create_document_routes(create_rag, create_doc_manager, api_key: Optional[str] = None):
# Create combined auth dependency for document routes
combined_auth = get_combined_auth_dependency(api_key)
@router.post(
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
)
async def scan_for_new_documents(background_tasks: BackgroundTasks):
@router.post("/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)])
async def scan_for_new_documents(raw_request: Request, background_tasks: BackgroundTasks):
"""
Trigger the scanning process for new documents.
@ -2049,6 +2047,9 @@ def create_document_routes(
Returns:
ScanResponse: A response object containing the scanning status and track_id
"""
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
# Generate track_id with "scan" prefix for scanning operation
track_id = generate_track_id("scan")
@ -2060,11 +2061,9 @@ def create_document_routes(
track_id=track_id,
)
@router.post(
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
@router.post("/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)])
async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
raw_request: Request, background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""
Upload a file to the input directory and index it.
@ -2085,6 +2084,9 @@ def create_document_routes(
HTTPException: If the file type is not supported (400) or other errors occur (500).
"""
try:
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
# Sanitize filename to prevent Path Traversal attacks
safe_filename = sanitize_filename(file.filename, doc_manager.input_dir)
@ -2133,12 +2135,8 @@ def create_document_routes(
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
async def insert_text(
request: InsertTextRequest, background_tasks: BackgroundTasks
):
@router.post("/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)])
async def insert_text(raw_request: Request, request: InsertTextRequest, background_tasks: BackgroundTasks):
"""
Insert text into the RAG system.
@ -2156,6 +2154,8 @@ def create_document_routes(
HTTPException: If an error occurs during text processing (500).
"""
try:
rag = await create_rag(raw_request)
# Check if file_source already exists in doc_status storage
if (
request.file_source
@ -2200,9 +2200,7 @@ def create_document_routes(
response_model=InsertResponse,
dependencies=[Depends(combined_auth)],
)
async def insert_texts(
request: InsertTextsRequest, background_tasks: BackgroundTasks
):
async def insert_texts(raw_request: Request, request: InsertTextsRequest, background_tasks: BackgroundTasks):
"""
Insert multiple texts into the RAG system.
@ -2220,6 +2218,8 @@ def create_document_routes(
HTTPException: If an error occurs during text processing (500).
"""
try:
rag = await create_rag(raw_request)
# Check if any file_sources already exist in doc_status storage
if request.file_sources:
for file_source in request.file_sources:
@ -2261,10 +2261,8 @@ def create_document_routes(
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
)
async def clear_documents():
@router.delete("", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)])
async def clear_documents(raw_request: Request):
"""
Clear all documents from the RAG system.
@ -2285,46 +2283,49 @@ def create_document_routes(
HTTPException: Raised when a serious error occurs during the clearing process,
with status code 500 and error details in the detail field.
"""
from lightrag.kg.shared_storage import (
get_namespace_data,
get_namespace_lock,
)
# Get pipeline status and lock
pipeline_status = await get_namespace_data(
"pipeline_status", workspace=rag.workspace
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
# Check and set status with lock
async with pipeline_status_lock:
if pipeline_status.get("busy", False):
return ClearDocumentsResponse(
status="busy",
message="Cannot clear documents while pipeline is busy",
)
# Set busy to true
pipeline_status.update(
{
"busy": True,
"job_name": "Clearing Documents",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "Starting document clearing process",
}
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
pipeline_status["history_messages"].append(
"Starting document clearing process"
)
try:
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
from lightrag.kg.shared_storage import (
get_namespace_data,
get_namespace_lock,
)
# Get pipeline status and lock
pipeline_status = await get_namespace_data(
"pipeline_status", workspace=rag.workspace
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
# Check and set status with lock
async with pipeline_status_lock:
if pipeline_status.get("busy", False):
return ClearDocumentsResponse(
status="busy",
message="Cannot clear documents while pipeline is busy",
)
# Set busy to true
pipeline_status.update(
{
"busy": True,
"job_name": "Clearing Documents",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "Starting document clearing process",
}
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
pipeline_status["history_messages"].append(
"Starting document clearing process"
)
# Use drop method to clear all data
drop_tasks = []
storages = [
@ -2460,7 +2461,7 @@ def create_document_routes(
dependencies=[Depends(combined_auth)],
response_model=PipelineStatusResponse,
)
async def get_pipeline_status() -> PipelineStatusResponse:
async def get_pipeline_status(raw_request: Request) -> PipelineStatusResponse:
"""
Get the current status of the document indexing pipeline.
@ -2485,10 +2486,12 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving pipeline status (500)
"""
try:
rag = await create_rag(raw_request)
from lightrag.kg.shared_storage import (
get_all_update_flags_status,
get_namespace_data,
get_namespace_lock,
get_all_update_flags_status,
)
pipeline_status = await get_namespace_data(
@ -2556,10 +2559,8 @@ def create_document_routes(
raise HTTPException(status_code=500, detail=str(e))
# TODO: Deprecated, use /documents/paginated instead
@router.get(
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
)
async def documents() -> DocsStatusesResponse:
@router.get("", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)])
async def documents(raw_request: Request) -> DocsStatusesResponse:
"""
Get the status of all documents in the system. This endpoint is deprecated; use /documents/paginated instead.
To prevent excessive resource consumption, a maximum of 1,000 records is returned.
@ -2578,6 +2579,8 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving document statuses (500).
"""
try:
rag = await create_rag(raw_request)
statuses = (
DocStatus.PENDING,
DocStatus.PROCESSING,
@ -2673,6 +2676,7 @@ def create_document_routes(
summary="Delete a document and all its associated data by its ID.",
)
async def delete_document(
raw_request: Request,
delete_request: DeleteDocRequest,
background_tasks: BackgroundTasks,
) -> DeleteDocByIdResponse:
@ -2699,9 +2703,12 @@ def create_document_routes(
HTTPException:
- 500: If an unexpected internal error occurs during initialization.
"""
doc_ids = delete_request.doc_ids
try:
rag = await create_rag(raw_request)
doc_manager = create_doc_manager(raw_request)
doc_ids = delete_request.doc_ids
from lightrag.kg.shared_storage import (
get_namespace_data,
get_namespace_lock,
@ -2750,7 +2757,7 @@ def create_document_routes(
response_model=ClearCacheResponse,
dependencies=[Depends(combined_auth)],
)
async def clear_cache(request: ClearCacheRequest):
async def clear_cache(raw_request: Request, request: ClearCacheRequest):
"""
Clear all cache data from the LLM response cache storage.
@ -2767,6 +2774,7 @@ def create_document_routes(
HTTPException: If an error occurs during cache clearing (500).
"""
try:
rag = await create_rag(raw_request)
# Call the aclear_cache method (no modes parameter)
await rag.aclear_cache()
@ -2784,7 +2792,7 @@ def create_document_routes(
response_model=DeletionResult,
dependencies=[Depends(combined_auth)],
)
async def delete_entity(request: DeleteEntityRequest):
async def delete_entity(raw_request: Request, request: DeleteEntityRequest):
"""
Delete an entity and all its relationships from the knowledge graph.
@ -2798,6 +2806,8 @@ def create_document_routes(
HTTPException: If the entity is not found (404) or an error occurs (500).
"""
try:
rag = await create_rag(raw_request)
result = await rag.adelete_by_entity(entity_name=request.entity_name)
if result.status == "not_found":
raise HTTPException(status_code=404, detail=result.message)
@ -2819,7 +2829,7 @@ def create_document_routes(
response_model=DeletionResult,
dependencies=[Depends(combined_auth)],
)
async def delete_relation(request: DeleteRelationRequest):
async def delete_relation(raw_request: Request, request: DeleteRelationRequest):
"""
Delete a relationship between two entities from the knowledge graph.
@ -2833,6 +2843,8 @@ def create_document_routes(
HTTPException: If the relation is not found (404) or an error occurs (500).
"""
try:
rag = await create_rag(raw_request)
result = await rag.adelete_by_relation(
source_entity=request.source_entity,
target_entity=request.target_entity,
@ -2857,7 +2869,7 @@ def create_document_routes(
response_model=TrackStatusResponse,
dependencies=[Depends(combined_auth)],
)
async def get_track_status(track_id: str) -> TrackStatusResponse:
async def get_track_status(raw_request: Request, track_id: str) -> TrackStatusResponse:
"""
Get the processing status of documents by tracking ID.
@ -2877,6 +2889,8 @@ def create_document_routes(
HTTPException: If track_id is invalid (400) or an error occurs (500).
"""
try:
rag = await create_rag(raw_request)
# Validate track_id
if not track_id or not track_id.strip():
raise HTTPException(status_code=400, detail="Track ID cannot be empty")
@ -2932,6 +2946,7 @@ def create_document_routes(
dependencies=[Depends(combined_auth)],
)
async def get_documents_paginated(
raw_request: Request,
request: DocumentsRequest,
) -> PaginatedDocsResponse:
"""
@ -2954,6 +2969,8 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving documents (500).
"""
try:
rag = await create_rag(raw_request)
# Get paginated documents and status counts in parallel
docs_task = rag.doc_status.get_docs_paginated(
status_filter=request.status_filter,
@ -3018,7 +3035,7 @@ def create_document_routes(
response_model=StatusCountsResponse,
dependencies=[Depends(combined_auth)],
)
async def get_document_status_counts() -> StatusCountsResponse:
async def get_document_status_counts(raw_request: Request) -> StatusCountsResponse:
"""
Get counts of documents by status.
@ -3032,6 +3049,8 @@ def create_document_routes(
HTTPException: If an error occurs while retrieving status counts (500).
"""
try:
rag = await create_rag(raw_request)
status_counts = await rag.doc_status.get_all_status_counts()
return StatusCountsResponse(status_counts=status_counts)
@ -3045,7 +3064,7 @@ def create_document_routes(
response_model=ReprocessResponse,
dependencies=[Depends(combined_auth)],
)
async def reprocess_failed_documents(background_tasks: BackgroundTasks):
async def reprocess_failed_documents(raw_request: Request, background_tasks: BackgroundTasks):
"""
Reprocess failed and pending documents.
@ -3068,6 +3087,8 @@ def create_document_routes(
HTTPException: If an error occurs while initiating reprocessing (500).
"""
try:
rag = await create_rag(raw_request)
# Generate track_id with "retry" prefix for retry operation
track_id = generate_track_id("retry")
@ -3093,7 +3114,7 @@ def create_document_routes(
response_model=CancelPipelineResponse,
dependencies=[Depends(combined_auth)],
)
async def cancel_pipeline():
async def cancel_pipeline(raw_request: Request):
"""
Request cancellation of the currently running pipeline.
@ -3115,6 +3136,8 @@ def create_document_routes(
HTTPException: If an error occurs while setting cancellation flag (500).
"""
try:
rag = await create_rag(raw_request)
from lightrag.kg.shared_storage import (
get_namespace_data,
get_namespace_lock,

View file

@ -2,12 +2,14 @@
This module contains all graph-related routes for the LightRAG API.
"""
from typing import Optional, Dict, Any
import traceback
from fastapi import APIRouter, Depends, Query, HTTPException
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field
from lightrag.utils import logger
from ..utils_api import get_combined_auth_dependency
router = APIRouter(tags=["graph"])
@ -86,11 +88,11 @@ class RelationCreateRequest(BaseModel):
)
def create_graph_routes(rag, api_key: Optional[str] = None):
def create_graph_routes(create_rag, api_key: Optional[str] = None):
combined_auth = get_combined_auth_dependency(api_key)
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
async def get_graph_labels():
async def get_graph_labels(raw_request: Request):
"""
Get all graph labels
@ -98,6 +100,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of graph labels
"""
try:
rag = await create_rag(raw_request)
return await rag.get_graph_labels()
except Exception as e:
logger.error(f"Error getting graph labels: {str(e)}")
@ -108,9 +112,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/popular", dependencies=[Depends(combined_auth)])
async def get_popular_labels(
limit: int = Query(
300, description="Maximum number of popular labels to return", ge=1, le=1000
),
raw_request: Request,
limit: int = Query(300, description="Maximum number of popular labels to return", ge=1, le=1000),
):
"""
Get popular labels by node degree (most connected entities)
@ -122,6 +125,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of popular labels sorted by degree (highest first)
"""
try:
rag = await create_rag(raw_request)
return await rag.chunk_entity_relation_graph.get_popular_labels(limit)
except Exception as e:
logger.error(f"Error getting popular labels: {str(e)}")
@ -132,6 +137,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/search", dependencies=[Depends(combined_auth)])
async def search_labels(
raw_request: Request,
q: str = Query(..., description="Search query string"),
limit: int = Query(
50, description="Maximum number of search results to return", ge=1, le=100
@ -148,6 +154,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of matching labels sorted by relevance
"""
try:
rag = await create_rag(raw_request)
return await rag.chunk_entity_relation_graph.search_labels(q, limit)
except Exception as e:
logger.error(f"Error searching labels with query '{q}': {str(e)}")
@ -158,6 +166,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graphs", dependencies=[Depends(combined_auth)])
async def get_knowledge_graph(
raw_request: Request,
label: str = Query(..., description="Label to get knowledge graph for"),
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
@ -177,6 +186,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict[str, List[str]]: Knowledge graph for label
"""
try:
rag = await create_rag(raw_request)
# Log the label parameter to check for leading spaces
logger.debug(
f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})"
@ -196,6 +207,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)])
async def check_entity_exists(
raw_request: Request,
name: str = Query(..., description="Entity name to check"),
):
"""
@ -208,6 +220,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists
"""
try:
rag = await create_rag(raw_request)
exists = await rag.chunk_entity_relation_graph.has_node(name)
return {"exists": exists}
except Exception as e:
@ -218,7 +232,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
)
@router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)])
async def update_entity(request: EntityUpdateRequest):
async def update_entity(raw_request: Request, request: EntityUpdateRequest):
"""
Update an entity's properties in the knowledge graph
@ -353,6 +367,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
}
"""
try:
rag = await create_rag(raw_request)
result = await rag.aedit_entity(
entity_name=request.entity_name,
updated_data=request.updated_data,
@ -408,7 +424,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
)
@router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)])
async def update_relation(request: RelationUpdateRequest):
async def update_relation(raw_request: Request, request: RelationUpdateRequest):
"""Update a relation's properties in the knowledge graph
Args:
@ -418,6 +434,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict: Updated relation information
"""
try:
rag = await create_rag(raw_request)
result = await rag.aedit_relation(
source_entity=request.source_id,
target_entity=request.target_id,
@ -443,7 +461,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
)
@router.post("/graph/entity/create", dependencies=[Depends(combined_auth)])
async def create_entity(request: EntityCreateRequest):
async def create_entity(raw_request: Request, request: EntityCreateRequest):
"""
Create a new entity in the knowledge graph
@ -488,6 +506,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
}
"""
try:
rag = await create_rag(raw_request)
# Use the proper acreate_entity method which handles:
# - Graph lock for concurrency
# - Vector embedding creation in entities_vdb
@ -516,7 +536,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
)
@router.post("/graph/relation/create", dependencies=[Depends(combined_auth)])
async def create_relation(request: RelationCreateRequest):
async def create_relation(raw_request: Request, request: RelationCreateRequest):
"""
Create a new relationship between two entities in the knowledge graph
@ -573,6 +593,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
}
"""
try:
rag = await create_rag(raw_request)
# Use the proper acreate_relation method which handles:
# - Graph lock for concurrency
# - Entity existence validation
@ -605,7 +627,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
)
@router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)])
async def merge_entities(request: EntityMergeRequest):
async def merge_entities(raw_request: Request, request: EntityMergeRequest):
"""
Merge multiple entities into a single entity, preserving all relationships
@ -662,6 +684,8 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
- This operation cannot be undone, so verify entity names before merging
"""
try:
rag = await create_rag(raw_request)
result = await rag.amerge_entities(
source_entities=request.entities_to_change,
target_entity=request.entity_to_change_into,

View file

@ -1,17 +1,17 @@
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from typing import List, Dict, Any, Optional, Type
from lightrag.utils import logger
import time
import asyncio
import json
import re
import time
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
import asyncio
from pydantic import BaseModel
from lightrag import LightRAG, QueryParam
from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import get_combined_auth_dependency
from fastapi import Depends
from lightrag.utils import TiktokenTokenizer, logger
# query mode according to query prefix (bypass is not LightRAG quer mode)
@ -117,9 +117,7 @@ class OllamaPsResponse(BaseModel):
models: List[OllamaRunningModel]
async def parse_request_body(
request: Request, model_class: Type[BaseModel]
) -> BaseModel:
async def parse_request_body(request: Request, model_class: Type[BaseModel]) -> BaseModel:
"""
Parse request body based on Content-Type header.
Supports both application/json and application/octet-stream.
@ -151,9 +149,7 @@ async def parse_request_body(
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON in request body")
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Error parsing request body: {str(e)}"
)
raise HTTPException(status_code=400, detail=f"Error parsing request body: {str(e)}")
def estimate_tokens(text: str) -> int:
@ -218,9 +214,8 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
class OllamaAPI:
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
self.rag = rag
self.ollama_server_infos = rag.ollama_server_infos
def __init__(self, create_rag, top_k: int = 60, api_key: Optional[str] = None):
self.create_rag = create_rag
self.top_k = top_k
self.api_key = api_key
self.router = APIRouter(tags=["ollama"])
@ -236,21 +231,24 @@ class OllamaAPI:
return OllamaVersionResponse(version="0.9.3")
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
async def get_tags():
async def get_tags(raw_request: Request):
"""Return available models acting as an Ollama server"""
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
return OllamaTagResponse(
models=[
{
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
"name": ollama_server_infos.LIGHTRAG_MODEL,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"size": ollama_server_infos.LIGHTRAG_SIZE,
"digest": ollama_server_infos.LIGHTRAG_DIGEST,
"details": {
"parent_model": "",
"format": "gguf",
"family": self.ollama_server_infos.LIGHTRAG_NAME,
"families": [self.ollama_server_infos.LIGHTRAG_NAME],
"family": ollama_server_infos.LIGHTRAG_NAME,
"families": [ollama_server_infos.LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0",
},
@ -259,15 +257,18 @@ class OllamaAPI:
)
@self.router.get("/ps", dependencies=[Depends(combined_auth)])
async def get_running_models():
async def get_running_models(raw_request: Request):
"""List Running Models - returns currently running models"""
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
return OllamaPsResponse(
models=[
{
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
"name": ollama_server_infos.LIGHTRAG_MODEL,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"size": ollama_server_infos.LIGHTRAG_SIZE,
"digest": ollama_server_infos.LIGHTRAG_DIGEST,
"details": {
"parent_model": "",
"format": "gguf",
@ -277,14 +278,12 @@ class OllamaAPI:
"quantization_level": "Q4_0",
},
"expires_at": "2050-12-31T14:38:31.83753-07:00",
"size_vram": self.ollama_server_infos.LIGHTRAG_SIZE,
"size_vram": ollama_server_infos.LIGHTRAG_SIZE,
}
]
)
@self.router.post(
"/generate", dependencies=[Depends(combined_auth)], include_in_schema=True
)
@self.router.post("/generate", dependencies=[Depends(combined_auth)], include_in_schema=True)
async def generate(raw_request: Request):
"""Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG,
@ -292,6 +291,9 @@ class OllamaAPI:
Supports both application/json and application/octet-stream Content-Types.
"""
try:
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
# Parse the request body manually
request = await parse_request_body(raw_request, OllamaGenerateRequest)
@ -300,12 +302,10 @@ class OllamaAPI:
prompt_tokens = estimate_tokens(query)
if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system
rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
response = await self.rag.llm_model_func(
query, stream=True, **self.rag.llm_model_kwargs
)
response = await rag.llm_model_func(query, stream=True, **rag.llm_model_kwargs)
async def stream_generator():
first_chunk_time = None
@ -320,8 +320,8 @@ class OllamaAPI:
total_response = response
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": response,
"done": False,
}
@ -333,8 +333,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
@ -358,8 +358,8 @@ class OllamaAPI:
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False,
}
@ -375,8 +375,8 @@ class OllamaAPI:
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": f"\n\nError: {error_msg}",
"error": f"\n\nError: {error_msg}",
"done": False,
@ -385,8 +385,8 @@ class OllamaAPI:
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
}
@ -400,8 +400,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
@ -428,9 +428,7 @@ class OllamaAPI:
)
else:
first_chunk_time = time.time_ns()
response_text = await self.rag.llm_model_func(
query, stream=False, **self.rag.llm_model_kwargs
)
response_text = await rag.llm_model_func(query, stream=False, **rag.llm_model_kwargs)
last_chunk_time = time.time_ns()
if not response_text:
@ -442,8 +440,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time
return {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": str(response_text),
"done": True,
"done_reason": "stop",
@ -468,7 +466,11 @@ class OllamaAPI:
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
Supports both application/json and application/octet-stream Content-Types.
"""
try:
rag = await self.create_rag(raw_request)
ollama_server_infos = rag.ollama_server_infos
# Parse the request body manually
request = await parse_request_body(raw_request, OllamaChatRequest)
@ -516,15 +518,15 @@ class OllamaAPI:
# Determine if the request is prefix with "/bypass"
if mode == SearchMode.bypass:
if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system
response = await self.rag.llm_model_func(
rag.llm_model_kwargs["system_prompt"] = request.system
response = await rag.llm_model_func(
cleaned_query,
stream=True,
history_messages=conversation_history,
**self.rag.llm_model_kwargs,
**rag.llm_model_kwargs,
)
else:
response = await self.rag.aquery(
response = await rag.aquery(
cleaned_query, param=query_param
)
@ -541,8 +543,8 @@ class OllamaAPI:
total_response = response
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
@ -558,8 +560,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
@ -586,8 +588,8 @@ class OllamaAPI:
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
@ -607,8 +609,8 @@ class OllamaAPI:
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": f"\n\nError: {error_msg}",
@ -621,8 +623,8 @@ class OllamaAPI:
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
@ -641,8 +643,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
@ -678,18 +680,16 @@ class OllamaAPI:
)
if match_result or mode == SearchMode.bypass:
if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system
rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await self.rag.llm_model_func(
response_text = await rag.llm_model_func(
cleaned_query,
stream=False,
history_messages=conversation_history,
**self.rag.llm_model_kwargs,
**rag.llm_model_kwargs,
)
else:
response_text = await self.rag.aquery(
cleaned_query, param=query_param
)
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
@ -702,8 +702,8 @@ class OllamaAPI:
eval_time = last_chunk_time - first_chunk_time
return {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": str(response_text),

View file

@ -4,12 +4,14 @@ This module contains all query-related routes for the LightRAG API.
import json
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from lightrag.api.utils_api import get_combined_auth_dependency
from lightrag.utils import logger
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
from lightrag.api.utils_api import get_combined_auth_dependency
from lightrag.base import QueryParam
from lightrag.utils import logger
router = APIRouter(tags=["query"])
@ -190,7 +192,7 @@ class StreamChunkResponse(BaseModel):
)
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
def create_query_routes(create_rag, api_key: Optional[str] = None, top_k: int = 60):
combined_auth = get_combined_auth_dependency(api_key)
@router.post(
@ -322,7 +324,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
},
},
)
async def query_text(request: QueryRequest):
async def query_text(raw_request: Request, request: QueryRequest):
"""
Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored.
@ -402,6 +404,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
- 500: Internal processing error (e.g., LLM service unavailable)
"""
try:
rag = await create_rag(raw_request)
param = request.to_query_params(
False
) # Ensure stream=False for non-streaming endpoint
@ -532,7 +536,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
},
},
)
async def query_text_stream(request: QueryRequest):
async def query_text_stream(raw_request: Request, request: QueryRequest):
"""
Advanced RAG query endpoint with flexible streaming response.
@ -660,6 +664,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
Use streaming mode for real-time interfaces and non-streaming for batch processing.
"""
try:
rag = await create_rag(raw_request)
# Use the stream parameter from the request, defaulting to True if not specified
stream_mode = request.stream if request.stream is not None else True
param = request.to_query_params(stream_mode)
@ -1035,7 +1041,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
},
},
)
async def query_data(request: QueryRequest):
async def query_data(raw_request: Request, request: QueryRequest):
"""
Advanced data retrieval endpoint for structured RAG analysis.
@ -1139,6 +1145,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
as structured data analysis typically requires source attribution.
"""
try:
rag = await create_rag(raw_request)
param = request.to_query_params(False) # No streaming for data endpoint
response = await rag.aquery_data(request.query, param=param)
@ -1151,6 +1159,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
status="failure",
message="Invalid response type",
data={},
metadata={},
)
except Exception as e:
logger.error(f"Error processing data query: {str(e)}", exc_info=True)

View file

@ -289,6 +289,7 @@ const axiosInstance = axios.create({
axiosInstance.interceptors.request.use((config) => {
const apiKey = useSettingsStore.getState().apiKey
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
const workspace = localStorage.getItem('LIGHTRAG-WORKSPACE');
// Always include token if it exists, regardless of path
if (token) {
@ -297,6 +298,9 @@ axiosInstance.interceptors.request.use((config) => {
if (apiKey) {
config.headers['X-API-Key'] = apiKey
}
if (workspace) {
config.headers['LIGHTRAG-WORKSPACE'] = workspace
}
return config
})
@ -397,6 +401,7 @@ export const queryTextStream = async (
) => {
const apiKey = useSettingsStore.getState().apiKey;
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
const workspace = localStorage.getItem('LIGHTRAG-WORKSPACE');
const headers: HeadersInit = {
'Content-Type': 'application/json',
'Accept': 'application/x-ndjson',
@ -407,6 +412,9 @@ export const queryTextStream = async (
if (apiKey) {
headers['X-API-Key'] = apiKey;
}
if (workspace) {
headers['LIGHTRAG-WORKSPACE'] = workspace;
}
try {
const response = await fetch(`${backendBaseUrl}/query/stream`, {

View file

@ -7,8 +7,10 @@ import { useAuthStore } from '@/stores/state'
import { cn } from '@/lib/utils'
import { useTranslation } from 'react-i18next'
import { navigationService } from '@/services/navigation'
import { ZapIcon, GithubIcon, LogOutIcon } from 'lucide-react'
import { ZapIcon, GithubIcon, LogOutIcon, CheckIcon } from 'lucide-react'
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip'
import {useState, useEffect} from "react";
import {toast} from 'sonner';
interface NavigationTabProps {
value: string
@ -57,6 +59,7 @@ function TabsNavigation() {
export default function SiteHeader() {
const { t } = useTranslation()
const { isGuestMode, coreVersion, apiVersion, username, webuiTitle, webuiDescription } = useAuthStore()
const [workspace, setWorkspace] = useState('');
const versionDisplay = (coreVersion && apiVersion)
? `${coreVersion}/${apiVersion}`
@ -72,6 +75,26 @@ export default function SiteHeader() {
navigationService.navigateToLogin();
}
useEffect(() => {
const ws = localStorage.getItem('LIGHTRAG-WORKSPACE') || '';
setWorkspace(ws);
}, []);
const handleWorkspaceUpdate = () => {
const trimed = workspace.trim();
if (trimed) {
localStorage.setItem('LIGHTRAG-WORKSPACE', trimed);
toast.success(t('Workspace set. Reloading page...'));
} else {
localStorage.removeItem('LIGHTRAG-WORKSPACE');
toast.success(t('Workspace cleared. Reloading page...'));
}
setTimeout(() => {
window.location.reload();
}, 500);
}
return (
<header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
<div className="min-w-[200px] w-auto flex items-center">
@ -111,6 +134,10 @@ export default function SiteHeader() {
<nav className="w-[200px] flex items-center justify-end">
<div className="flex items-center gap-2">
<div className="flex items-center gap-1">
<input type="text" value={workspace} onChange={(e) => setWorkspace(e.target.value)} placeholder="workspace" className="h-6 w-20 px-1 text-xs border rounded bg-background text-foreground"/>
<Button variant="ghost" size="icon" className="h-6 w-6" side="bottom" tooltip={t('header.updateWorkspace', 'Update Workspace')} onClick={handleWorkspaceUpdate}><CheckIcon className="size-3" aria-hidden="true" /></Button>
</div>
{versionDisplay && (
<TooltipProvider>
<Tooltip>