feat: simplify citations, add reference merging, and restructure API response format
This commit is contained in:
parent
18968c6b6b
commit
5eb4a4b799
6 changed files with 452 additions and 187 deletions
|
|
@ -2,7 +2,9 @@
|
|||
LightRAG FastAPI Server
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
import os
|
||||
import logging
|
||||
import logging.config
|
||||
|
|
@ -245,6 +247,35 @@ def create_app(args):
|
|||
|
||||
app = FastAPI(**app_kwargs)
|
||||
|
||||
# Add custom validation error handler for /query/data endpoint
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
):
|
||||
# Check if this is a request to /query/data endpoint
|
||||
if request.url.path.endswith("/query/data"):
|
||||
# Extract error details
|
||||
error_details = []
|
||||
for error in exc.errors():
|
||||
field_path = " -> ".join(str(loc) for loc in error["loc"])
|
||||
error_details.append(f"{field_path}: {error['msg']}")
|
||||
|
||||
error_message = "; ".join(error_details)
|
||||
|
||||
# Return in the expected format for /query/data
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"status": "failure",
|
||||
"message": f"Validation error: {error_message}",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
else:
|
||||
# For other endpoints, return the default FastAPI validation error
|
||||
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
||||
|
||||
def get_cors_origins():
|
||||
"""Get allowed origins from global_args
|
||||
Returns a list of allowed origins, defaults to ["*"] if not set
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Literal, Optional
|
|||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from lightrag.base import QueryParam
|
||||
from ..utils_api import get_combined_auth_dependency
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from ascii_colors import trace_exception
|
||||
|
|
@ -18,7 +18,7 @@ router = APIRouter(tags=["query"])
|
|||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str = Field(
|
||||
min_length=1,
|
||||
min_length=3,
|
||||
description="The query text",
|
||||
)
|
||||
|
||||
|
|
@ -135,14 +135,10 @@ class QueryResponse(BaseModel):
|
|||
|
||||
|
||||
class QueryDataResponse(BaseModel):
|
||||
entities: List[Dict[str, Any]] = Field(
|
||||
description="Retrieved entities from knowledge graph"
|
||||
)
|
||||
relationships: List[Dict[str, Any]] = Field(
|
||||
description="Retrieved relationships from knowledge graph"
|
||||
)
|
||||
chunks: List[Dict[str, Any]] = Field(
|
||||
description="Retrieved text chunks from documents"
|
||||
status: str = Field(description="Query execution status")
|
||||
message: str = Field(description="Status message")
|
||||
data: Dict[str, Any] = Field(
|
||||
description="Query result data containing entities, relationships, chunks, and references"
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(
|
||||
description="Query metadata including mode, keywords, and processing information"
|
||||
|
|
@ -253,8 +249,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
request (QueryRequest): The request object containing the query parameters.
|
||||
|
||||
Returns:
|
||||
QueryDataResponse: A Pydantic model containing structured data with entities,
|
||||
relationships, chunks, and metadata.
|
||||
QueryDataResponse: A Pydantic model containing structured data with status,
|
||||
message, data (entities, relationships, chunks, references),
|
||||
and metadata.
|
||||
|
||||
Raises:
|
||||
HTTPException: Raised when an error occurs during the request handling process,
|
||||
|
|
@ -264,40 +261,15 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
param = request.to_query_params(False) # No streaming for data endpoint
|
||||
response = await rag.aquery_data(request.query, param=param)
|
||||
|
||||
# The aquery_data method returns a dict with entities, relationships, chunks, and metadata
|
||||
# aquery_data returns the new format with status, message, data, and metadata
|
||||
if isinstance(response, dict):
|
||||
# Ensure all required fields exist and are lists/dicts
|
||||
entities = response.get("entities", [])
|
||||
relationships = response.get("relationships", [])
|
||||
chunks = response.get("chunks", [])
|
||||
metadata = response.get("metadata", {})
|
||||
|
||||
# Validate data types
|
||||
if not isinstance(entities, list):
|
||||
entities = []
|
||||
if not isinstance(relationships, list):
|
||||
relationships = []
|
||||
if not isinstance(chunks, list):
|
||||
chunks = []
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
return QueryDataResponse(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
chunks=chunks,
|
||||
metadata=metadata,
|
||||
)
|
||||
return QueryDataResponse(**response)
|
||||
else:
|
||||
# Fallback for unexpected response format
|
||||
# Handle unexpected response format
|
||||
return QueryDataResponse(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
chunks=[],
|
||||
metadata={
|
||||
"error": "Unexpected response format",
|
||||
"raw_response": str(response),
|
||||
},
|
||||
status="failure",
|
||||
message="Invalid response type",
|
||||
data={},
|
||||
)
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ from lightrag.kg.shared_storage import (
|
|||
get_data_init_lock,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
from lightrag.base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
|
|
@ -72,8 +72,8 @@ from .base import (
|
|||
DeletionResult,
|
||||
OllamaServerInfos,
|
||||
)
|
||||
from .namespace import NameSpace
|
||||
from .operate import (
|
||||
from lightrag.namespace import NameSpace
|
||||
from lightrag.operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
merge_nodes_and_edges,
|
||||
|
|
@ -81,8 +81,8 @@ from .operate import (
|
|||
naive_query,
|
||||
_rebuild_knowledge_from_chunks,
|
||||
)
|
||||
from .constants import GRAPH_FIELD_SEP
|
||||
from .utils import (
|
||||
from lightrag.constants import GRAPH_FIELD_SEP
|
||||
from lightrag.utils import (
|
||||
Tokenizer,
|
||||
TiktokenTokenizer,
|
||||
EmbeddingFunc,
|
||||
|
|
@ -94,9 +94,10 @@ from .utils import (
|
|||
sanitize_text_for_encoding,
|
||||
check_storage_env_vars,
|
||||
generate_track_id,
|
||||
convert_to_user_format,
|
||||
logger,
|
||||
)
|
||||
from .types import KnowledgeGraph
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
|
|
@ -2127,11 +2128,104 @@ class LightRAG:
|
|||
returning the final processed entities, relationships, and chunks data that would be sent to LLM.
|
||||
|
||||
Args:
|
||||
query: Query text.
|
||||
param: Query parameters (same as aquery).
|
||||
query: Query text for retrieval.
|
||||
param: Query parameters controlling retrieval behavior (same as aquery).
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Structured data result with entities, relationships, chunks, and metadata
|
||||
dict[str, Any]: Structured data result in the following format:
|
||||
|
||||
**Success Response:**
|
||||
```python
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Query executed successfully",
|
||||
"data": {
|
||||
"entities": [
|
||||
{
|
||||
"entity_name": str, # Entity identifier
|
||||
"entity_type": str, # Entity category/type
|
||||
"description": str, # Entity description
|
||||
"source_id": str, # Source chunk references
|
||||
"file_path": str, # Origin file path
|
||||
"created_at": str, # Creation timestamp
|
||||
"reference_id": str # Reference identifier for citations
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"src_id": str, # Source entity name
|
||||
"tgt_id": str, # Target entity name
|
||||
"description": str, # Relationship description
|
||||
"keywords": str, # Relationship keywords
|
||||
"weight": float, # Relationship strength
|
||||
"source_id": str, # Source chunk references
|
||||
"file_path": str, # Origin file path
|
||||
"created_at": str, # Creation timestamp
|
||||
"reference_id": str # Reference identifier for citations
|
||||
}
|
||||
],
|
||||
"chunks": [
|
||||
{
|
||||
"content": str, # Document chunk content
|
||||
"file_path": str, # Origin file path
|
||||
"chunk_id": str, # Unique chunk identifier
|
||||
"reference_id": str # Reference identifier for citations
|
||||
}
|
||||
],
|
||||
"references": [
|
||||
{
|
||||
"reference_id": str, # Reference identifier
|
||||
"file_path": str # Corresponding file path
|
||||
}
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"query_mode": str, # Query mode used ("local", "global", "hybrid", "mix", "naive", "bypass")
|
||||
"keywords": {
|
||||
"high_level": List[str], # High-level keywords extracted
|
||||
"low_level": List[str] # Low-level keywords extracted
|
||||
},
|
||||
"processing_info": {
|
||||
"total_entities_found": int, # Total entities before truncation
|
||||
"total_relations_found": int, # Total relations before truncation
|
||||
"entities_after_truncation": int, # Entities after token truncation
|
||||
"relations_after_truncation": int, # Relations after token truncation
|
||||
"merged_chunks_count": int, # Chunks before final processing
|
||||
"final_chunks_count": int # Final chunks in result
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Query Mode Differences:**
|
||||
- **local**: Focuses on entities and their related chunks based on low-level keywords
|
||||
- **global**: Focuses on relationships and their connected entities based on high-level keywords
|
||||
- **hybrid**: Combines local and global results using round-robin merging
|
||||
- **mix**: Includes knowledge graph data plus vector-retrieved document chunks
|
||||
- **naive**: Only vector-retrieved chunks, entities and relationships arrays are empty
|
||||
- **bypass**: All data arrays are empty, used for direct LLM queries
|
||||
|
||||
** processing_info is optional and may not be present in all responses, especially when query result is empty**
|
||||
|
||||
**Failure Response:**
|
||||
```python
|
||||
{
|
||||
"status": "failure",
|
||||
"message": str, # Error description
|
||||
"data": {} # Empty data object
|
||||
}
|
||||
```
|
||||
|
||||
**Common Failure Cases:**
|
||||
- Empty query string
|
||||
- Both high-level and low-level keywords are empty
|
||||
- Query returns empty dataset
|
||||
- Missing tokenizer or system configuration errors
|
||||
|
||||
Note:
|
||||
The function adapts to the new data format from convert_to_user_format where
|
||||
actual data is nested under the 'data' field, with 'status' and 'message'
|
||||
fields at the top level.
|
||||
"""
|
||||
global_config = asdict(self)
|
||||
|
||||
|
|
@ -2163,23 +2257,30 @@ class LightRAG:
|
|||
)
|
||||
elif param.mode == "bypass":
|
||||
logger.debug("[aquery_data] Using bypass mode")
|
||||
# bypass mode returns empty data
|
||||
final_data = {
|
||||
"entities": [],
|
||||
"relationships": [],
|
||||
"chunks": [],
|
||||
"metadata": {
|
||||
"query_mode": "bypass",
|
||||
"keywords": {"high_level": [], "low_level": []},
|
||||
},
|
||||
}
|
||||
# bypass mode returns empty data using convert_to_user_format
|
||||
final_data = convert_to_user_format(
|
||||
[], # no entities
|
||||
[], # no relationships
|
||||
[], # no chunks
|
||||
[], # no references
|
||||
"bypass",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
||||
# Log final result counts
|
||||
entities_count = len(final_data.get("entities", []))
|
||||
relationships_count = len(final_data.get("relationships", []))
|
||||
chunks_count = len(final_data.get("chunks", []))
|
||||
# Log final result counts - adapt to new data format from convert_to_user_format
|
||||
if isinstance(final_data, dict) and "data" in final_data:
|
||||
# New format: data is nested under 'data' field
|
||||
data_section = final_data["data"]
|
||||
entities_count = len(data_section.get("entities", []))
|
||||
relationships_count = len(data_section.get("relationships", []))
|
||||
chunks_count = len(data_section.get("chunks", []))
|
||||
else:
|
||||
# Fallback for other formats
|
||||
entities_count = len(final_data.get("entities", []))
|
||||
relationships_count = len(final_data.get("relationships", []))
|
||||
chunks_count = len(final_data.get("chunks", []))
|
||||
|
||||
logger.debug(
|
||||
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
|
||||
)
|
||||
|
|
@ -2676,7 +2777,7 @@ class LightRAG:
|
|||
Returns:
|
||||
DeletionResult: An object containing the outcome of the deletion process.
|
||||
"""
|
||||
from .utils_graph import adelete_by_entity
|
||||
from lightrag.utils_graph import adelete_by_entity
|
||||
|
||||
return await adelete_by_entity(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2709,7 +2810,7 @@ class LightRAG:
|
|||
Returns:
|
||||
DeletionResult: An object containing the outcome of the deletion process.
|
||||
"""
|
||||
from .utils_graph import adelete_by_relation
|
||||
from lightrag.utils_graph import adelete_by_relation
|
||||
|
||||
return await adelete_by_relation(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2760,7 +2861,7 @@ class LightRAG:
|
|||
self, entity_name: str, include_vector_data: bool = False
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of an entity"""
|
||||
from .utils_graph import get_entity_info
|
||||
from lightrag.utils_graph import get_entity_info
|
||||
|
||||
return await get_entity_info(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2773,7 +2874,7 @@ class LightRAG:
|
|||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of a relationship"""
|
||||
from .utils_graph import get_relation_info
|
||||
from lightrag.utils_graph import get_relation_info
|
||||
|
||||
return await get_relation_info(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2798,7 +2899,7 @@ class LightRAG:
|
|||
Returns:
|
||||
Dictionary containing updated entity information
|
||||
"""
|
||||
from .utils_graph import aedit_entity
|
||||
from lightrag.utils_graph import aedit_entity
|
||||
|
||||
return await aedit_entity(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2832,7 +2933,7 @@ class LightRAG:
|
|||
Returns:
|
||||
Dictionary containing updated relation information
|
||||
"""
|
||||
from .utils_graph import aedit_relation
|
||||
from lightrag.utils_graph import aedit_relation
|
||||
|
||||
return await aedit_relation(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2865,7 +2966,7 @@ class LightRAG:
|
|||
Returns:
|
||||
Dictionary containing created entity information
|
||||
"""
|
||||
from .utils_graph import acreate_entity
|
||||
from lightrag.utils_graph import acreate_entity
|
||||
|
||||
return await acreate_entity(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2896,7 +2997,7 @@ class LightRAG:
|
|||
Returns:
|
||||
Dictionary containing created relation information
|
||||
"""
|
||||
from .utils_graph import acreate_relation
|
||||
from lightrag.utils_graph import acreate_relation
|
||||
|
||||
return await acreate_relation(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2942,7 +3043,7 @@ class LightRAG:
|
|||
Returns:
|
||||
Dictionary containing the merged entity information
|
||||
"""
|
||||
from .utils_graph import amerge_entities
|
||||
from lightrag.utils_graph import amerge_entities
|
||||
|
||||
return await amerge_entities(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2986,7 +3087,7 @@ class LightRAG:
|
|||
- table: Print formatted tables to console
|
||||
include_vector_data: Whether to include data from the vector database.
|
||||
"""
|
||||
from .utils import aexport_data as utils_aexport_data
|
||||
from lightrag.utils import aexport_data as utils_aexport_data
|
||||
|
||||
await utils_aexport_data(
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ from .utils import (
|
|||
safe_vdb_operation_with_exception,
|
||||
create_prefixed_exception,
|
||||
fix_tuple_delimiter_corruption,
|
||||
_convert_to_user_format,
|
||||
convert_to_user_format,
|
||||
generate_reference_list_from_chunks,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
|
|
@ -2279,6 +2280,12 @@ async def kg_query(
|
|||
return_raw_data: bool = False,
|
||||
) -> str | AsyncIterator[str] | dict[str, Any]:
|
||||
if not query:
|
||||
if return_raw_data:
|
||||
return {
|
||||
"status": "failure",
|
||||
"message": "Query string is empty.",
|
||||
"data": {},
|
||||
}
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
if query_param.model_func:
|
||||
|
|
@ -2306,10 +2313,14 @@ async def kg_query(
|
|||
cached_result = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
if cached_result is not None:
|
||||
if (
|
||||
cached_result is not None
|
||||
and not return_raw_data
|
||||
and not query_param.only_need_context
|
||||
and not query_param.only_need_prompt
|
||||
):
|
||||
cached_response, _ = cached_result # Extract content, ignore timestamp
|
||||
if not query_param.only_need_context and not query_param.only_need_prompt:
|
||||
return cached_response
|
||||
return cached_response
|
||||
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
|
|
@ -2328,6 +2339,12 @@ async def kg_query(
|
|||
logger.warning(f"Forced low_level_keywords to origin query: {query}")
|
||||
ll_keywords = [query]
|
||||
else:
|
||||
if return_raw_data:
|
||||
return {
|
||||
"status": "failure",
|
||||
"message": "Both high_level_keywords and low_level_keywords are empty",
|
||||
"data": {},
|
||||
}
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||||
|
|
@ -2356,9 +2373,16 @@ async def kg_query(
|
|||
)
|
||||
return raw_data
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Failed to build query context for raw data. Expected a tuple, but got a different type."
|
||||
)
|
||||
if not context_result:
|
||||
return {
|
||||
"status": "failure",
|
||||
"message": "Query return empty data set.",
|
||||
"data": {},
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
"Fail to build raw data query result. Invalid return from _build_query_context"
|
||||
)
|
||||
|
||||
# Build context (normal flow)
|
||||
context = await _build_query_context(
|
||||
|
|
@ -2870,7 +2894,6 @@ async def _apply_token_truncation(
|
|||
|
||||
entities_context.append(
|
||||
{
|
||||
"id": f"EN{i + 1}",
|
||||
"entity": entity_name,
|
||||
"type": entity.get("entity_type", "UNKNOWN"),
|
||||
"description": entity.get("description", "UNKNOWN"),
|
||||
|
|
@ -2898,7 +2921,6 @@ async def _apply_token_truncation(
|
|||
|
||||
relations_context.append(
|
||||
{
|
||||
"id": f"RE{i + 1}",
|
||||
"entity1": entity1,
|
||||
"entity2": entity2,
|
||||
"description": relation.get("description", "UNKNOWN"),
|
||||
|
|
@ -2956,26 +2978,19 @@ async def _apply_token_truncation(
|
|||
filtered_entities = []
|
||||
filtered_entity_id_to_original = {}
|
||||
if entities_context:
|
||||
entity_name_to_id = {e["entity"]: e["id"] for e in entities_context}
|
||||
final_entity_names = set(entity_name_to_id.keys())
|
||||
final_entity_names = {e["entity"] for e in entities_context}
|
||||
seen_nodes = set()
|
||||
for entity in final_entities:
|
||||
name = entity.get("entity_name")
|
||||
if name in final_entity_names and name not in seen_nodes:
|
||||
entity_with_id = entity.copy()
|
||||
entity_with_id["id"] = entity_name_to_id.get(name)
|
||||
|
||||
filtered_entities.append(entity_with_id)
|
||||
filtered_entity_id_to_original[name] = entity_with_id
|
||||
filtered_entities.append(entity)
|
||||
filtered_entity_id_to_original[name] = entity
|
||||
seen_nodes.add(name)
|
||||
|
||||
filtered_relations = []
|
||||
filtered_relation_id_to_original = {}
|
||||
if relations_context:
|
||||
relation_pair_to_id = {
|
||||
(r["entity1"], r["entity2"]): r["id"] for r in relations_context
|
||||
}
|
||||
final_relation_pairs = set(relation_pair_to_id.keys())
|
||||
final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context}
|
||||
seen_edges = set()
|
||||
for relation in final_relations:
|
||||
src, tgt = relation.get("src_id"), relation.get("tgt_id")
|
||||
|
|
@ -2984,11 +2999,8 @@ async def _apply_token_truncation(
|
|||
|
||||
pair = (src, tgt)
|
||||
if pair in final_relation_pairs and pair not in seen_edges:
|
||||
relation_with_id = relation.copy()
|
||||
relation_with_id["id"] = relation_pair_to_id.get(pair)
|
||||
|
||||
filtered_relations.append(relation_with_id)
|
||||
filtered_relation_id_to_original[pair] = relation_with_id
|
||||
filtered_relations.append(relation)
|
||||
filtered_relation_id_to_original[pair] = relation
|
||||
seen_edges.add(pair)
|
||||
|
||||
return {
|
||||
|
|
@ -3121,47 +3133,23 @@ async def _build_llm_context(
|
|||
"""
|
||||
tokenizer = global_config.get("tokenizer")
|
||||
if not tokenizer:
|
||||
logger.warning("No tokenizer found, building context without token limits")
|
||||
logger.error("Missing tokenizer, cannot build LLM context")
|
||||
|
||||
# Build basic context without token processing
|
||||
entities_str = "\n".join(
|
||||
json.dumps(entity, ensure_ascii=False) for entity in entities_context
|
||||
)
|
||||
relations_str = "\n".join(
|
||||
json.dumps(relation, ensure_ascii=False) for relation in relations_context
|
||||
)
|
||||
|
||||
text_units_context = []
|
||||
for i, chunk in enumerate(merged_chunks):
|
||||
text_units_context.append(
|
||||
{
|
||||
"id": i + 1,
|
||||
"content": chunk["content"],
|
||||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
}
|
||||
if return_raw_data:
|
||||
# Return empty raw data structure when no entities/relations
|
||||
empty_raw_data = convert_to_user_format(
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
query_param.mode,
|
||||
)
|
||||
|
||||
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||||
|
||||
return f"""-----Entities(KG)-----
|
||||
|
||||
```json
|
||||
{entities_str}
|
||||
```
|
||||
|
||||
-----Relationships(KG)-----
|
||||
|
||||
```json
|
||||
{relations_str}
|
||||
```
|
||||
|
||||
-----Document Chunks(DC)-----
|
||||
|
||||
```json
|
||||
{text_units_str}
|
||||
```
|
||||
|
||||
"""
|
||||
empty_raw_data["status"] = "failure"
|
||||
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
|
||||
return None, empty_raw_data
|
||||
else:
|
||||
logger.error("Tokenizer not found in global configuration.")
|
||||
return None
|
||||
|
||||
# Get token limits
|
||||
max_total_tokens = getattr(
|
||||
|
|
@ -3198,9 +3186,12 @@ async def _build_llm_context(
|
|||
-----Document Chunks(DC)-----
|
||||
|
||||
```json
|
||||
[]
|
||||
```
|
||||
|
||||
-----Refrence Document List-----
|
||||
|
||||
The reference documents list in Document Chunks(DC) is as follows (reference_id in square brackets):
|
||||
|
||||
"""
|
||||
kg_context = kg_context_template.format(
|
||||
entities_str=entities_str, relations_str=relations_str
|
||||
|
|
@ -3252,13 +3243,18 @@ async def _build_llm_context(
|
|||
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
|
||||
)
|
||||
|
||||
# Generate reference list from truncated chunks using the new common function
|
||||
reference_list, truncated_chunks = generate_reference_list_from_chunks(
|
||||
truncated_chunks
|
||||
)
|
||||
|
||||
# Rebuild text_units_context with truncated chunks
|
||||
# The actual tokens may be slightly less than available_chunk_tokens due to deduplication logic
|
||||
for i, chunk in enumerate(truncated_chunks):
|
||||
text_units_context.append(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"reference_id": chunk["reference_id"],
|
||||
"content": chunk["content"],
|
||||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -3274,12 +3270,15 @@ async def _build_llm_context(
|
|||
if not entities_context and not relations_context:
|
||||
if return_raw_data:
|
||||
# Return empty raw data structure when no entities/relations
|
||||
empty_raw_data = _convert_to_user_format(
|
||||
empty_raw_data = convert_to_user_format(
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
query_param.mode,
|
||||
)
|
||||
empty_raw_data["status"] = "failure"
|
||||
empty_raw_data["message"] = "Query returned empty dataset."
|
||||
return None, empty_raw_data
|
||||
else:
|
||||
return None
|
||||
|
|
@ -3311,6 +3310,11 @@ async def _build_llm_context(
|
|||
text_units_str = "\n".join(
|
||||
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
|
||||
)
|
||||
reference_list_str = "\n\n".join(
|
||||
f"[{ref['reference_id']}] {ref['file_path']}"
|
||||
for ref in reference_list
|
||||
if ref["reference_id"]
|
||||
)
|
||||
|
||||
result = f"""-----Entities(KG)-----
|
||||
|
||||
|
|
@ -3330,6 +3334,12 @@ async def _build_llm_context(
|
|||
{text_units_str}
|
||||
```
|
||||
|
||||
-----Refrence Document List-----
|
||||
|
||||
Document Chunks (DC) reference documents : (Each entry begins with [reference_id])
|
||||
|
||||
{reference_list_str}
|
||||
|
||||
"""
|
||||
|
||||
# If final data is requested, return both context and complete data structure
|
||||
|
|
@ -3337,10 +3347,11 @@ async def _build_llm_context(
|
|||
logger.debug(
|
||||
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
|
||||
)
|
||||
final_data = _convert_to_user_format(
|
||||
final_data = convert_to_user_format(
|
||||
entities_context,
|
||||
relations_context,
|
||||
truncated_chunks,
|
||||
reference_list,
|
||||
query_param.mode,
|
||||
entity_id_to_original,
|
||||
relation_id_to_original,
|
||||
|
|
@ -3365,7 +3376,7 @@ async def _build_query_context(
|
|||
query_param: QueryParam,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
return_raw_data: bool = False,
|
||||
) -> str | tuple[str, dict[str, Any]]:
|
||||
) -> str | None | tuple[str, dict[str, Any]]:
|
||||
"""
|
||||
Main query context building function using the new 4-stage architecture:
|
||||
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
|
||||
|
|
@ -3448,7 +3459,11 @@ async def _build_query_context(
|
|||
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
|
||||
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
|
||||
|
||||
# Add complete metadata to raw_data
|
||||
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
|
||||
if "metadata" not in raw_data:
|
||||
raw_data["metadata"] = {}
|
||||
|
||||
# Update keywords while preserving existing metadata
|
||||
raw_data["metadata"]["keywords"] = {
|
||||
"high_level": hl_keywords_list,
|
||||
"low_level": ll_keywords_list,
|
||||
|
|
@ -4092,6 +4107,18 @@ async def naive_query(
|
|||
system_prompt: str | None = None,
|
||||
return_raw_data: bool = False,
|
||||
) -> str | AsyncIterator[str] | dict[str, Any]:
|
||||
if not query:
|
||||
if return_raw_data:
|
||||
# Return empty raw data structure when query is empty
|
||||
empty_raw_data = {
|
||||
"status": "failure",
|
||||
"message": "Query string is empty.",
|
||||
"data": {},
|
||||
}
|
||||
return empty_raw_data
|
||||
else:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
if query_param.model_func:
|
||||
use_model_func = query_param.model_func
|
||||
else:
|
||||
|
|
@ -4123,26 +4150,35 @@ async def naive_query(
|
|||
return cached_response
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
if not tokenizer:
|
||||
if return_raw_data:
|
||||
# Return empty raw data structure when tokenizer is missing
|
||||
empty_raw_data = {
|
||||
"status": "failure",
|
||||
"message": "Tokenizer not found in global configuration.",
|
||||
"data": {},
|
||||
}
|
||||
return empty_raw_data
|
||||
else:
|
||||
logger.error("Tokenizer not found in global configuration.")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
|
||||
|
||||
if chunks is None or len(chunks) == 0:
|
||||
# Build empty raw data for consistency
|
||||
empty_raw_data = {
|
||||
"entities": [], # naive mode has no entities
|
||||
"relationships": [], # naive mode has no relationships
|
||||
"chunks": [],
|
||||
"metadata": {
|
||||
"query_mode": "naive",
|
||||
"keywords": {"high_level": [], "low_level": []},
|
||||
},
|
||||
}
|
||||
|
||||
# If only raw data is requested, return it directly
|
||||
if return_raw_data:
|
||||
empty_raw_data = convert_to_user_format(
|
||||
[], # naive mode has no entities
|
||||
[], # naive mode has no relationships
|
||||
[], # no chunks
|
||||
[], # no references
|
||||
"naive",
|
||||
)
|
||||
empty_raw_data["message"] = "No relevant document chunks found."
|
||||
return empty_raw_data
|
||||
|
||||
return PROMPTS["fail_response"]
|
||||
else:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# Calculate dynamic token limit for chunks
|
||||
# Get token limits from query_param (with fallback to global_config)
|
||||
|
|
@ -4197,44 +4233,56 @@ async def naive_query(
|
|||
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
|
||||
)
|
||||
|
||||
logger.info(f"Final context: {len(processed_chunks)} chunks")
|
||||
# Generate reference list from processed chunks using the new common function
|
||||
reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks(
|
||||
processed_chunks
|
||||
)
|
||||
|
||||
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
|
||||
|
||||
# If only raw data is requested, return it directly
|
||||
if return_raw_data:
|
||||
# Build raw data structure for naive mode using processed chunks
|
||||
raw_data = _convert_to_user_format(
|
||||
# Build raw data structure for naive mode using processed chunks with reference IDs
|
||||
raw_data = convert_to_user_format(
|
||||
[], # naive mode has no entities
|
||||
[], # naive mode has no relationships
|
||||
processed_chunks,
|
||||
processed_chunks_with_ref_ids,
|
||||
reference_list,
|
||||
"naive",
|
||||
)
|
||||
|
||||
# Add complete metadata for naive mode
|
||||
if "metadata" not in raw_data:
|
||||
raw_data["metadata"] = {}
|
||||
raw_data["metadata"]["keywords"] = {
|
||||
"high_level": [], # naive mode has no keyword extraction
|
||||
"low_level": [], # naive mode has no keyword extraction
|
||||
}
|
||||
raw_data["metadata"]["processing_info"] = {
|
||||
"total_chunks_found": len(chunks),
|
||||
"final_chunks_count": len(processed_chunks),
|
||||
"final_chunks_count": len(processed_chunks_with_ref_ids),
|
||||
}
|
||||
|
||||
return raw_data
|
||||
|
||||
# Build text_units_context from processed chunks
|
||||
# Build text_units_context from processed chunks with reference IDs
|
||||
text_units_context = []
|
||||
for i, chunk in enumerate(processed_chunks):
|
||||
for i, chunk in enumerate(processed_chunks_with_ref_ids):
|
||||
text_units_context.append(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"reference_id": chunk["reference_id"],
|
||||
"content": chunk["content"],
|
||||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
}
|
||||
)
|
||||
|
||||
text_units_str = "\n".join(
|
||||
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
|
||||
)
|
||||
reference_list_str = "\n\n".join(
|
||||
f"[{ref['reference_id']}] {ref['file_path']}"
|
||||
for ref in reference_list
|
||||
if ref["reference_id"]
|
||||
)
|
||||
|
||||
if query_param.only_need_context and not query_param.only_need_prompt:
|
||||
return f"""
|
||||
|
|
@ -4244,6 +4292,10 @@ async def naive_query(
|
|||
{text_units_str}
|
||||
```
|
||||
|
||||
-----Refrence Document List-----
|
||||
|
||||
{reference_list_str}
|
||||
|
||||
"""
|
||||
user_query = (
|
||||
"\n\n".join([query, query_param.user_prompt])
|
||||
|
|
|
|||
|
|
@ -2720,10 +2720,11 @@ def create_prefixed_exception(original_exception: Exception, prefix: str) -> Exc
|
|||
)
|
||||
|
||||
|
||||
def _convert_to_user_format(
|
||||
def convert_to_user_format(
|
||||
entities_context: list[dict],
|
||||
relations_context: list[dict],
|
||||
final_chunks: list[dict],
|
||||
chunks: list[dict],
|
||||
references: list[dict],
|
||||
query_mode: str,
|
||||
entity_id_to_original: dict = None,
|
||||
relation_id_to_original: dict = None,
|
||||
|
|
@ -2744,7 +2745,6 @@ def _convert_to_user_format(
|
|||
# Use original database data
|
||||
formatted_entities.append(
|
||||
{
|
||||
"id": original_entity.get("id", "unknown"),
|
||||
"entity_name": original_entity.get("entity_name", entity_name),
|
||||
"entity_type": original_entity.get("entity_type", "UNKNOWN"),
|
||||
"description": original_entity.get("description", ""),
|
||||
|
|
@ -2757,7 +2757,6 @@ def _convert_to_user_format(
|
|||
# Fallback to LLM context data (for backward compatibility)
|
||||
formatted_entities.append(
|
||||
{
|
||||
"id": entity.get("id", "unknown"),
|
||||
"entity_name": entity_name,
|
||||
"entity_type": entity.get("type", "UNKNOWN"),
|
||||
"description": entity.get("description", ""),
|
||||
|
|
@ -2783,7 +2782,6 @@ def _convert_to_user_format(
|
|||
# Use original database data
|
||||
formatted_relationships.append(
|
||||
{
|
||||
"id": original_relation.get("id", "unknown"),
|
||||
"src_id": original_relation.get("src_id", entity1),
|
||||
"tgt_id": original_relation.get("tgt_id", entity2),
|
||||
"description": original_relation.get("description", ""),
|
||||
|
|
@ -2798,7 +2796,6 @@ def _convert_to_user_format(
|
|||
# Fallback to LLM context data (for backward compatibility)
|
||||
formatted_relationships.append(
|
||||
{
|
||||
"id": relation.get("id", "unknown"),
|
||||
"src_id": entity1,
|
||||
"tgt_id": entity2,
|
||||
"description": relation.get("description", ""),
|
||||
|
|
@ -2812,9 +2809,9 @@ def _convert_to_user_format(
|
|||
|
||||
# Convert chunks format (chunks already contain complete data)
|
||||
formatted_chunks = []
|
||||
for i, chunk in enumerate(final_chunks):
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_data = {
|
||||
"id": chunk.get("id", "unknown"),
|
||||
"reference_id": chunk.get("reference_id", ""),
|
||||
"content": chunk.get("content", ""),
|
||||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
"chunk_id": chunk.get("chunk_id", ""),
|
||||
|
|
@ -2822,7 +2819,7 @@ def _convert_to_user_format(
|
|||
formatted_chunks.append(chunk_data)
|
||||
|
||||
logger.debug(
|
||||
f"[_convert_to_user_format] Formatted {len(formatted_chunks)}/{len(final_chunks)} chunks"
|
||||
f"[convert_to_user_format] Formatted {len(formatted_chunks)}/{len(chunks)} chunks"
|
||||
)
|
||||
|
||||
# Build basic metadata (metadata details will be added by calling functions)
|
||||
|
|
@ -2835,8 +2832,79 @@ def _convert_to_user_format(
|
|||
}
|
||||
|
||||
return {
|
||||
"entities": formatted_entities,
|
||||
"relationships": formatted_relationships,
|
||||
"chunks": formatted_chunks,
|
||||
"status": "success",
|
||||
"message": "Query processed successfully",
|
||||
"data": {
|
||||
"entities": formatted_entities,
|
||||
"relationships": formatted_relationships,
|
||||
"chunks": formatted_chunks,
|
||||
"references": references,
|
||||
},
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
def generate_reference_list_from_chunks(
|
||||
chunks: list[dict],
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Generate reference list from chunks, prioritizing by occurrence frequency.
|
||||
|
||||
This function extracts file_paths from chunks, counts their occurrences,
|
||||
sorts by frequency and first appearance order, creates reference_id mappings,
|
||||
and builds a reference_list structure.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk dictionaries with file_path information
|
||||
|
||||
Returns:
|
||||
tuple: (reference_list, updated_chunks_with_reference_ids)
|
||||
- reference_list: List of dicts with reference_id and file_path
|
||||
- updated_chunks_with_reference_ids: Original chunks with reference_id field added
|
||||
"""
|
||||
if not chunks:
|
||||
return [], []
|
||||
|
||||
# 1. Extract all valid file_paths and count their occurrences
|
||||
file_path_counts = {}
|
||||
for chunk in chunks:
|
||||
file_path = chunk.get("file_path", "")
|
||||
if file_path and file_path != "unknown_source":
|
||||
file_path_counts[file_path] = file_path_counts.get(file_path, 0) + 1
|
||||
|
||||
# 2. Sort file paths by frequency (descending), then by first appearance order
|
||||
# Create a list of (file_path, count, first_index) tuples
|
||||
file_path_with_indices = []
|
||||
seen_paths = set()
|
||||
for i, chunk in enumerate(chunks):
|
||||
file_path = chunk.get("file_path", "")
|
||||
if file_path and file_path != "unknown_source" and file_path not in seen_paths:
|
||||
file_path_with_indices.append((file_path, file_path_counts[file_path], i))
|
||||
seen_paths.add(file_path)
|
||||
|
||||
# Sort by count (descending), then by first appearance index (ascending)
|
||||
sorted_file_paths = sorted(file_path_with_indices, key=lambda x: (-x[1], x[2]))
|
||||
unique_file_paths = [item[0] for item in sorted_file_paths]
|
||||
|
||||
# 3. Create mapping from file_path to reference_id (prioritized by frequency)
|
||||
file_path_to_ref_id = {}
|
||||
for i, file_path in enumerate(unique_file_paths):
|
||||
file_path_to_ref_id[file_path] = str(i + 1)
|
||||
|
||||
# 4. Add reference_id field to each chunk
|
||||
updated_chunks = []
|
||||
for chunk in chunks:
|
||||
chunk_copy = chunk.copy()
|
||||
file_path = chunk_copy.get("file_path", "")
|
||||
if file_path and file_path != "unknown_source":
|
||||
chunk_copy["reference_id"] = file_path_to_ref_id[file_path]
|
||||
else:
|
||||
chunk_copy["reference_id"] = ""
|
||||
updated_chunks.append(chunk_copy)
|
||||
|
||||
# 5. Build reference_list
|
||||
reference_list = []
|
||||
for i, file_path in enumerate(unique_file_paths):
|
||||
reference_list.append({"reference_id": str(i + 1), "file_path": file_path})
|
||||
|
||||
return reference_list, updated_chunks
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@
|
|||
"""
|
||||
Test script: Demonstrates usage of aquery_data FastAPI endpoint
|
||||
Query content: Who is the author of LightRAG
|
||||
|
||||
Updated to handle the new data format where:
|
||||
- Response includes status, message, data, and metadata fields at top level
|
||||
- Actual query results (entities, relationships, chunks, references) are nested under 'data' field
|
||||
- Includes backward compatibility with legacy format
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
|
@ -80,17 +85,37 @@ def test_aquery_data_endpoint():
|
|||
def print_query_results(data: Dict[str, Any]):
|
||||
"""Format and print query results"""
|
||||
|
||||
entities = data.get("entities", [])
|
||||
relationships = data.get("relationships", [])
|
||||
chunks = data.get("chunks", [])
|
||||
metadata = data.get("metadata", {})
|
||||
# Check for new data format with status and message
|
||||
status = data.get("status", "unknown")
|
||||
message = data.get("message", "")
|
||||
|
||||
print(f"\n📋 Query Status: {status}")
|
||||
if message:
|
||||
print(f"📋 Message: {message}")
|
||||
|
||||
# Handle new nested data format
|
||||
query_data = data.get("data", {})
|
||||
|
||||
# Fallback to old format if new format is not present
|
||||
if not query_data and any(
|
||||
key in data for key in ["entities", "relationships", "chunks"]
|
||||
):
|
||||
print(" (Using legacy data format)")
|
||||
query_data = data
|
||||
|
||||
entities = query_data.get("entities", [])
|
||||
relationships = query_data.get("relationships", [])
|
||||
chunks = query_data.get("chunks", [])
|
||||
references = query_data.get("references", [])
|
||||
|
||||
print("\n📊 Query result statistics:")
|
||||
print(f" Entity count: {len(entities)}")
|
||||
print(f" Relationship count: {len(relationships)}")
|
||||
print(f" Text chunk count: {len(chunks)}")
|
||||
print(f" Reference count: {len(references)}")
|
||||
|
||||
# Print metadata
|
||||
# Print metadata (now at top level in new format)
|
||||
metadata = data.get("metadata", {})
|
||||
if metadata:
|
||||
print("\n🔍 Query metadata:")
|
||||
print(f" Query mode: {metadata.get('query_mode', 'unknown')}")
|
||||
|
|
@ -118,12 +143,14 @@ def print_query_results(data: Dict[str, Any]):
|
|||
entity_type = entity.get("entity_type", "Unknown")
|
||||
description = entity.get("description", "No description")
|
||||
file_path = entity.get("file_path", "Unknown source")
|
||||
reference_id = entity.get("reference_id", "No reference")
|
||||
|
||||
print(f" {i+1}. {entity_name} ({entity_type})")
|
||||
print(
|
||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
print(f" Source: {file_path}")
|
||||
print(f" Reference ID: {reference_id}")
|
||||
print()
|
||||
|
||||
# Print relationship information
|
||||
|
|
@ -135,6 +162,7 @@ def print_query_results(data: Dict[str, Any]):
|
|||
description = rel.get("description", "No description")
|
||||
keywords = rel.get("keywords", "No keywords")
|
||||
file_path = rel.get("file_path", "Unknown source")
|
||||
reference_id = rel.get("reference_id", "No reference")
|
||||
|
||||
print(f" {i+1}. {src} → {tgt}")
|
||||
print(f" Keywords: {keywords}")
|
||||
|
|
@ -142,6 +170,7 @@ def print_query_results(data: Dict[str, Any]):
|
|||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
print(f" Source: {file_path}")
|
||||
print(f" Reference ID: {reference_id}")
|
||||
print()
|
||||
|
||||
# Print text chunk information
|
||||
|
|
@ -151,14 +180,26 @@ def print_query_results(data: Dict[str, Any]):
|
|||
content = chunk.get("content", "No content")
|
||||
file_path = chunk.get("file_path", "Unknown source")
|
||||
chunk_id = chunk.get("chunk_id", "Unknown ID")
|
||||
reference_id = chunk.get("reference_id", "No reference")
|
||||
|
||||
print(f" {i+1}. Text chunk ID: {chunk_id}")
|
||||
print(f" Source: {file_path}")
|
||||
print(f" Reference ID: {reference_id}")
|
||||
print(
|
||||
f" Content: {content[:200]}{'...' if len(content) > 200 else ''}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Print references information (new in updated format)
|
||||
if references:
|
||||
print("📚 References:")
|
||||
for i, ref in enumerate(references):
|
||||
reference_id = ref.get("reference_id", "Unknown ID")
|
||||
file_path = ref.get("file_path", "Unknown source")
|
||||
print(f" {i+1}. Reference ID: {reference_id}")
|
||||
print(f" File Path: {file_path}")
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue