309 lines
12 KiB
Python
309 lines
12 KiB
Python
"""
|
|
This module contains all query-related routes for the LightRAG API.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from lightrag.base import QueryParam
|
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from ascii_colors import trace_exception
|
|
|
|
router = APIRouter(tags=["query"])
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
query: str = Field(
|
|
min_length=3,
|
|
description="The query text",
|
|
)
|
|
|
|
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
|
|
default="mix",
|
|
description="Query mode",
|
|
)
|
|
|
|
only_need_context: Optional[bool] = Field(
|
|
default=None,
|
|
description="If True, only returns the retrieved context without generating a response.",
|
|
)
|
|
|
|
only_need_prompt: Optional[bool] = Field(
|
|
default=None,
|
|
description="If True, only returns the generated prompt without producing a response.",
|
|
)
|
|
|
|
response_type: Optional[str] = Field(
|
|
min_length=1,
|
|
default=None,
|
|
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
|
|
)
|
|
|
|
top_k: Optional[int] = Field(
|
|
ge=1,
|
|
default=None,
|
|
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
|
|
)
|
|
|
|
chunk_top_k: Optional[int] = Field(
|
|
ge=1,
|
|
default=None,
|
|
description="Number of text chunks to retrieve initially from vector search and keep after reranking.",
|
|
)
|
|
|
|
max_entity_tokens: Optional[int] = Field(
|
|
default=None,
|
|
description="Maximum number of tokens allocated for entity context in unified token control system.",
|
|
ge=1,
|
|
)
|
|
|
|
max_relation_tokens: Optional[int] = Field(
|
|
default=None,
|
|
description="Maximum number of tokens allocated for relationship context in unified token control system.",
|
|
ge=1,
|
|
)
|
|
|
|
max_total_tokens: Optional[int] = Field(
|
|
default=None,
|
|
description="Maximum total tokens budget for the entire query context (entities + relations + chunks + system prompt).",
|
|
ge=1,
|
|
)
|
|
|
|
conversation_history: Optional[List[Dict[str, Any]]] = Field(
|
|
default=None,
|
|
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
|
|
)
|
|
|
|
user_prompt: Optional[str] = Field(
|
|
default=None,
|
|
description="User-provided prompt for the query. If provided, this will be used instead of the default value from prompt template.",
|
|
)
|
|
|
|
enable_rerank: Optional[bool] = Field(
|
|
default=None,
|
|
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
|
|
)
|
|
|
|
include_references: Optional[bool] = Field(
|
|
default=True,
|
|
description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
|
|
)
|
|
|
|
@field_validator("query", mode="after")
|
|
@classmethod
|
|
def query_strip_after(cls, query: str) -> str:
|
|
return query.strip()
|
|
|
|
@field_validator("conversation_history", mode="after")
|
|
@classmethod
|
|
def conversation_history_role_check(
|
|
cls, conversation_history: List[Dict[str, Any]] | None
|
|
) -> List[Dict[str, Any]] | None:
|
|
if conversation_history is None:
|
|
return None
|
|
for msg in conversation_history:
|
|
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
|
|
raise ValueError(
|
|
"Each message must have a 'role' key with value 'user' or 'assistant'."
|
|
)
|
|
return conversation_history
|
|
|
|
def to_query_params(self, is_stream: bool) -> "QueryParam":
|
|
"""Converts a QueryRequest instance into a QueryParam instance."""
|
|
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
|
request_data = self.model_dump(exclude_none=True, exclude={"query"})
|
|
|
|
# Ensure `mode` and `stream` are set explicitly
|
|
param = QueryParam(**request_data)
|
|
param.stream = is_stream
|
|
return param
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
response: str = Field(
|
|
description="The generated response",
|
|
)
|
|
references: Optional[List[Dict[str, str]]] = Field(
|
|
default=None,
|
|
description="Reference list (only included when include_references=True, /query/data always includes references.)",
|
|
)
|
|
|
|
|
|
class QueryDataResponse(BaseModel):
|
|
status: str = Field(description="Query execution status")
|
|
message: str = Field(description="Status message")
|
|
data: Dict[str, Any] = Field(
|
|
description="Query result data containing entities, relationships, chunks, and references"
|
|
)
|
|
metadata: Dict[str, Any] = Field(
|
|
description="Query metadata including mode, keywords, and processing information"
|
|
)
|
|
|
|
|
|
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|
combined_auth = get_combined_auth_dependency(api_key)
|
|
|
|
@router.post(
|
|
"/query", response_model=QueryResponse, dependencies=[Depends(combined_auth)]
|
|
)
|
|
async def query_text(request: QueryRequest):
|
|
"""
|
|
This endpoint performs a RAG query with non-streaming response.
|
|
|
|
Parameters:
|
|
request (QueryRequest): The request object containing the query parameters.
|
|
Returns:
|
|
QueryResponse: A Pydantic model containing the result of the query processing.
|
|
If include_references=True, also includes reference list.
|
|
|
|
Raises:
|
|
HTTPException: Raised when an error occurs during the request handling process,
|
|
with status code 500 and detail containing the exception message.
|
|
"""
|
|
try:
|
|
param = request.to_query_params(
|
|
False
|
|
) # Ensure stream=False for non-streaming endpoint
|
|
# Force stream=False for /query endpoint regardless of include_references setting
|
|
param.stream = False
|
|
|
|
# Unified approach: always use aquery_llm for both cases
|
|
result = await rag.aquery_llm(request.query, param=param)
|
|
|
|
# Extract LLM response and references from unified result
|
|
llm_response = result.get("llm_response", {})
|
|
references = result.get("data", {}).get("references", [])
|
|
|
|
# Get the non-streaming response content
|
|
response_content = llm_response.get("content", "")
|
|
if not response_content:
|
|
response_content = "No relevant context found for the query."
|
|
|
|
# Return response with or without references based on request
|
|
if request.include_references:
|
|
return QueryResponse(response=response_content, references=references)
|
|
else:
|
|
return QueryResponse(response=response_content, references=None)
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post("/query/stream", dependencies=[Depends(combined_auth)])
|
|
async def query_text_stream(request: QueryRequest):
|
|
"""
|
|
This endpoint performs RAG query with streaming response.
|
|
Streaming can be turn off by setting stream=False in QueryRequest.
|
|
|
|
The streaming response includes:
|
|
1. Reference list (sent first as a single message, if include_references=True)
|
|
2. LLM response content (streamed as multiple chunks)
|
|
|
|
Args:
|
|
request (QueryRequest): The request object containing the query parameters.
|
|
|
|
Returns:
|
|
StreamingResponse: A streaming response containing:
|
|
- First message: {"references": [...]} - Complete reference list (if requested)
|
|
- Subsequent messages: {"response": "..."} - LLM response chunks
|
|
- Error messages: {"error": "..."} - If any errors occur
|
|
"""
|
|
try:
|
|
param = request.to_query_params(
|
|
True
|
|
) # Ensure stream=True for streaming endpoint
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
# Unified approach: always use aquery_llm for all cases
|
|
result = await rag.aquery_llm(request.query, param=param)
|
|
|
|
async def stream_generator():
|
|
# Extract references and LLM response from unified result
|
|
references = result.get("data", {}).get("references", [])
|
|
llm_response = result.get("llm_response", {})
|
|
|
|
# Send reference list first if requested
|
|
if request.include_references:
|
|
yield f"{json.dumps({'references': references})}\n"
|
|
|
|
# Then stream the LLM response content
|
|
if llm_response.get("is_streaming"):
|
|
response_stream = llm_response.get("response_iterator")
|
|
if response_stream:
|
|
try:
|
|
async for chunk in response_stream:
|
|
if chunk: # Only send non-empty content
|
|
yield f"{json.dumps({'response': chunk})}\n"
|
|
except Exception as e:
|
|
logging.error(f"Streaming error: {str(e)}")
|
|
yield f"{json.dumps({'error': str(e)})}\n"
|
|
else:
|
|
# Non-streaming response (fallback)
|
|
response_content = llm_response.get("content", "")
|
|
if response_content:
|
|
yield f"{json.dumps({'response': response_content})}\n"
|
|
else:
|
|
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
|
|
|
|
return StreamingResponse(
|
|
stream_generator(),
|
|
media_type="application/x-ndjson",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": "application/x-ndjson",
|
|
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
|
|
},
|
|
)
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post(
|
|
"/query/data",
|
|
response_model=QueryDataResponse,
|
|
dependencies=[Depends(combined_auth)],
|
|
)
|
|
async def query_data(request: QueryRequest):
|
|
"""
|
|
Retrieve structured data without LLM generation.
|
|
|
|
This endpoint returns raw retrieval results including entities, relationships,
|
|
and text chunks that would be used for RAG, but without generating a final response.
|
|
All parameters are compatible with the regular /query endpoint.
|
|
|
|
Parameters:
|
|
request (QueryRequest): The request object containing the query parameters.
|
|
|
|
Returns:
|
|
QueryDataResponse: A Pydantic model containing structured data with status,
|
|
message, data (entities, relationships, chunks, references),
|
|
and metadata.
|
|
|
|
Raises:
|
|
HTTPException: Raised when an error occurs during the request handling process,
|
|
with status code 500 and detail containing the exception message.
|
|
"""
|
|
try:
|
|
param = request.to_query_params(False) # No streaming for data endpoint
|
|
response = await rag.aquery_data(request.query, param=param)
|
|
|
|
# aquery_data returns the new format with status, message, data, and metadata
|
|
if isinstance(response, dict):
|
|
return QueryDataResponse(**response)
|
|
else:
|
|
# Handle unexpected response format
|
|
return QueryDataResponse(
|
|
status="failure",
|
|
message="Invalid response type",
|
|
data={},
|
|
)
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
return router
|