Fix linting

This commit is contained in:
yangdx 2025-09-25 16:22:00 +08:00
parent b08b8a6a6a
commit b848ca49e6
4 changed files with 44 additions and 64 deletions

View file

@ -198,7 +198,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
""" """
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
The streaming response includes: The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True) 1. Reference list (sent first as a single message, if include_references=True)
2. LLM response content (streamed as multiple chunks) 2. LLM response content (streamed as multiple chunks)
@ -224,18 +224,22 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
if request.include_references: if request.include_references:
try: try:
# Use aquery_data to get reference list independently # Use aquery_data to get reference list independently
data_param = request.to_query_params(False) # Non-streaming for data data_param = request.to_query_params(
data_result = await rag.aquery_data(request.query, param=data_param) False
) # Non-streaming for data
data_result = await rag.aquery_data(
request.query, param=data_param
)
if isinstance(data_result, dict) and "data" in data_result: if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", []) reference_list = data_result["data"].get("references", [])
except Exception as e: except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}") logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = [] reference_list = []
# Send reference list first (if requested) # Send reference list first (if requested)
if request.include_references: if request.include_references:
yield f"{json.dumps({'references': reference_list})}\n" yield f"{json.dumps({'references': reference_list})}\n"
# Then stream the response content # Then stream the response content
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send it all at once # If it's a string, send it all at once

View file

@ -828,27 +828,29 @@ class DeletionResult:
# Unified Query Result Data Structures for Reference List Support # Unified Query Result Data Structures for Reference List Support
@dataclass @dataclass
class QueryResult: class QueryResult:
""" """
Unified query result data structure for all query modes. Unified query result data structure for all query modes.
Attributes: Attributes:
content: Text content for non-streaming responses content: Text content for non-streaming responses
response_iterator: Streaming response iterator for streaming responses response_iterator: Streaming response iterator for streaming responses
raw_data: Complete structured data including references and metadata raw_data: Complete structured data including references and metadata
is_streaming: Whether this is a streaming result is_streaming: Whether this is a streaming result
""" """
content: Optional[str] = None content: Optional[str] = None
response_iterator: Optional[AsyncIterator[str]] = None response_iterator: Optional[AsyncIterator[str]] = None
raw_data: Optional[Dict[str, Any]] = None raw_data: Optional[Dict[str, Any]] = None
is_streaming: bool = False is_streaming: bool = False
@property @property
def reference_list(self) -> List[Dict[str, str]]: def reference_list(self) -> List[Dict[str, str]]:
""" """
Convenient property to extract reference list from raw_data. Convenient property to extract reference list from raw_data.
Returns: Returns:
List[Dict[str, str]]: Reference list in format: List[Dict[str, str]]: Reference list in format:
[{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...] [{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...]
@ -856,12 +858,12 @@ class QueryResult:
if self.raw_data: if self.raw_data:
return self.raw_data.get("data", {}).get("references", []) return self.raw_data.get("data", {}).get("references", [])
return [] return []
@property @property
def metadata(self) -> Dict[str, Any]: def metadata(self) -> Dict[str, Any]:
""" """
Convenient property to extract metadata from raw_data. Convenient property to extract metadata from raw_data.
Returns: Returns:
Dict[str, Any]: Query metadata including query_mode, keywords, etc. Dict[str, Any]: Query metadata including query_mode, keywords, etc.
""" """
@ -874,14 +876,15 @@ class QueryResult:
class QueryContextResult: class QueryContextResult:
""" """
Unified query context result data structure. Unified query context result data structure.
Attributes: Attributes:
context: LLM context string context: LLM context string
raw_data: Complete structured data including reference_list raw_data: Complete structured data including reference_list
""" """
context: str context: str
raw_data: Dict[str, Any] raw_data: Dict[str, Any]
@property @property
def reference_list(self) -> List[Dict[str, str]]: def reference_list(self) -> List[Dict[str, str]]:
"""Convenient property to extract reference list from raw_data.""" """Convenient property to extract reference list from raw_data."""

View file

@ -2077,7 +2077,7 @@ class LightRAG:
global_config = asdict(self) global_config = asdict(self)
query_result = None query_result = None
if param.mode in ["local", "global", "hybrid", "mix"]: if param.mode in ["local", "global", "hybrid", "mix"]:
query_result = await kg_query( query_result = await kg_query(
query.strip(), query.strip(),
@ -2118,13 +2118,13 @@ class LightRAG:
query_result = QueryResult( query_result = QueryResult(
content=response if not param.stream else None, content=response if not param.stream else None,
response_iterator=response if param.stream else None, response_iterator=response if param.stream else None,
is_streaming=param.stream is_streaming=param.stream,
) )
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
await self._query_done() await self._query_done()
# Return appropriate response based on streaming mode # Return appropriate response based on streaming mode
if query_result.is_streaming: if query_result.is_streaming:
return query_result.response_iterator return query_result.response_iterator
@ -2266,7 +2266,7 @@ class LightRAG:
) )
query_result = None query_result = None
if data_param.mode in ["local", "global", "hybrid", "mix"]: if data_param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}") logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}")
query_result = await kg_query( query_result = await kg_query(
@ -2301,10 +2301,7 @@ class LightRAG:
[], # no references [], # no references
"bypass", "bypass",
) )
query_result = QueryResult( query_result = QueryResult(content="", raw_data=empty_raw_data)
content="",
raw_data=empty_raw_data
)
else: else:
raise ValueError(f"Unknown mode {data_param.mode}") raise ValueError(f"Unknown mode {data_param.mode}")

View file

@ -2282,7 +2282,7 @@ async def kg_query(
) -> QueryResult: ) -> QueryResult:
""" """
Execute knowledge graph query and return unified QueryResult object. Execute knowledge graph query and return unified QueryResult object.
Args: Args:
query: Query string query: Query string
knowledge_graph_inst: Knowledge graph storage instance knowledge_graph_inst: Knowledge graph storage instance
@ -2294,21 +2294,21 @@ async def kg_query(
hashing_kv: Cache storage hashing_kv: Cache storage
system_prompt: System prompt system_prompt: System prompt
chunks_vdb: Document chunks vector database chunks_vdb: Document chunks vector database
Returns: Returns:
QueryResult: Unified query result object containing: QueryResult: Unified query result object containing:
- content: Non-streaming response text content - content: Non-streaming response text content
- response_iterator: Streaming response iterator - response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata) - raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result - is_streaming: Whether this is a streaming result
Based on different query_param settings, different fields will be populated: Based on different query_param settings, different fields will be populated:
- only_need_context=True: content contains context string - only_need_context=True: content contains context string
- only_need_prompt=True: content contains complete prompt - only_need_prompt=True: content contains complete prompt
- stream=True: response_iterator contains streaming response, raw_data contains complete data - stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data - default: content contains LLM response text, raw_data contains complete data
""" """
if not query: if not query:
return QueryResult(content=PROMPTS["fail_response"]) return QueryResult(content=PROMPTS["fail_response"])
@ -2386,8 +2386,7 @@ async def kg_query(
# Return different content based on query parameters # Return different content based on query parameters
if query_param.only_need_context and not query_param.only_need_prompt: if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult( return QueryResult(
content=context_result.context, content=context_result.context, raw_data=context_result.raw_data
raw_data=context_result.raw_data
) )
# Build system prompt # Build system prompt
@ -2405,10 +2404,7 @@ async def kg_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult( return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
content=prompt_content,
raw_data=context_result.raw_data
)
# Call LLM # Call LLM
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
@ -2466,16 +2462,13 @@ async def kg_query(
), ),
) )
return QueryResult( return QueryResult(content=response, raw_data=context_result.raw_data)
content=response,
raw_data=context_result.raw_data
)
else: else:
# Streaming response (AsyncIterator) # Streaming response (AsyncIterator)
return QueryResult( return QueryResult(
response_iterator=response, response_iterator=response,
raw_data=context_result.raw_data, raw_data=context_result.raw_data,
is_streaming=True is_streaming=True,
) )
@ -3375,7 +3368,7 @@ async def _build_query_context(
""" """
Main query context building function using the new 4-stage architecture: Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context 1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
Returns unified QueryContextResult containing both context and raw_data. Returns unified QueryContextResult containing both context and raw_data.
""" """
@ -3477,11 +3470,8 @@ async def _build_query_context(
logger.debug( logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}" f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
) )
return QueryContextResult( return QueryContextResult(context=context, raw_data=raw_data)
context=context,
raw_data=raw_data
)
async def _get_node_data( async def _get_node_data(
@ -4090,7 +4080,7 @@ async def naive_query(
) -> QueryResult: ) -> QueryResult:
""" """
Execute naive query and return unified QueryResult object. Execute naive query and return unified QueryResult object.
Args: Args:
query: Query string query: Query string
chunks_vdb: Document chunks vector database chunks_vdb: Document chunks vector database
@ -4098,7 +4088,7 @@ async def naive_query(
global_config: Global configuration global_config: Global configuration
hashing_kv: Cache storage hashing_kv: Cache storage
system_prompt: System prompt system_prompt: System prompt
Returns: Returns:
QueryResult: Unified query result object containing: QueryResult: Unified query result object containing:
- content: Non-streaming response text content - content: Non-streaming response text content
@ -4106,7 +4096,7 @@ async def naive_query(
- raw_data: Complete structured data (including references and metadata) - raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result - is_streaming: Whether this is a streaming result
""" """
if not query: if not query:
return QueryResult(content=PROMPTS["fail_response"]) return QueryResult(content=PROMPTS["fail_response"])
@ -4157,10 +4147,7 @@ async def naive_query(
"naive", "naive",
) )
empty_raw_data["message"] = "No relevant document chunks found." empty_raw_data["message"] = "No relevant document chunks found."
return QueryResult( return QueryResult(content=PROMPTS["fail_response"], raw_data=empty_raw_data)
content=PROMPTS["fail_response"],
raw_data=empty_raw_data
)
# Calculate dynamic token limit for chunks # Calculate dynamic token limit for chunks
max_total_tokens = getattr( max_total_tokens = getattr(
@ -4275,10 +4262,7 @@ async def naive_query(
""" """
if query_param.only_need_context and not query_param.only_need_prompt: if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult( return QueryResult(content=context_content, raw_data=raw_data)
content=context_content,
raw_data=raw_data
)
user_query = ( user_query = (
"\n\n".join([query, query_param.user_prompt]) "\n\n".join([query, query_param.user_prompt])
@ -4294,10 +4278,7 @@ async def naive_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult( return QueryResult(content=prompt_content, raw_data=raw_data)
content=prompt_content,
raw_data=raw_data
)
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug( logger.debug(
@ -4354,14 +4335,9 @@ async def naive_query(
), ),
) )
return QueryResult( return QueryResult(content=response, raw_data=raw_data)
content=response,
raw_data=raw_data
)
else: else:
# Streaming response (AsyncIterator) # Streaming response (AsyncIterator)
return QueryResult( return QueryResult(
response_iterator=response, response_iterator=response, raw_data=raw_data, is_streaming=True
raw_data=raw_data,
is_streaming=True
) )