feat: simplify citations, add reference merging, and restructure API response format

This commit is contained in:
yangdx 2025-09-24 14:30:10 +08:00
parent 18968c6b6b
commit 5eb4a4b799
6 changed files with 452 additions and 187 deletions

View file

@ -2,7 +2,9 @@
LightRAG FastAPI Server LightRAG FastAPI Server
""" """
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
import os import os
import logging import logging
import logging.config import logging.config
@ -245,6 +247,35 @@ def create_app(args):
app = FastAPI(**app_kwargs) app = FastAPI(**app_kwargs)
# Add custom validation error handler for /query/data endpoint
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
# Check if this is a request to /query/data endpoint
if request.url.path.endswith("/query/data"):
# Extract error details
error_details = []
for error in exc.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_details.append(f"{field_path}: {error['msg']}")
error_message = "; ".join(error_details)
# Return in the expected format for /query/data
return JSONResponse(
status_code=400,
content={
"status": "failure",
"message": f"Validation error: {error_message}",
"data": {},
"metadata": {},
},
)
else:
# For other endpoints, return the default FastAPI validation error
return JSONResponse(status_code=422, content={"detail": exc.errors()})
def get_cors_origins(): def get_cors_origins():
"""Get allowed origins from global_args """Get allowed origins from global_args
Returns a list of allowed origins, defaults to ["*"] if not set Returns a list of allowed origins, defaults to ["*"] if not set

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam from lightrag.base import QueryParam
from ..utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception from ascii_colors import trace_exception
@ -18,7 +18,7 @@ router = APIRouter(tags=["query"])
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str = Field( query: str = Field(
min_length=1, min_length=3,
description="The query text", description="The query text",
) )
@ -135,14 +135,10 @@ class QueryResponse(BaseModel):
class QueryDataResponse(BaseModel): class QueryDataResponse(BaseModel):
entities: List[Dict[str, Any]] = Field( status: str = Field(description="Query execution status")
description="Retrieved entities from knowledge graph" message: str = Field(description="Status message")
) data: Dict[str, Any] = Field(
relationships: List[Dict[str, Any]] = Field( description="Query result data containing entities, relationships, chunks, and references"
description="Retrieved relationships from knowledge graph"
)
chunks: List[Dict[str, Any]] = Field(
description="Retrieved text chunks from documents"
) )
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
description="Query metadata including mode, keywords, and processing information" description="Query metadata including mode, keywords, and processing information"
@ -253,8 +249,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
request (QueryRequest): The request object containing the query parameters. request (QueryRequest): The request object containing the query parameters.
Returns: Returns:
QueryDataResponse: A Pydantic model containing structured data with entities, QueryDataResponse: A Pydantic model containing structured data with status,
relationships, chunks, and metadata. message, data (entities, relationships, chunks, references),
and metadata.
Raises: Raises:
HTTPException: Raised when an error occurs during the request handling process, HTTPException: Raised when an error occurs during the request handling process,
@ -264,40 +261,15 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
param = request.to_query_params(False) # No streaming for data endpoint param = request.to_query_params(False) # No streaming for data endpoint
response = await rag.aquery_data(request.query, param=param) response = await rag.aquery_data(request.query, param=param)
# The aquery_data method returns a dict with entities, relationships, chunks, and metadata # aquery_data returns the new format with status, message, data, and metadata
if isinstance(response, dict): if isinstance(response, dict):
# Ensure all required fields exist and are lists/dicts return QueryDataResponse(**response)
entities = response.get("entities", [])
relationships = response.get("relationships", [])
chunks = response.get("chunks", [])
metadata = response.get("metadata", {})
# Validate data types
if not isinstance(entities, list):
entities = []
if not isinstance(relationships, list):
relationships = []
if not isinstance(chunks, list):
chunks = []
if not isinstance(metadata, dict):
metadata = {}
return QueryDataResponse(
entities=entities,
relationships=relationships,
chunks=chunks,
metadata=metadata,
)
else: else:
# Fallback for unexpected response format # Handle unexpected response format
return QueryDataResponse( return QueryDataResponse(
entities=[], status="failure",
relationships=[], message="Invalid response type",
chunks=[], data={},
metadata={
"error": "Unexpected response format",
"raw_response": str(response),
},
) )
except Exception as e: except Exception as e:
trace_exception(e) trace_exception(e)

View file

@ -59,7 +59,7 @@ from lightrag.kg.shared_storage import (
get_data_init_lock, get_data_init_lock,
) )
from .base import ( from lightrag.base import (
BaseGraphStorage, BaseGraphStorage,
BaseKVStorage, BaseKVStorage,
BaseVectorStorage, BaseVectorStorage,
@ -72,8 +72,8 @@ from .base import (
DeletionResult, DeletionResult,
OllamaServerInfos, OllamaServerInfos,
) )
from .namespace import NameSpace from lightrag.namespace import NameSpace
from .operate import ( from lightrag.operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities, extract_entities,
merge_nodes_and_edges, merge_nodes_and_edges,
@ -81,8 +81,8 @@ from .operate import (
naive_query, naive_query,
_rebuild_knowledge_from_chunks, _rebuild_knowledge_from_chunks,
) )
from .constants import GRAPH_FIELD_SEP from lightrag.constants import GRAPH_FIELD_SEP
from .utils import ( from lightrag.utils import (
Tokenizer, Tokenizer,
TiktokenTokenizer, TiktokenTokenizer,
EmbeddingFunc, EmbeddingFunc,
@ -94,9 +94,10 @@ from .utils import (
sanitize_text_for_encoding, sanitize_text_for_encoding,
check_storage_env_vars, check_storage_env_vars,
generate_track_id, generate_track_id,
convert_to_user_format,
logger, logger,
) )
from .types import KnowledgeGraph from lightrag.types import KnowledgeGraph
from dotenv import load_dotenv from dotenv import load_dotenv
# use the .env that is inside the current folder # use the .env that is inside the current folder
@ -2127,11 +2128,104 @@ class LightRAG:
returning the final processed entities, relationships, and chunks data that would be sent to LLM. returning the final processed entities, relationships, and chunks data that would be sent to LLM.
Args: Args:
query: Query text. query: Query text for retrieval.
param: Query parameters (same as aquery). param: Query parameters controlling retrieval behavior (same as aquery).
Returns: Returns:
dict[str, Any]: Structured data result with entities, relationships, chunks, and metadata dict[str, Any]: Structured data result in the following format:
**Success Response:**
```python
{
"status": "success",
"message": "Query executed successfully",
"data": {
"entities": [
{
"entity_name": str, # Entity identifier
"entity_type": str, # Entity category/type
"description": str, # Entity description
"source_id": str, # Source chunk references
"file_path": str, # Origin file path
"created_at": str, # Creation timestamp
"reference_id": str # Reference identifier for citations
}
],
"relationships": [
{
"src_id": str, # Source entity name
"tgt_id": str, # Target entity name
"description": str, # Relationship description
"keywords": str, # Relationship keywords
"weight": float, # Relationship strength
"source_id": str, # Source chunk references
"file_path": str, # Origin file path
"created_at": str, # Creation timestamp
"reference_id": str # Reference identifier for citations
}
],
"chunks": [
{
"content": str, # Document chunk content
"file_path": str, # Origin file path
"chunk_id": str, # Unique chunk identifier
"reference_id": str # Reference identifier for citations
}
],
"references": [
{
"reference_id": str, # Reference identifier
"file_path": str # Corresponding file path
}
]
},
"metadata": {
"query_mode": str, # Query mode used ("local", "global", "hybrid", "mix", "naive", "bypass")
"keywords": {
"high_level": List[str], # High-level keywords extracted
"low_level": List[str] # Low-level keywords extracted
},
"processing_info": {
"total_entities_found": int, # Total entities before truncation
"total_relations_found": int, # Total relations before truncation
"entities_after_truncation": int, # Entities after token truncation
"relations_after_truncation": int, # Relations after token truncation
"merged_chunks_count": int, # Chunks before final processing
"final_chunks_count": int # Final chunks in result
}
}
}
```
**Query Mode Differences:**
- **local**: Focuses on entities and their related chunks based on low-level keywords
- **global**: Focuses on relationships and their connected entities based on high-level keywords
- **hybrid**: Combines local and global results using round-robin merging
- **mix**: Includes knowledge graph data plus vector-retrieved document chunks
- **naive**: Only vector-retrieved chunks, entities and relationships arrays are empty
- **bypass**: All data arrays are empty, used for direct LLM queries
** processing_info is optional and may not be present in all responses, especially when query result is empty**
**Failure Response:**
```python
{
"status": "failure",
"message": str, # Error description
"data": {} # Empty data object
}
```
**Common Failure Cases:**
- Empty query string
- Both high-level and low-level keywords are empty
- Query returns empty dataset
- Missing tokenizer or system configuration errors
Note:
The function adapts to the new data format from convert_to_user_format where
actual data is nested under the 'data' field, with 'status' and 'message'
fields at the top level.
""" """
global_config = asdict(self) global_config = asdict(self)
@ -2163,23 +2257,30 @@ class LightRAG:
) )
elif param.mode == "bypass": elif param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode") logger.debug("[aquery_data] Using bypass mode")
# bypass mode returns empty data # bypass mode returns empty data using convert_to_user_format
final_data = { final_data = convert_to_user_format(
"entities": [], [], # no entities
"relationships": [], [], # no relationships
"chunks": [], [], # no chunks
"metadata": { [], # no references
"query_mode": "bypass", "bypass",
"keywords": {"high_level": [], "low_level": []}, )
},
}
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
# Log final result counts # Log final result counts - adapt to new data format from convert_to_user_format
entities_count = len(final_data.get("entities", [])) if isinstance(final_data, dict) and "data" in final_data:
relationships_count = len(final_data.get("relationships", [])) # New format: data is nested under 'data' field
chunks_count = len(final_data.get("chunks", [])) data_section = final_data["data"]
entities_count = len(data_section.get("entities", []))
relationships_count = len(data_section.get("relationships", []))
chunks_count = len(data_section.get("chunks", []))
else:
# Fallback for other formats
entities_count = len(final_data.get("entities", []))
relationships_count = len(final_data.get("relationships", []))
chunks_count = len(final_data.get("chunks", []))
logger.debug( logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks" f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
) )
@ -2676,7 +2777,7 @@ class LightRAG:
Returns: Returns:
DeletionResult: An object containing the outcome of the deletion process. DeletionResult: An object containing the outcome of the deletion process.
""" """
from .utils_graph import adelete_by_entity from lightrag.utils_graph import adelete_by_entity
return await adelete_by_entity( return await adelete_by_entity(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2709,7 +2810,7 @@ class LightRAG:
Returns: Returns:
DeletionResult: An object containing the outcome of the deletion process. DeletionResult: An object containing the outcome of the deletion process.
""" """
from .utils_graph import adelete_by_relation from lightrag.utils_graph import adelete_by_relation
return await adelete_by_relation( return await adelete_by_relation(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2760,7 +2861,7 @@ class LightRAG:
self, entity_name: str, include_vector_data: bool = False self, entity_name: str, include_vector_data: bool = False
) -> dict[str, str | None | dict[str, str]]: ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity""" """Get detailed information of an entity"""
from .utils_graph import get_entity_info from lightrag.utils_graph import get_entity_info
return await get_entity_info( return await get_entity_info(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2773,7 +2874,7 @@ class LightRAG:
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
) -> dict[str, str | None | dict[str, str]]: ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of a relationship""" """Get detailed information of a relationship"""
from .utils_graph import get_relation_info from lightrag.utils_graph import get_relation_info
return await get_relation_info( return await get_relation_info(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2798,7 +2899,7 @@ class LightRAG:
Returns: Returns:
Dictionary containing updated entity information Dictionary containing updated entity information
""" """
from .utils_graph import aedit_entity from lightrag.utils_graph import aedit_entity
return await aedit_entity( return await aedit_entity(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2832,7 +2933,7 @@ class LightRAG:
Returns: Returns:
Dictionary containing updated relation information Dictionary containing updated relation information
""" """
from .utils_graph import aedit_relation from lightrag.utils_graph import aedit_relation
return await aedit_relation( return await aedit_relation(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2865,7 +2966,7 @@ class LightRAG:
Returns: Returns:
Dictionary containing created entity information Dictionary containing created entity information
""" """
from .utils_graph import acreate_entity from lightrag.utils_graph import acreate_entity
return await acreate_entity( return await acreate_entity(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2896,7 +2997,7 @@ class LightRAG:
Returns: Returns:
Dictionary containing created relation information Dictionary containing created relation information
""" """
from .utils_graph import acreate_relation from lightrag.utils_graph import acreate_relation
return await acreate_relation( return await acreate_relation(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2942,7 +3043,7 @@ class LightRAG:
Returns: Returns:
Dictionary containing the merged entity information Dictionary containing the merged entity information
""" """
from .utils_graph import amerge_entities from lightrag.utils_graph import amerge_entities
return await amerge_entities( return await amerge_entities(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -2986,7 +3087,7 @@ class LightRAG:
- table: Print formatted tables to console - table: Print formatted tables to console
include_vector_data: Whether to include data from the vector database. include_vector_data: Whether to include data from the vector database.
""" """
from .utils import aexport_data as utils_aexport_data from lightrag.utils import aexport_data as utils_aexport_data
await utils_aexport_data( await utils_aexport_data(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,

View file

@ -30,7 +30,8 @@ from .utils import (
safe_vdb_operation_with_exception, safe_vdb_operation_with_exception,
create_prefixed_exception, create_prefixed_exception,
fix_tuple_delimiter_corruption, fix_tuple_delimiter_corruption,
_convert_to_user_format, convert_to_user_format,
generate_reference_list_from_chunks,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@ -2279,6 +2280,12 @@ async def kg_query(
return_raw_data: bool = False, return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]: ) -> str | AsyncIterator[str] | dict[str, Any]:
if not query: if not query:
if return_raw_data:
return {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
if query_param.model_func: if query_param.model_func:
@ -2306,10 +2313,14 @@ async def kg_query(
cached_result = await handle_cache( cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
) )
if cached_result is not None: if (
cached_result is not None
and not return_raw_data
and not query_param.only_need_context
and not query_param.only_need_prompt
):
cached_response, _ = cached_result # Extract content, ignore timestamp cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt: return cached_response
return cached_response
hl_keywords, ll_keywords = await get_keywords_from_query( hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv
@ -2328,6 +2339,12 @@ async def kg_query(
logger.warning(f"Forced low_level_keywords to origin query: {query}") logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query] ll_keywords = [query]
else: else:
if return_raw_data:
return {
"status": "failure",
"message": "Both high_level_keywords and low_level_keywords are empty",
"data": {},
}
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
@ -2356,9 +2373,16 @@ async def kg_query(
) )
return raw_data return raw_data
else: else:
raise RuntimeError( if not context_result:
"Failed to build query context for raw data. Expected a tuple, but got a different type." return {
) "status": "failure",
"message": "Query return empty data set.",
"data": {},
}
else:
raise ValueError(
"Fail to build raw data query result. Invalid return from _build_query_context"
)
# Build context (normal flow) # Build context (normal flow)
context = await _build_query_context( context = await _build_query_context(
@ -2870,7 +2894,6 @@ async def _apply_token_truncation(
entities_context.append( entities_context.append(
{ {
"id": f"EN{i + 1}",
"entity": entity_name, "entity": entity_name,
"type": entity.get("entity_type", "UNKNOWN"), "type": entity.get("entity_type", "UNKNOWN"),
"description": entity.get("description", "UNKNOWN"), "description": entity.get("description", "UNKNOWN"),
@ -2898,7 +2921,6 @@ async def _apply_token_truncation(
relations_context.append( relations_context.append(
{ {
"id": f"RE{i + 1}",
"entity1": entity1, "entity1": entity1,
"entity2": entity2, "entity2": entity2,
"description": relation.get("description", "UNKNOWN"), "description": relation.get("description", "UNKNOWN"),
@ -2956,26 +2978,19 @@ async def _apply_token_truncation(
filtered_entities = [] filtered_entities = []
filtered_entity_id_to_original = {} filtered_entity_id_to_original = {}
if entities_context: if entities_context:
entity_name_to_id = {e["entity"]: e["id"] for e in entities_context} final_entity_names = {e["entity"] for e in entities_context}
final_entity_names = set(entity_name_to_id.keys())
seen_nodes = set() seen_nodes = set()
for entity in final_entities: for entity in final_entities:
name = entity.get("entity_name") name = entity.get("entity_name")
if name in final_entity_names and name not in seen_nodes: if name in final_entity_names and name not in seen_nodes:
entity_with_id = entity.copy() filtered_entities.append(entity)
entity_with_id["id"] = entity_name_to_id.get(name) filtered_entity_id_to_original[name] = entity
filtered_entities.append(entity_with_id)
filtered_entity_id_to_original[name] = entity_with_id
seen_nodes.add(name) seen_nodes.add(name)
filtered_relations = [] filtered_relations = []
filtered_relation_id_to_original = {} filtered_relation_id_to_original = {}
if relations_context: if relations_context:
relation_pair_to_id = { final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context}
(r["entity1"], r["entity2"]): r["id"] for r in relations_context
}
final_relation_pairs = set(relation_pair_to_id.keys())
seen_edges = set() seen_edges = set()
for relation in final_relations: for relation in final_relations:
src, tgt = relation.get("src_id"), relation.get("tgt_id") src, tgt = relation.get("src_id"), relation.get("tgt_id")
@ -2984,11 +2999,8 @@ async def _apply_token_truncation(
pair = (src, tgt) pair = (src, tgt)
if pair in final_relation_pairs and pair not in seen_edges: if pair in final_relation_pairs and pair not in seen_edges:
relation_with_id = relation.copy() filtered_relations.append(relation)
relation_with_id["id"] = relation_pair_to_id.get(pair) filtered_relation_id_to_original[pair] = relation
filtered_relations.append(relation_with_id)
filtered_relation_id_to_original[pair] = relation_with_id
seen_edges.add(pair) seen_edges.add(pair)
return { return {
@ -3121,47 +3133,23 @@ async def _build_llm_context(
""" """
tokenizer = global_config.get("tokenizer") tokenizer = global_config.get("tokenizer")
if not tokenizer: if not tokenizer:
logger.warning("No tokenizer found, building context without token limits") logger.error("Missing tokenizer, cannot build LLM context")
# Build basic context without token processing if return_raw_data:
entities_str = "\n".join( # Return empty raw data structure when no entities/relations
json.dumps(entity, ensure_ascii=False) for entity in entities_context empty_raw_data = convert_to_user_format(
) [],
relations_str = "\n".join( [],
json.dumps(relation, ensure_ascii=False) for relation in relations_context [],
) [],
query_param.mode,
text_units_context = []
for i, chunk in enumerate(merged_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
) )
empty_raw_data["status"] = "failure"
text_units_str = json.dumps(text_units_context, ensure_ascii=False) empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
return None, empty_raw_data
return f"""-----Entities(KG)----- else:
logger.error("Tokenizer not found in global configuration.")
```json return None
{entities_str}
```
-----Relationships(KG)-----
```json
{relations_str}
```
-----Document Chunks(DC)-----
```json
{text_units_str}
```
"""
# Get token limits # Get token limits
max_total_tokens = getattr( max_total_tokens = getattr(
@ -3198,9 +3186,12 @@ async def _build_llm_context(
-----Document Chunks(DC)----- -----Document Chunks(DC)-----
```json ```json
[]
``` ```
-----Refrence Document List-----
The reference documents list in Document Chunks(DC) is as follows (reference_id in square brackets):
""" """
kg_context = kg_context_template.format( kg_context = kg_context_template.format(
entities_str=entities_str, relations_str=relations_str entities_str=entities_str, relations_str=relations_str
@ -3252,13 +3243,18 @@ async def _build_llm_context(
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
) )
# Generate reference list from truncated chunks using the new common function
reference_list, truncated_chunks = generate_reference_list_from_chunks(
truncated_chunks
)
# Rebuild text_units_context with truncated chunks # Rebuild text_units_context with truncated chunks
# The actual tokens may be slightly less than available_chunk_tokens due to deduplication logic
for i, chunk in enumerate(truncated_chunks): for i, chunk in enumerate(truncated_chunks):
text_units_context.append( text_units_context.append(
{ {
"id": chunk["id"], "reference_id": chunk["reference_id"],
"content": chunk["content"], "content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
} }
) )
@ -3274,12 +3270,15 @@ async def _build_llm_context(
if not entities_context and not relations_context: if not entities_context and not relations_context:
if return_raw_data: if return_raw_data:
# Return empty raw data structure when no entities/relations # Return empty raw data structure when no entities/relations
empty_raw_data = _convert_to_user_format( empty_raw_data = convert_to_user_format(
[],
[], [],
[], [],
[], [],
query_param.mode, query_param.mode,
) )
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Query returned empty dataset."
return None, empty_raw_data return None, empty_raw_data
else: else:
return None return None
@ -3311,6 +3310,11 @@ async def _build_llm_context(
text_units_str = "\n".join( text_units_str = "\n".join(
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
) )
reference_list_str = "\n\n".join(
f"[{ref['reference_id']}] {ref['file_path']}"
for ref in reference_list
if ref["reference_id"]
)
result = f"""-----Entities(KG)----- result = f"""-----Entities(KG)-----
@ -3330,6 +3334,12 @@ async def _build_llm_context(
{text_units_str} {text_units_str}
``` ```
-----Refrence Document List-----
Document Chunks (DC) reference documents : (Each entry begins with [reference_id])
{reference_list_str}
""" """
# If final data is requested, return both context and complete data structure # If final data is requested, return both context and complete data structure
@ -3337,10 +3347,11 @@ async def _build_llm_context(
logger.debug( logger.debug(
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks" f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
) )
final_data = _convert_to_user_format( final_data = convert_to_user_format(
entities_context, entities_context,
relations_context, relations_context,
truncated_chunks, truncated_chunks,
reference_list,
query_param.mode, query_param.mode,
entity_id_to_original, entity_id_to_original,
relation_id_to_original, relation_id_to_original,
@ -3365,7 +3376,7 @@ async def _build_query_context(
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False, return_raw_data: bool = False,
) -> str | tuple[str, dict[str, Any]]: ) -> str | None | tuple[str, dict[str, Any]]:
""" """
Main query context building function using the new 4-stage architecture: Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context 1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
@ -3448,7 +3459,11 @@ async def _build_query_context(
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else [] ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Add complete metadata to raw_data # Add complete metadata to raw_data (preserve existing metadata including query_mode)
if "metadata" not in raw_data:
raw_data["metadata"] = {}
# Update keywords while preserving existing metadata
raw_data["metadata"]["keywords"] = { raw_data["metadata"]["keywords"] = {
"high_level": hl_keywords_list, "high_level": hl_keywords_list,
"low_level": ll_keywords_list, "low_level": ll_keywords_list,
@ -4092,6 +4107,18 @@ async def naive_query(
system_prompt: str | None = None, system_prompt: str | None = None,
return_raw_data: bool = False, return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]: ) -> str | AsyncIterator[str] | dict[str, Any]:
if not query:
if return_raw_data:
# Return empty raw data structure when query is empty
empty_raw_data = {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return empty_raw_data
else:
return PROMPTS["fail_response"]
if query_param.model_func: if query_param.model_func:
use_model_func = query_param.model_func use_model_func = query_param.model_func
else: else:
@ -4123,26 +4150,35 @@ async def naive_query(
return cached_response return cached_response
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer:
if return_raw_data:
# Return empty raw data structure when tokenizer is missing
empty_raw_data = {
"status": "failure",
"message": "Tokenizer not found in global configuration.",
"data": {},
}
return empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return PROMPTS["fail_response"]
chunks = await _get_vector_context(query, chunks_vdb, query_param, None) chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0: if chunks is None or len(chunks) == 0:
# Build empty raw data for consistency
empty_raw_data = {
"entities": [], # naive mode has no entities
"relationships": [], # naive mode has no relationships
"chunks": [],
"metadata": {
"query_mode": "naive",
"keywords": {"high_level": [], "low_level": []},
},
}
# If only raw data is requested, return it directly # If only raw data is requested, return it directly
if return_raw_data: if return_raw_data:
empty_raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
[], # no chunks
[], # no references
"naive",
)
empty_raw_data["message"] = "No relevant document chunks found."
return empty_raw_data return empty_raw_data
else:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
# Calculate dynamic token limit for chunks # Calculate dynamic token limit for chunks
# Get token limits from query_param (with fallback to global_config) # Get token limits from query_param (with fallback to global_config)
@ -4197,44 +4233,56 @@ async def naive_query(
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
) )
logger.info(f"Final context: {len(processed_chunks)} chunks") # Generate reference list from processed chunks using the new common function
reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks(
processed_chunks
)
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
# If only raw data is requested, return it directly # If only raw data is requested, return it directly
if return_raw_data: if return_raw_data:
# Build raw data structure for naive mode using processed chunks # Build raw data structure for naive mode using processed chunks with reference IDs
raw_data = _convert_to_user_format( raw_data = convert_to_user_format(
[], # naive mode has no entities [], # naive mode has no entities
[], # naive mode has no relationships [], # naive mode has no relationships
processed_chunks, processed_chunks_with_ref_ids,
reference_list,
"naive", "naive",
) )
# Add complete metadata for naive mode # Add complete metadata for naive mode
if "metadata" not in raw_data:
raw_data["metadata"] = {}
raw_data["metadata"]["keywords"] = { raw_data["metadata"]["keywords"] = {
"high_level": [], # naive mode has no keyword extraction "high_level": [], # naive mode has no keyword extraction
"low_level": [], # naive mode has no keyword extraction "low_level": [], # naive mode has no keyword extraction
} }
raw_data["metadata"]["processing_info"] = { raw_data["metadata"]["processing_info"] = {
"total_chunks_found": len(chunks), "total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks), "final_chunks_count": len(processed_chunks_with_ref_ids),
} }
return raw_data return raw_data
# Build text_units_context from processed chunks # Build text_units_context from processed chunks with reference IDs
text_units_context = [] text_units_context = []
for i, chunk in enumerate(processed_chunks): for i, chunk in enumerate(processed_chunks_with_ref_ids):
text_units_context.append( text_units_context.append(
{ {
"id": chunk["id"], "reference_id": chunk["reference_id"],
"content": chunk["content"], "content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
} }
) )
text_units_str = "\n".join( text_units_str = "\n".join(
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
) )
reference_list_str = "\n\n".join(
f"[{ref['reference_id']}] {ref['file_path']}"
for ref in reference_list
if ref["reference_id"]
)
if query_param.only_need_context and not query_param.only_need_prompt: if query_param.only_need_context and not query_param.only_need_prompt:
return f""" return f"""
@ -4244,6 +4292,10 @@ async def naive_query(
{text_units_str} {text_units_str}
``` ```
-----Refrence Document List-----
{reference_list_str}
""" """
user_query = ( user_query = (
"\n\n".join([query, query_param.user_prompt]) "\n\n".join([query, query_param.user_prompt])

View file

@ -2720,10 +2720,11 @@ def create_prefixed_exception(original_exception: Exception, prefix: str) -> Exc
) )
def _convert_to_user_format( def convert_to_user_format(
entities_context: list[dict], entities_context: list[dict],
relations_context: list[dict], relations_context: list[dict],
final_chunks: list[dict], chunks: list[dict],
references: list[dict],
query_mode: str, query_mode: str,
entity_id_to_original: dict = None, entity_id_to_original: dict = None,
relation_id_to_original: dict = None, relation_id_to_original: dict = None,
@ -2744,7 +2745,6 @@ def _convert_to_user_format(
# Use original database data # Use original database data
formatted_entities.append( formatted_entities.append(
{ {
"id": original_entity.get("id", "unknown"),
"entity_name": original_entity.get("entity_name", entity_name), "entity_name": original_entity.get("entity_name", entity_name),
"entity_type": original_entity.get("entity_type", "UNKNOWN"), "entity_type": original_entity.get("entity_type", "UNKNOWN"),
"description": original_entity.get("description", ""), "description": original_entity.get("description", ""),
@ -2757,7 +2757,6 @@ def _convert_to_user_format(
# Fallback to LLM context data (for backward compatibility) # Fallback to LLM context data (for backward compatibility)
formatted_entities.append( formatted_entities.append(
{ {
"id": entity.get("id", "unknown"),
"entity_name": entity_name, "entity_name": entity_name,
"entity_type": entity.get("type", "UNKNOWN"), "entity_type": entity.get("type", "UNKNOWN"),
"description": entity.get("description", ""), "description": entity.get("description", ""),
@ -2783,7 +2782,6 @@ def _convert_to_user_format(
# Use original database data # Use original database data
formatted_relationships.append( formatted_relationships.append(
{ {
"id": original_relation.get("id", "unknown"),
"src_id": original_relation.get("src_id", entity1), "src_id": original_relation.get("src_id", entity1),
"tgt_id": original_relation.get("tgt_id", entity2), "tgt_id": original_relation.get("tgt_id", entity2),
"description": original_relation.get("description", ""), "description": original_relation.get("description", ""),
@ -2798,7 +2796,6 @@ def _convert_to_user_format(
# Fallback to LLM context data (for backward compatibility) # Fallback to LLM context data (for backward compatibility)
formatted_relationships.append( formatted_relationships.append(
{ {
"id": relation.get("id", "unknown"),
"src_id": entity1, "src_id": entity1,
"tgt_id": entity2, "tgt_id": entity2,
"description": relation.get("description", ""), "description": relation.get("description", ""),
@ -2812,9 +2809,9 @@ def _convert_to_user_format(
# Convert chunks format (chunks already contain complete data) # Convert chunks format (chunks already contain complete data)
formatted_chunks = [] formatted_chunks = []
for i, chunk in enumerate(final_chunks): for i, chunk in enumerate(chunks):
chunk_data = { chunk_data = {
"id": chunk.get("id", "unknown"), "reference_id": chunk.get("reference_id", ""),
"content": chunk.get("content", ""), "content": chunk.get("content", ""),
"file_path": chunk.get("file_path", "unknown_source"), "file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk.get("chunk_id", ""), "chunk_id": chunk.get("chunk_id", ""),
@ -2822,7 +2819,7 @@ def _convert_to_user_format(
formatted_chunks.append(chunk_data) formatted_chunks.append(chunk_data)
logger.debug( logger.debug(
f"[_convert_to_user_format] Formatted {len(formatted_chunks)}/{len(final_chunks)} chunks" f"[convert_to_user_format] Formatted {len(formatted_chunks)}/{len(chunks)} chunks"
) )
# Build basic metadata (metadata details will be added by calling functions) # Build basic metadata (metadata details will be added by calling functions)
@ -2835,8 +2832,79 @@ def _convert_to_user_format(
} }
return { return {
"entities": formatted_entities, "status": "success",
"relationships": formatted_relationships, "message": "Query processed successfully",
"chunks": formatted_chunks, "data": {
"entities": formatted_entities,
"relationships": formatted_relationships,
"chunks": formatted_chunks,
"references": references,
},
"metadata": metadata, "metadata": metadata,
} }
def generate_reference_list_from_chunks(
chunks: list[dict],
) -> tuple[list[dict], list[dict]]:
"""
Generate reference list from chunks, prioritizing by occurrence frequency.
This function extracts file_paths from chunks, counts their occurrences,
sorts by frequency and first appearance order, creates reference_id mappings,
and builds a reference_list structure.
Args:
chunks: List of chunk dictionaries with file_path information
Returns:
tuple: (reference_list, updated_chunks_with_reference_ids)
- reference_list: List of dicts with reference_id and file_path
- updated_chunks_with_reference_ids: Original chunks with reference_id field added
"""
if not chunks:
return [], []
# 1. Extract all valid file_paths and count their occurrences
file_path_counts = {}
for chunk in chunks:
file_path = chunk.get("file_path", "")
if file_path and file_path != "unknown_source":
file_path_counts[file_path] = file_path_counts.get(file_path, 0) + 1
# 2. Sort file paths by frequency (descending), then by first appearance order
# Create a list of (file_path, count, first_index) tuples
file_path_with_indices = []
seen_paths = set()
for i, chunk in enumerate(chunks):
file_path = chunk.get("file_path", "")
if file_path and file_path != "unknown_source" and file_path not in seen_paths:
file_path_with_indices.append((file_path, file_path_counts[file_path], i))
seen_paths.add(file_path)
# Sort by count (descending), then by first appearance index (ascending)
sorted_file_paths = sorted(file_path_with_indices, key=lambda x: (-x[1], x[2]))
unique_file_paths = [item[0] for item in sorted_file_paths]
# 3. Create mapping from file_path to reference_id (prioritized by frequency)
file_path_to_ref_id = {}
for i, file_path in enumerate(unique_file_paths):
file_path_to_ref_id[file_path] = str(i + 1)
# 4. Add reference_id field to each chunk
updated_chunks = []
for chunk in chunks:
chunk_copy = chunk.copy()
file_path = chunk_copy.get("file_path", "")
if file_path and file_path != "unknown_source":
chunk_copy["reference_id"] = file_path_to_ref_id[file_path]
else:
chunk_copy["reference_id"] = ""
updated_chunks.append(chunk_copy)
# 5. Build reference_list
reference_list = []
for i, file_path in enumerate(unique_file_paths):
reference_list.append({"reference_id": str(i + 1), "file_path": file_path})
return reference_list, updated_chunks

View file

@ -2,6 +2,11 @@
""" """
Test script: Demonstrates usage of aquery_data FastAPI endpoint Test script: Demonstrates usage of aquery_data FastAPI endpoint
Query content: Who is the author of LightRAG Query content: Who is the author of LightRAG
Updated to handle the new data format where:
- Response includes status, message, data, and metadata fields at top level
- Actual query results (entities, relationships, chunks, references) are nested under 'data' field
- Includes backward compatibility with legacy format
""" """
import requests import requests
@ -80,17 +85,37 @@ def test_aquery_data_endpoint():
def print_query_results(data: Dict[str, Any]): def print_query_results(data: Dict[str, Any]):
"""Format and print query results""" """Format and print query results"""
entities = data.get("entities", []) # Check for new data format with status and message
relationships = data.get("relationships", []) status = data.get("status", "unknown")
chunks = data.get("chunks", []) message = data.get("message", "")
metadata = data.get("metadata", {})
print(f"\n📋 Query Status: {status}")
if message:
print(f"📋 Message: {message}")
# Handle new nested data format
query_data = data.get("data", {})
# Fallback to old format if new format is not present
if not query_data and any(
key in data for key in ["entities", "relationships", "chunks"]
):
print(" (Using legacy data format)")
query_data = data
entities = query_data.get("entities", [])
relationships = query_data.get("relationships", [])
chunks = query_data.get("chunks", [])
references = query_data.get("references", [])
print("\n📊 Query result statistics:") print("\n📊 Query result statistics:")
print(f" Entity count: {len(entities)}") print(f" Entity count: {len(entities)}")
print(f" Relationship count: {len(relationships)}") print(f" Relationship count: {len(relationships)}")
print(f" Text chunk count: {len(chunks)}") print(f" Text chunk count: {len(chunks)}")
print(f" Reference count: {len(references)}")
# Print metadata # Print metadata (now at top level in new format)
metadata = data.get("metadata", {})
if metadata: if metadata:
print("\n🔍 Query metadata:") print("\n🔍 Query metadata:")
print(f" Query mode: {metadata.get('query_mode', 'unknown')}") print(f" Query mode: {metadata.get('query_mode', 'unknown')}")
@ -118,12 +143,14 @@ def print_query_results(data: Dict[str, Any]):
entity_type = entity.get("entity_type", "Unknown") entity_type = entity.get("entity_type", "Unknown")
description = entity.get("description", "No description") description = entity.get("description", "No description")
file_path = entity.get("file_path", "Unknown source") file_path = entity.get("file_path", "Unknown source")
reference_id = entity.get("reference_id", "No reference")
print(f" {i+1}. {entity_name} ({entity_type})") print(f" {i+1}. {entity_name} ({entity_type})")
print( print(
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}" f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
) )
print(f" Source: {file_path}") print(f" Source: {file_path}")
print(f" Reference ID: {reference_id}")
print() print()
# Print relationship information # Print relationship information
@ -135,6 +162,7 @@ def print_query_results(data: Dict[str, Any]):
description = rel.get("description", "No description") description = rel.get("description", "No description")
keywords = rel.get("keywords", "No keywords") keywords = rel.get("keywords", "No keywords")
file_path = rel.get("file_path", "Unknown source") file_path = rel.get("file_path", "Unknown source")
reference_id = rel.get("reference_id", "No reference")
print(f" {i+1}. {src}{tgt}") print(f" {i+1}. {src}{tgt}")
print(f" Keywords: {keywords}") print(f" Keywords: {keywords}")
@ -142,6 +170,7 @@ def print_query_results(data: Dict[str, Any]):
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}" f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
) )
print(f" Source: {file_path}") print(f" Source: {file_path}")
print(f" Reference ID: {reference_id}")
print() print()
# Print text chunk information # Print text chunk information
@ -151,14 +180,26 @@ def print_query_results(data: Dict[str, Any]):
content = chunk.get("content", "No content") content = chunk.get("content", "No content")
file_path = chunk.get("file_path", "Unknown source") file_path = chunk.get("file_path", "Unknown source")
chunk_id = chunk.get("chunk_id", "Unknown ID") chunk_id = chunk.get("chunk_id", "Unknown ID")
reference_id = chunk.get("reference_id", "No reference")
print(f" {i+1}. Text chunk ID: {chunk_id}") print(f" {i+1}. Text chunk ID: {chunk_id}")
print(f" Source: {file_path}") print(f" Source: {file_path}")
print(f" Reference ID: {reference_id}")
print( print(
f" Content: {content[:200]}{'...' if len(content) > 200 else ''}" f" Content: {content[:200]}{'...' if len(content) > 200 else ''}"
) )
print() print()
# Print references information (new in updated format)
if references:
print("📚 References:")
for i, ref in enumerate(references):
reference_id = ref.get("reference_id", "Unknown ID")
file_path = ref.get("file_path", "Unknown source")
print(f" {i+1}. Reference ID: {reference_id}")
print(f" File Path: {file_path}")
print()
print("=" * 60) print("=" * 60)