Add search method to lightrag. Search is for retrieve structured objects (entities, relations, chunks) in their raw data format.

This commit is contained in:
Tong Da 2025-09-01 01:57:19 +08:00
parent 5e73896c40
commit a60a8704ba

View file

@ -1185,9 +1185,9 @@ async def _merge_nodes_then_upsert(
# Log based on actual LLM usage
if llm_was_used:
status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
else:
status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
@ -1302,9 +1302,9 @@ async def _merge_edges_then_upsert(
# Log based on actual LLM usage
if llm_was_used:
status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
else:
status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
@ -2285,7 +2285,7 @@ async def _get_vector_context(
return []
async def _build_query_context(
async def _perform_kg_search(
query: str,
ll_keywords: str,
hl_keywords: str,
@ -2295,25 +2295,21 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None,
):
if not query:
logger.warning("Query is empty, skipping context building")
return ""
) -> dict[str, Any]:
"""
Pure search logic that retrieves raw entities, relations, and vector chunks.
No token truncation or formatting - just raw search results.
"""
logger.info(f"Process {os.getpid()} building query context...")
# Collect chunks from different sources separately
vector_chunks = []
entity_chunks = []
relation_chunks = []
entities_context = []
relations_context = []
# Store original data for later text chunk retrieval
# Initialize result containers
local_entities = []
local_relations = []
global_entities = []
global_relations = []
vector_chunks = []
chunk_tracking = {}
# Handle different query modes
# Track chunk sources and metadata for final logging
chunk_tracking = {} # chunk_id -> {source, frequency, order}
@ -2369,7 +2365,7 @@ async def _build_query_context(
query_param,
)
# Get vector chunks first if in mix mode
# Get vector chunks for mix mode
if query_param.mode == "mix" and chunks_vdb:
vector_chunks = await _get_vector_context(
query,
@ -2389,11 +2385,9 @@ async def _build_query_context(
else:
logger.warning(f"Vector chunk missing chunk_id: {chunk}")
# Use round-robin merge to combine local and global data fairly
# Round-robin merge entities
final_entities = []
seen_entities = set()
# Round-robin merge entities
max_len = max(len(local_entities), len(global_entities))
for i in range(max_len):
# First from local
@ -2415,7 +2409,6 @@ async def _build_query_context(
# Round-robin merge relations
final_relations = []
seen_relations = set()
max_len = max(len(local_relations), len(global_relations))
for i in range(max_len):
# First from local
@ -2448,91 +2441,107 @@ async def _build_query_context(
final_relations.append(relation)
seen_relations.add(rel_key)
# Generate entities context
logger.info(
f"Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks"
)
return {
"final_entities": final_entities,
"final_relations": final_relations,
"vector_chunks": vector_chunks,
"chunk_tracking": chunk_tracking,
"query_embedding": query_embedding,
}
async def _apply_token_truncation(
search_result: dict[str, Any],
query_param: QueryParam,
global_config: dict[str, str],
) -> dict[str, Any]:
"""
Apply token-based truncation to entities and relations for LLM efficiency.
This function is only used by kg_query, not kg_search.
"""
tokenizer = global_config.get("tokenizer")
if not tokenizer:
logger.warning("No tokenizer found, skipping truncation")
return {
"truncated_entities": search_result["final_entities"],
"truncated_relations": search_result["final_relations"],
"entities_context": [],
"relations_context": [],
"filtered_entities": search_result["final_entities"],
"filtered_relations": search_result["final_relations"],
}
# Get token limits from query_param with fallbacks
max_entity_tokens = getattr(
query_param,
"max_entity_tokens",
global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS),
)
max_relation_tokens = getattr(
query_param,
"max_relation_tokens",
global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS),
)
final_entities = search_result["final_entities"]
final_relations = search_result["final_relations"]
# Generate entities context for truncation
entities_context = []
for i, n in enumerate(final_entities):
created_at = n.get("created_at", "UNKNOWN")
for i, entity in enumerate(final_entities):
created_at = entity.get("created_at", "UNKNOWN")
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from node data
file_path = n.get("file_path", "unknown_source")
entities_context.append(
{
"id": i + 1,
"entity": n["entity_name"],
"type": n.get("entity_type", "UNKNOWN"),
"description": n.get("description", "UNKNOWN"),
"entity": entity["entity_name"],
"type": entity.get("entity_type", "UNKNOWN"),
"description": entity.get("description", "UNKNOWN"),
"created_at": created_at,
"file_path": file_path,
"file_path": entity.get("file_path", "unknown_source"),
}
)
# Generate relations context
# Generate relations context for truncation
relations_context = []
for i, e in enumerate(final_relations):
created_at = e.get("created_at", "UNKNOWN")
# Convert timestamp to readable format
for i, relation in enumerate(final_relations):
created_at = relation.get("created_at", "UNKNOWN")
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from edge data
file_path = e.get("file_path", "unknown_source")
# Handle different relation data formats
if "src_tgt" in e:
entity1, entity2 = e["src_tgt"]
if "src_tgt" in relation:
entity1, entity2 = relation["src_tgt"]
else:
entity1, entity2 = e.get("src_id"), e.get("tgt_id")
entity1, entity2 = relation.get("src_id"), relation.get("tgt_id")
relations_context.append(
{
"id": i + 1,
"entity1": entity1,
"entity2": entity2,
"description": e.get("description", "UNKNOWN"),
"description": relation.get("description", "UNKNOWN"),
"created_at": created_at,
"file_path": file_path,
"file_path": relation.get("file_path", "unknown_source"),
}
)
logger.debug(
f"Initial KG query results: {len(entities_context)} entities, {len(relations_context)} relations"
f"Before truncation: {len(entities_context)} entities, {len(relations_context)} relations"
)
# Unified token control system - Apply precise token limits to entities and relations
tokenizer = text_chunks_db.global_config.get("tokenizer")
# Get new token limits from query_param (with fallback to global_config)
max_entity_tokens = getattr(
query_param,
"max_entity_tokens",
text_chunks_db.global_config.get(
"max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS
),
)
max_relation_tokens = getattr(
query_param,
"max_relation_tokens",
text_chunks_db.global_config.get(
"max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS
),
)
max_total_tokens = getattr(
query_param,
"max_total_tokens",
text_chunks_db.global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
)
# Truncate entities based on complete JSON serialization
# Apply token-based truncation
if entities_context:
# Process entities context to replace GRAPH_FIELD_SEP with : in file_path fields
# Remove file_path and created_at for token calculation
for entity in entities_context:
# remove file_path and created_at
entity.pop("file_path", None)
entity.pop("created_at", None)
# if "file_path" in entity and entity["file_path"]:
# entity["file_path"] = entity["file_path"].replace(GRAPH_FIELD_SEP, ";")
entities_context = truncate_list_by_token_size(
entities_context,
@ -2541,17 +2550,11 @@ async def _build_query_context(
tokenizer=tokenizer,
)
# Truncate relations based on complete JSON serialization
if relations_context:
# Process relations context to replace GRAPH_FIELD_SEP with : in file_path fields
# Remove file_path and created_at for token calculation
for relation in relations_context:
# remove file_path and created_at
relation.pop("file_path", None)
relation.pop("created_at", None)
# if "file_path" in relation and relation["file_path"]:
# relation["file_path"] = relation["file_path"].replace(
# GRAPH_FIELD_SEP, ";"
# )
relations_context = truncate_list_by_token_size(
relations_context,
@ -2560,41 +2563,86 @@ async def _build_query_context(
tokenizer=tokenizer,
)
# After truncation, get text chunks based on final entities and relations
logger.info(
f"Truncated KG query results: {len(entities_context)} entities, {len(relations_context)} relations"
f"After truncation: {len(entities_context)} entities, {len(relations_context)} relations"
)
# Create filtered data based on truncated context
final_node_datas = []
if entities_context and final_entities:
# Create filtered original data based on truncated context
filtered_entities = []
if entities_context:
final_entity_names = {e["entity"] for e in entities_context}
seen_nodes = set()
for node in final_entities:
name = node.get("entity_name")
for entity in final_entities:
name = entity.get("entity_name")
if name in final_entity_names and name not in seen_nodes:
final_node_datas.append(node)
filtered_entities.append(entity)
seen_nodes.add(name)
final_edge_datas = []
if relations_context and final_relations:
filtered_relations = []
if relations_context:
final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context}
seen_edges = set()
for edge in final_relations:
src, tgt = edge.get("src_id"), edge.get("tgt_id")
for relation in final_relations:
src, tgt = relation.get("src_id"), relation.get("tgt_id")
if src is None or tgt is None:
src, tgt = edge.get("src_tgt", (None, None))
src, tgt = relation.get("src_tgt", (None, None))
pair = (src, tgt)
if pair in final_relation_pairs and pair not in seen_edges:
final_edge_datas.append(edge)
filtered_relations.append(relation)
seen_edges.add(pair)
# Get text chunks based on final filtered data
# To preserve the influence of entity order, entiy-based chunks should not be deduplcicated by vector_chunks
if final_node_datas:
return {
"truncated_entities": final_entities, # Keep original for backward compatibility
"truncated_relations": final_relations, # Keep original for backward compatibility
"entities_context": entities_context, # Formatted and truncated for LLM
"relations_context": relations_context, # Formatted and truncated for LLM
"filtered_entities": filtered_entities, # Original entities that passed truncation
"filtered_relations": filtered_relations, # Original relations that passed truncation
}
async def _merge_all_chunks(
search_result: dict[str, Any],
filtered_entities: list[dict] = None,
filtered_relations: list[dict] = None,
query: str = "",
knowledge_graph_inst: BaseGraphStorage = None,
text_chunks_db: BaseKVStorage = None,
query_param: QueryParam = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
) -> list[dict]:
"""
Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks.
For kg_search: uses all original entities/relations
For kg_query: uses filtered entities/relations based on token truncation
"""
if chunk_tracking is None:
chunk_tracking = search_result.get("chunk_tracking", {})
# Use filtered entities/relations if provided (kg_query), otherwise use all (kg_search)
entities_to_use = (
filtered_entities
if filtered_entities is not None
else search_result["final_entities"]
)
relations_to_use = (
filtered_relations
if filtered_relations is not None
else search_result["final_relations"]
)
vector_chunks = search_result["vector_chunks"]
# Get chunks from entities
entity_chunks = []
if entities_to_use and text_chunks_db:
# Pre-compute query embedding if needed
query_embedding = search_result.get("query_embedding", None)
entity_chunks = await _find_related_text_unit_from_entities(
final_node_datas,
entities_to_use,
query_param,
text_chunks_db,
knowledge_graph_inst,
@ -2604,21 +2652,21 @@ async def _build_query_context(
query_embedding=query_embedding,
)
# Find deduplcicated chunks from edge
# Deduplication cause chunks solely relation-based to be prioritized and sent to the LLM when re-ranking is disabled
if final_edge_datas:
# Get chunks from relations
relation_chunks = []
if relations_to_use and text_chunks_db:
relation_chunks = await _find_related_text_unit_from_relations(
final_edge_datas,
relations_to_use,
query_param,
text_chunks_db,
entity_chunks,
entity_chunks, # For deduplication
query,
chunks_vdb,
chunk_tracking=chunk_tracking,
query_embedding=query_embedding,
)
# Round-robin merge chunks from different sources with deduplication by chunk_id
# Round-robin merge chunks from different sources with deduplication
merged_chunks = []
seen_chunk_ids = set()
max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks))
@ -2668,12 +2716,238 @@ async def _build_query_context(
)
logger.info(
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplication: {origin_len - len(merged_chunks)})"
)
return merged_chunks
async def kg_search(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
chunks_vdb: BaseVectorStorage = None,
) -> dict[str, Any]:
"""Search knowledge graph and return structured results without LLM generation
For kg_search: Search + Merge chunks (NO truncation)
Returns complete search results for user analysis
"""
if not query:
return {
"entities": [],
"relationships": [],
"chunks": [],
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": [], "low_level": []},
},
}
# Handle cache (reuse existing cache logic but for search results)
args_hash = compute_args_hash(
query_param.mode,
query,
"search", # Different cache key for search vs query
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_response = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="search"
)
if cached_response is not None:
try:
return json_repair.loads(cached_response)
except (json.JSONDecodeError, KeyError):
logger.warning(
"Invalid cache format for search results, proceeding with fresh search"
)
# Get keywords (reuse existing logic)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
logger.debug(f"High-level keywords: {hl_keywords}")
logger.debug(f"Low-level keywords: {ll_keywords}")
# Handle empty keywords (reuse existing logic)
if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]:
logger.warning("low_level_keywords is empty")
if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]:
logger.warning("high_level_keywords is empty")
if hl_keywords == [] and ll_keywords == []:
if len(query) < 50:
logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query]
else:
return {
"entities": [],
"relationships": [],
"chunks": [],
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
"error": "Keywords extraction failed",
},
}
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Stage 1: Pure search (no truncation for kg_search)
search_result = await _perform_kg_search(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
query_param,
chunks_vdb,
)
if not search_result["final_entities"] and not search_result["final_relations"]:
return {
"entities": [],
"relationships": [],
"chunks": [],
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
"error": "No valid results found",
},
}
# Stage 2: Merge ALL chunks (no filtering, use all entities/relations)
merged_chunks = await _merge_all_chunks(
search_result,
filtered_entities=None, # Use ALL entities (no filtering)
filtered_relations=None, # Use ALL relations (no filtering)
query=query,
knowledge_graph_inst=knowledge_graph_inst,
text_chunks_db=text_chunks_db,
query_param=query_param,
chunks_vdb=chunks_vdb,
chunk_tracking=search_result["chunk_tracking"],
)
# Build final structured result
final_result = {
"entities": search_result["final_entities"],
"relationships": search_result["final_relations"],
"chunks": merged_chunks,
"metadata": {
"query_mode": query_param.mode,
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
},
}
# Cache the results
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": "search",
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=json.dumps(final_result, ensure_ascii=False),
prompt=query,
mode=query_param.mode,
cache_type="search",
queryparam=queryparam_dict,
),
)
return final_result
async def _build_llm_context(
entities_context: list[dict],
relations_context: list[dict],
merged_chunks: list[dict],
query: str,
query_param: QueryParam,
global_config: dict[str, str],
chunk_tracking: dict = None,
) -> str:
"""
Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation.
"""
tokenizer = global_config.get("tokenizer")
if not tokenizer:
logger.warning("No tokenizer found, building context without token limits")
# Build basic context without token processing
entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_context = []
for i, chunk in enumerate(merged_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
)
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
return f"""-----Entities(KG)-----
```json
{entities_str}
```
-----Relationships(KG)-----
```json
{relations_str}
```
-----Document Chunks(DC)-----
```json
{text_units_str}
```
"""
# Get token limits
max_total_tokens = getattr(
query_param,
"max_total_tokens",
global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
)
# Apply token processing to merged chunks
text_units_context = []
truncated_chunks = []
if merged_chunks:
# Calculate dynamic token limit for text chunks
entities_str = json.dumps(entities_context, ensure_ascii=False)
@ -2704,18 +2978,7 @@ async def _build_query_context(
)
kg_context_tokens = len(tokenizer.encode(kg_context))
# Calculate actual system prompt overhead dynamically
# 1. Converstion history not included in context length calculation
history_context = ""
# if query_param.conversation_history:
# history_context = get_conversation_turns(
# query_param.conversation_history, query_param.history_turns
# )
# history_tokens = (
# len(tokenizer.encode(history_context)) if history_context else 0
# )
# 2. Calculate system prompt template tokens (excluding context_data)
# Calculate system prompt template overhead
user_prompt = query_param.user_prompt if query_param.user_prompt else ""
response_type = (
query_param.response_type
@ -2723,14 +2986,14 @@ async def _build_query_context(
else "Multiple Paragraphs"
)
# Get the system prompt template from PROMPTS
sys_prompt_template = text_chunks_db.global_config.get(
# Get the system prompt template from PROMPTS or global_config
sys_prompt_template = global_config.get(
"system_prompt_template", PROMPTS["rag_response"]
)
# Create a sample system prompt with placeholders filled (excluding context_data)
# Create sample system prompt for overhead calculation
sample_sys_prompt = sys_prompt_template.format(
history=history_context,
history="", # History not included in context length calculation
context_data="", # Empty for overhead calculation
response_type=response_type,
user_prompt=user_prompt,
@ -2756,7 +3019,7 @@ async def _build_query_context(
query=query,
unique_chunks=merged_chunks,
query_param=query_param,
global_config=text_chunks_db.global_config,
global_config=global_config,
source_type=query_param.mode,
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
)
@ -2827,6 +3090,76 @@ async def _build_query_context(
return result
# Now let's update the old _build_query_context to use the new architecture
async def _build_query_context(
query: str,
ll_keywords: str,
hl_keywords: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None,
) -> str:
"""
Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
"""
if not query:
logger.warning("Query is empty, skipping context building")
return ""
# Stage 1: Pure search
search_result = await _perform_kg_search(
query,
ll_keywords,
hl_keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
query_param,
chunks_vdb,
)
if not search_result["final_entities"] and not search_result["final_relations"]:
return None
# Stage 2: Apply token truncation for LLM efficiency
truncation_result = await _apply_token_truncation(
search_result,
query_param,
text_chunks_db.global_config,
)
# Stage 3: Merge chunks using filtered entities/relations
merged_chunks = await _merge_all_chunks(
search_result,
filtered_entities=truncation_result["filtered_entities"],
filtered_relations=truncation_result["filtered_relations"],
query=query,
knowledge_graph_inst=knowledge_graph_inst,
text_chunks_db=text_chunks_db,
query_param=query_param,
chunks_vdb=chunks_vdb,
chunk_tracking=search_result["chunk_tracking"],
)
# Stage 4: Build final LLM context with dynamic token processing
context = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
)
return context
async def _get_node_data(
query: str,
knowledge_graph_inst: BaseGraphStorage,