LightRAG/lightrag/api/routers/query_routes.py
yangdx a528213210 Fix logging filter logic
• Fix boolean operator precedence in filter
• Consolidate GET/POST condition logic
2025-09-26 19:42:33 +08:00

309 lines
12 KiB
Python

"""
This module contains all query-related routes for the LightRAG API.
"""
import json
import logging
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
class QueryRequest(BaseModel):
query: str = Field(
min_length=3,
description="The query text",
)
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
default="mix",
description="Query mode",
)
only_need_context: Optional[bool] = Field(
default=None,
description="If True, only returns the retrieved context without generating a response.",
)
only_need_prompt: Optional[bool] = Field(
default=None,
description="If True, only returns the generated prompt without producing a response.",
)
response_type: Optional[str] = Field(
min_length=1,
default=None,
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
)
top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
)
chunk_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to retrieve initially from vector search and keep after reranking.",
)
max_entity_tokens: Optional[int] = Field(
default=None,
description="Maximum number of tokens allocated for entity context in unified token control system.",
ge=1,
)
max_relation_tokens: Optional[int] = Field(
default=None,
description="Maximum number of tokens allocated for relationship context in unified token control system.",
ge=1,
)
max_total_tokens: Optional[int] = Field(
default=None,
description="Maximum total tokens budget for the entire query context (entities + relations + chunks + system prompt).",
ge=1,
)
conversation_history: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
)
user_prompt: Optional[str] = Field(
default=None,
description="User-provided prompt for the query. If provided, this will be used instead of the default value from prompt template.",
)
enable_rerank: Optional[bool] = Field(
default=None,
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
)
include_references: Optional[bool] = Field(
default=True,
description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
)
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
return query.strip()
@field_validator("conversation_history", mode="after")
@classmethod
def conversation_history_role_check(
cls, conversation_history: List[Dict[str, Any]] | None
) -> List[Dict[str, Any]] | None:
if conversation_history is None:
return None
for msg in conversation_history:
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
raise ValueError(
"Each message must have a 'role' key with value 'user' or 'assistant'."
)
return conversation_history
def to_query_params(self, is_stream: bool) -> "QueryParam":
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
request_data = self.model_dump(exclude_none=True, exclude={"query"})
# Ensure `mode` and `stream` are set explicitly
param = QueryParam(**request_data)
param.stream = is_stream
return param
class QueryResponse(BaseModel):
response: str = Field(
description="The generated response",
)
references: Optional[List[Dict[str, str]]] = Field(
default=None,
description="Reference list (only included when include_references=True, /query/data always includes references.)",
)
class QueryDataResponse(BaseModel):
status: str = Field(description="Query execution status")
message: str = Field(description="Status message")
data: Dict[str, Any] = Field(
description="Query result data containing entities, relationships, chunks, and references"
)
metadata: Dict[str, Any] = Field(
description="Query metadata including mode, keywords, and processing information"
)
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
combined_auth = get_combined_auth_dependency(api_key)
@router.post(
"/query", response_model=QueryResponse, dependencies=[Depends(combined_auth)]
)
async def query_text(request: QueryRequest):
"""
This endpoint performs a RAG query with non-streaming response.
Parameters:
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list.
Raises:
HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message.
"""
try:
param = request.to_query_params(
False
) # Ensure stream=False for non-streaming endpoint
# Force stream=False for /query endpoint regardless of include_references setting
param.stream = False
# Unified approach: always use aquery_llm for both cases
result = await rag.aquery_llm(request.query, param=param)
# Extract LLM response and references from unified result
llm_response = result.get("llm_response", {})
references = result.get("data", {}).get("references", [])
# Get the non-streaming response content
response_content = llm_response.get("content", "")
if not response_content:
response_content = "No relevant context found for the query."
# Return response with or without references based on request
if request.include_references:
return QueryResponse(response=response_content, references=references)
else:
return QueryResponse(response=response_content, references=None)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/query/stream", dependencies=[Depends(combined_auth)])
async def query_text_stream(request: QueryRequest):
"""
This endpoint performs RAG query with streaming response.
Streaming can be turn off by setting stream=False in QueryRequest.
The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True)
2. LLM response content (streamed as multiple chunks)
Args:
request (QueryRequest): The request object containing the query parameters.
Returns:
StreamingResponse: A streaming response containing:
- First message: {"references": [...]} - Complete reference list (if requested)
- Subsequent messages: {"response": "..."} - LLM response chunks
- Error messages: {"error": "..."} - If any errors occur
"""
try:
param = request.to_query_params(
True
) # Ensure stream=True for streaming endpoint
from fastapi.responses import StreamingResponse
# Unified approach: always use aquery_llm for all cases
result = await rag.aquery_llm(request.query, param=param)
async def stream_generator():
# Extract references and LLM response from unified result
references = result.get("data", {}).get("references", [])
llm_response = result.get("llm_response", {})
# Send reference list first if requested
if request.include_references:
yield f"{json.dumps({'references': references})}\n"
# Then stream the LLM response content
if llm_response.get("is_streaming"):
response_stream = llm_response.get("response_iterator")
if response_stream:
try:
async for chunk in response_stream:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
else:
# Non-streaming response (fallback)
response_content = llm_response.get("content", "")
if response_content:
yield f"{json.dumps({'response': response_content})}\n"
else:
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/query/data",
response_model=QueryDataResponse,
dependencies=[Depends(combined_auth)],
)
async def query_data(request: QueryRequest):
"""
Retrieve structured data without LLM generation.
This endpoint returns raw retrieval results including entities, relationships,
and text chunks that would be used for RAG, but without generating a final response.
All parameters are compatible with the regular /query endpoint.
Parameters:
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryDataResponse: A Pydantic model containing structured data with status,
message, data (entities, relationships, chunks, references),
and metadata.
Raises:
HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message.
"""
try:
param = request.to_query_params(False) # No streaming for data endpoint
response = await rag.aquery_data(request.query, param=param)
# aquery_data returns the new format with status, message, data, and metadata
if isinstance(response, dict):
return QueryDataResponse(**response)
else:
# Handle unexpected response format
return QueryDataResponse(
status="failure",
message="Invalid response type",
data={},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
return router