Replace search API with aquery_data for consistent raw data retrieval, mirroring aquery results

• Reuse existing query logic paths and remove kg_search function entirely
• Update kg_query/naive_query to return raw data as needed
This commit is contained in:
yangdx 2025-09-13 15:30:29 +08:00
parent c2d064b580
commit 0ffb5d5f2d
3 changed files with 319 additions and 246 deletions

View file

@ -78,7 +78,6 @@ from .operate import (
extract_entities,
merge_nodes_and_edges,
kg_query,
kg_search,
naive_query,
_rebuild_knowledge_from_chunks,
)
@ -2116,54 +2115,66 @@ class LightRAG:
await self._query_done()
return response
def search(
async def aquery_data(
self,
query: str,
param: QueryParam = QueryParam(),
) -> dict[str, Any]:
"""
Synchronous search API: returns structured retrieval results without LLM generation.
Asynchronous data retrieval API: returns structured retrieval results without LLM generation.
This function reuses the same logic as aquery but stops before LLM generation,
returning the final processed entities, relationships, and chunks data that would be sent to LLM.
Args:
query: Query text.
param: Query parameters (reuse the same QueryParam as query/aquery).
param: Query parameters (same as aquery).
Returns:
dict[str, Any]: {"entities": [...], "relationships": [...], "chunks": [...], "metadata": {...}}
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.asearch(query, param))
async def asearch(
self,
query: str,
param: QueryParam = QueryParam(),
) -> dict[str, Any]:
"""
Asynchronous search API: calls kg_search and returns retrieval-only results
(entities, relationships, and merged chunks).
Args:
query: Query text.
param: Query parameters (reuse the same QueryParam as query/aquery).
Returns:
dict[str, Any]: Structured search result
dict[str, Any]: Structured data result with entities, relationships, chunks, and metadata
"""
global_config = asdict(self)
response = await kg_search(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache,
chunks_vdb=self.chunks_vdb,
)
if param.mode in ["local", "global", "hybrid", "mix"]:
final_data = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
chunks_vdb=self.chunks_vdb,
return_raw_data=True, # Get final processed data
)
elif param.mode == "naive":
final_data = await naive_query(
query.strip(),
self.chunks_vdb,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
return_raw_data=True, # Get final processed data
)
elif param.mode == "bypass":
# bypass mode returns empty data
final_data = {
"entities": [],
"relationships": [],
"chunks": [],
"metadata": {
"query_mode": "bypass",
"keywords": {"high_level": [], "low_level": []}
}
}
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
return final_data
async def _query_done(self):
await self.llm_response_cache.index_done_callback()

View file

@ -4,7 +4,7 @@ from functools import partial
import asyncio
import json
import json_repair
from typing import Any, AsyncIterator
from typing import Any, AsyncIterator, overload, Literal
from collections import Counter, defaultdict
from .utils import (
@ -30,6 +30,7 @@ from .utils import (
safe_vdb_operation_with_exception,
create_prefixed_exception,
fix_tuple_delimiter_corruption,
_convert_to_user_format,
)
from .base import (
BaseGraphStorage,
@ -2154,6 +2155,7 @@ async def extract_entities(
return chunk_results
@overload
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
@ -2165,7 +2167,41 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[True] = False,
) -> dict[str, Any]:
...
@overload
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[False] = False,
) -> str | AsyncIterator[str]:
...
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]:
if not query:
return PROMPTS["fail_response"]
@ -2221,7 +2257,30 @@ async def kg_query(
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Build context
# If raw data is requested, get both context and raw data
if return_raw_data:
context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb,
return_raw_data=True,
)
if isinstance(context_result, tuple):
_, raw_data = context_result
return raw_data
else:
raise RuntimeError(
"Failed to build query context for raw data. Expected a tuple, but got a different type."
)
# Build context (normal flow)
context = await _build_query_context(
query,
ll_keywords_str,
@ -2687,7 +2746,6 @@ async def _apply_token_truncation(
) -> dict[str, Any]:
"""
Apply token-based truncation to entities and relations for LLM efficiency.
This function is only used by kg_query, not kg_search.
"""
tokenizer = global_config.get("tokenizer")
if not tokenizer:
@ -2833,46 +2891,28 @@ async def _apply_token_truncation(
async def _merge_all_chunks(
search_result: dict[str, Any],
filtered_entities: list[dict] = None,
filtered_relations: list[dict] = None,
filtered_entities: list[dict],
filtered_relations: list[dict],
vector_chunks: list[dict],
query: str = "",
knowledge_graph_inst: BaseGraphStorage = None,
text_chunks_db: BaseKVStorage = None,
query_param: QueryParam = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
query_embedding: list[float] = None,
) -> list[dict]:
"""
Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks.
For kg_search: uses all original entities/relations
For kg_query: uses filtered entities/relations based on token truncation
"""
if chunk_tracking is None:
chunk_tracking = search_result.get("chunk_tracking", {})
# Use filtered entities/relations if provided (kg_query), otherwise use all (kg_search)
entities_to_use = (
filtered_entities
if filtered_entities is not None
else search_result["final_entities"]
)
relations_to_use = (
filtered_relations
if filtered_relations is not None
else search_result["final_relations"]
)
vector_chunks = search_result["vector_chunks"]
chunk_tracking = {}
# Get chunks from entities
entity_chunks = []
if entities_to_use and text_chunks_db:
# Pre-compute query embedding if needed
query_embedding = search_result.get("query_embedding", None)
if filtered_entities and text_chunks_db:
entity_chunks = await _find_related_text_unit_from_entities(
entities_to_use,
filtered_entities,
query_param,
text_chunks_db,
knowledge_graph_inst,
@ -2884,9 +2924,9 @@ async def _merge_all_chunks(
# Get chunks from relations
relation_chunks = []
if relations_to_use and text_chunks_db:
if filtered_relations and text_chunks_db:
relation_chunks = await _find_related_text_unit_from_relations(
relations_to_use,
filtered_relations,
query_param,
text_chunks_db,
entity_chunks, # For deduplication
@ -2952,171 +2992,6 @@ async def _merge_all_chunks(
return merged_chunks
async def kg_search(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
chunks_vdb: BaseVectorStorage = None,
) -> dict[str, Any]:
"""Search knowledge graph and return structured results without LLM generation
For kg_search: Search + Merge chunks (NO truncation)
Returns complete search results for user analysis
"""
if not query:
return {
"entities": [],
"relationships": [],
"chunks": [],
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": [], "low_level": []},
},
}
# Handle cache (reuse existing cache logic but for search results)
args_hash = compute_args_hash(
query_param.mode,
query,
"search", # Different cache key for search vs query
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="search"
)
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
try:
return json_repair.loads(cached_response)
except (json.JSONDecodeError, KeyError):
logger.warning(
"Invalid cache format for search results, proceeding with fresh search"
)
# Get keywords (reuse existing logic)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
logger.debug(f"High-level keywords: {hl_keywords}")
logger.debug(f"Low-level keywords: {ll_keywords}")
# Handle empty keywords (reuse existing logic)
if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]:
logger.warning("low_level_keywords is empty")
if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]:
logger.warning("high_level_keywords is empty")
if hl_keywords == [] and ll_keywords == []:
if len(query) < 50:
logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query]
else:
return {
"entities": [],
"relationships": [],
"chunks": [],
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
"error": "Keywords extraction failed",
},
}
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Stage 1: Pure search (no truncation for kg_search)
search_result = await _perform_kg_search(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb,
)
if not search_result["final_entities"] and not search_result["final_relations"]:
return {
"entities": [],
"relationships": [],
"chunks": [],
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
"error": "No valid results found",
},
}
# Stage 2: Merge ALL chunks (no filtering, use all entities/relations)
merged_chunks = await _merge_all_chunks(
search_result,
filtered_entities=None, # Use ALL entities (no filtering)
filtered_relations=None, # Use ALL relations (no filtering)
query=query,
knowledge_graph_inst=knowledge_graph_inst,
text_chunks_db=text_chunks_db,
query_param=query_param,
chunks_vdb=chunks_vdb,
chunk_tracking=search_result["chunk_tracking"],
)
# Build final structured result
final_result = {
"entities": search_result["final_entities"],
"relationships": search_result["final_relations"],
"chunks": merged_chunks,
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
},
}
# Cache the results
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": "search",
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=json.dumps(final_result, ensure_ascii=False),
prompt=query,
mode=query_param.mode,
cache_type="search",
queryparam=queryparam_dict,
),
)
return final_result
async def _build_llm_context(
entities_context: list[dict],
relations_context: list[dict],
@ -3125,7 +3000,8 @@ async def _build_llm_context(
query_param: QueryParam,
global_config: dict[str, str],
chunk_tracking: dict = None,
) -> str:
return_final_data: bool = False,
) -> str | tuple[str, dict[str, Any]]:
"""
Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation.
@ -3332,7 +3208,18 @@ async def _build_llm_context(
```
"""
return result
# If final data is requested, return both context and complete data structure
if return_final_data:
final_data = _convert_to_user_format(
entities_context,
relations_context,
truncated_chunks,
query_param.mode
)
return result, final_data
else:
return result
# Now let's update the old _build_query_context to use the new architecture
@ -3346,7 +3233,8 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None,
) -> str:
return_raw_data: bool = False,
) -> str | tuple[str, dict[str, Any]]:
"""
Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
@ -3385,15 +3273,16 @@ async def _build_query_context(
# Stage 3: Merge chunks using filtered entities/relations
merged_chunks = await _merge_all_chunks(
search_result,
filtered_entities=truncation_result["filtered_entities"],
filtered_relations=truncation_result["filtered_relations"],
vector_chunks=search_result["vector_chunks"],
query=query,
knowledge_graph_inst=knowledge_graph_inst,
text_chunks_db=text_chunks_db,
query_param=query_param,
chunks_vdb=chunks_vdb,
chunk_tracking=search_result["chunk_tracking"],
query_embedding=search_result["query_embedding"],
)
if (
@ -3404,17 +3293,60 @@ async def _build_query_context(
return None
# Stage 4: Build final LLM context with dynamic token processing
context = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
)
return context
if return_raw_data:
# Get both context and final data
context_result = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
return_final_data=True,
)
if isinstance(context_result, tuple):
context, final_chunks = context_result
else:
# Handle case where no final chunks are returned
context = context_result
final_chunks = merged_chunks
# Build raw data structure with the same data that goes to LLM
raw_data = {
"entities": truncation_result["filtered_entities"], # Use filtered entities (same as LLM)
"relationships": truncation_result["filtered_relations"], # Use filtered relations (same as LLM)
"chunks": final_chunks, # Use final processed chunks (same as LLM)
"metadata": {
"query_mode": query_param.mode,
"keywords": {
"high_level": hl_keywords.split(", ") if hl_keywords else [],
"low_level": ll_keywords.split(", ") if ll_keywords else []
},
"processing_info": {
"total_entities_found": len(search_result["final_entities"]),
"total_relations_found": len(search_result["final_relations"]),
"entities_after_truncation": len(truncation_result["filtered_entities"]),
"relations_after_truncation": len(truncation_result["filtered_relations"]),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(final_chunks)
}
}
}
return context, raw_data
else:
# Normal context building (existing logic)
context = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
)
return context
async def _get_node_data(
@ -3989,6 +3921,7 @@ async def _find_related_text_unit_from_relations(
return result_chunks
@overload
async def naive_query(
query: str,
chunks_vdb: BaseVectorStorage,
@ -3996,7 +3929,31 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
return_raw_data: Literal[True] = True,
) -> dict[str, Any]:
...
@overload
async def naive_query(
query: str,
chunks_vdb: BaseVectorStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
return_raw_data: Literal[False] = False,
) -> str | AsyncIterator[str]:
...
async def naive_query(
query: str,
chunks_vdb: BaseVectorStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]:
if query_param.model_func:
use_model_func = query_param.model_func
else:
@ -4032,6 +3989,21 @@ async def naive_query(
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0:
# Build empty raw data for consistency
empty_raw_data = {
"entities": [], # naive mode has no entities
"relationships": [], # naive mode has no relationships
"chunks": [],
"metadata": {
"query_mode": "naive",
"keywords": {"high_level": [], "low_level": []}
}
}
# If only raw data is requested, return it directly
if return_raw_data:
return empty_raw_data
return PROMPTS["fail_response"]
# Calculate dynamic token limit for chunks
@ -4089,6 +4061,20 @@ async def naive_query(
logger.info(f"Final context: {len(processed_chunks)} chunks")
# If only raw data is requested, return it directly
if return_raw_data:
# Build raw data structure for naive mode using processed chunks
raw_data = {
"entities": [], # naive mode has no entities
"relationships": [], # naive mode has no relationships
"chunks": processed_chunks, # Use processed chunks (same as LLM)
"metadata": {
"query_mode": "naive",
"keywords": {"high_level": [], "low_level": []}
}
}
return raw_data
# Build text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):

View file

@ -2723,3 +2723,79 @@ def create_prefixed_exception(original_exception: Exception, prefix: str) -> Exc
f"{prefix}: {type(original_exception).__name__}: {str(original_exception)} "
f"(Original exception could not be reconstructed: {construct_error})"
)
def _convert_to_user_format(
entities_context: list[dict],
relations_context: list[dict],
final_chunks: list[dict],
query_mode: str,
hl_keywords: list[str] = None,
ll_keywords: list[str] = None,
search_result: dict = None,
truncation_result: dict = None,
merged_chunks: list[dict] = None,
) -> dict[str, Any]:
"""Convert internal data format to user-friendly format"""
# Convert entities format
formatted_entities = []
for entity in entities_context:
formatted_entities.append({
"entity_name": entity.get("entity", ""),
"entity_type": entity.get("type", "UNKNOWN"),
"description": entity.get("description", ""),
"source_id": entity.get("source_id", ""),
"file_path": entity.get("file_path", "unknown_source"),
"created_at": entity.get("created_at", ""),
})
# Convert relationships format
formatted_relationships = []
for relation in relations_context:
formatted_relationships.append({
"src_id": relation.get("entity1", ""),
"tgt_id": relation.get("entity2", ""),
"description": relation.get("description", ""),
"keywords": relation.get("keywords", ""),
"weight": relation.get("weight", 1.0),
"source_id": relation.get("source_id", ""),
"file_path": relation.get("file_path", "unknown_source"),
"created_at": relation.get("created_at", ""),
})
# Convert chunks format
formatted_chunks = []
for chunk in final_chunks:
formatted_chunks.append({
"content": chunk.get("content", ""),
"file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk.get("chunk_id", ""),
})
# Build metadata with processing info
metadata = {
"query_mode": query_mode,
"keywords": {
"high_level": hl_keywords or [],
"low_level": ll_keywords or []
}
}
# Add processing info if available
if search_result and truncation_result and merged_chunks is not None:
metadata["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(truncation_result.get("filtered_entities", [])),
"relations_after_truncation": len(truncation_result.get("filtered_relations", [])),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(final_chunks)
}
return {
"entities": formatted_entities,
"relationships": formatted_relationships,
"chunks": formatted_chunks,
"metadata": metadata
}