Add search method to lightrag. Search is for retrieve structured objects (entities, relations, chunks) in their raw data format.
This commit is contained in:
parent
5e73896c40
commit
a60a8704ba
1 changed files with 458 additions and 125 deletions
|
|
@ -1185,9 +1185,9 @@ async def _merge_nodes_then_upsert(
|
|||
|
||||
# Log based on actual LLM usage
|
||||
if llm_was_used:
|
||||
status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
|
||||
status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
|
||||
else:
|
||||
status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
|
||||
status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
|
||||
|
||||
logger.info(status_message)
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
|
|
@ -1302,9 +1302,9 @@ async def _merge_edges_then_upsert(
|
|||
|
||||
# Log based on actual LLM usage
|
||||
if llm_was_used:
|
||||
status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
|
||||
status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
|
||||
else:
|
||||
status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
|
||||
status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}"
|
||||
|
||||
logger.info(status_message)
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
|
|
@ -2285,7 +2285,7 @@ async def _get_vector_context(
|
|||
return []
|
||||
|
||||
|
||||
async def _build_query_context(
|
||||
async def _perform_kg_search(
|
||||
query: str,
|
||||
ll_keywords: str,
|
||||
hl_keywords: str,
|
||||
|
|
@ -2295,25 +2295,21 @@ async def _build_query_context(
|
|||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
):
|
||||
if not query:
|
||||
logger.warning("Query is empty, skipping context building")
|
||||
return ""
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Pure search logic that retrieves raw entities, relations, and vector chunks.
|
||||
No token truncation or formatting - just raw search results.
|
||||
"""
|
||||
|
||||
logger.info(f"Process {os.getpid()} building query context...")
|
||||
|
||||
# Collect chunks from different sources separately
|
||||
vector_chunks = []
|
||||
entity_chunks = []
|
||||
relation_chunks = []
|
||||
entities_context = []
|
||||
relations_context = []
|
||||
|
||||
# Store original data for later text chunk retrieval
|
||||
# Initialize result containers
|
||||
local_entities = []
|
||||
local_relations = []
|
||||
global_entities = []
|
||||
global_relations = []
|
||||
vector_chunks = []
|
||||
chunk_tracking = {}
|
||||
|
||||
# Handle different query modes
|
||||
|
||||
# Track chunk sources and metadata for final logging
|
||||
chunk_tracking = {} # chunk_id -> {source, frequency, order}
|
||||
|
|
@ -2369,7 +2365,7 @@ async def _build_query_context(
|
|||
query_param,
|
||||
)
|
||||
|
||||
# Get vector chunks first if in mix mode
|
||||
# Get vector chunks for mix mode
|
||||
if query_param.mode == "mix" and chunks_vdb:
|
||||
vector_chunks = await _get_vector_context(
|
||||
query,
|
||||
|
|
@ -2389,11 +2385,9 @@ async def _build_query_context(
|
|||
else:
|
||||
logger.warning(f"Vector chunk missing chunk_id: {chunk}")
|
||||
|
||||
# Use round-robin merge to combine local and global data fairly
|
||||
# Round-robin merge entities
|
||||
final_entities = []
|
||||
seen_entities = set()
|
||||
|
||||
# Round-robin merge entities
|
||||
max_len = max(len(local_entities), len(global_entities))
|
||||
for i in range(max_len):
|
||||
# First from local
|
||||
|
|
@ -2415,7 +2409,6 @@ async def _build_query_context(
|
|||
# Round-robin merge relations
|
||||
final_relations = []
|
||||
seen_relations = set()
|
||||
|
||||
max_len = max(len(local_relations), len(global_relations))
|
||||
for i in range(max_len):
|
||||
# First from local
|
||||
|
|
@ -2448,91 +2441,107 @@ async def _build_query_context(
|
|||
final_relations.append(relation)
|
||||
seen_relations.add(rel_key)
|
||||
|
||||
# Generate entities context
|
||||
logger.info(
|
||||
f"Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks"
|
||||
)
|
||||
|
||||
return {
|
||||
"final_entities": final_entities,
|
||||
"final_relations": final_relations,
|
||||
"vector_chunks": vector_chunks,
|
||||
"chunk_tracking": chunk_tracking,
|
||||
"query_embedding": query_embedding,
|
||||
}
|
||||
|
||||
|
||||
async def _apply_token_truncation(
|
||||
search_result: dict[str, Any],
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Apply token-based truncation to entities and relations for LLM efficiency.
|
||||
This function is only used by kg_query, not kg_search.
|
||||
"""
|
||||
tokenizer = global_config.get("tokenizer")
|
||||
if not tokenizer:
|
||||
logger.warning("No tokenizer found, skipping truncation")
|
||||
return {
|
||||
"truncated_entities": search_result["final_entities"],
|
||||
"truncated_relations": search_result["final_relations"],
|
||||
"entities_context": [],
|
||||
"relations_context": [],
|
||||
"filtered_entities": search_result["final_entities"],
|
||||
"filtered_relations": search_result["final_relations"],
|
||||
}
|
||||
|
||||
# Get token limits from query_param with fallbacks
|
||||
max_entity_tokens = getattr(
|
||||
query_param,
|
||||
"max_entity_tokens",
|
||||
global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS),
|
||||
)
|
||||
max_relation_tokens = getattr(
|
||||
query_param,
|
||||
"max_relation_tokens",
|
||||
global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS),
|
||||
)
|
||||
|
||||
final_entities = search_result["final_entities"]
|
||||
final_relations = search_result["final_relations"]
|
||||
|
||||
# Generate entities context for truncation
|
||||
entities_context = []
|
||||
for i, n in enumerate(final_entities):
|
||||
created_at = n.get("created_at", "UNKNOWN")
|
||||
for i, entity in enumerate(final_entities):
|
||||
created_at = entity.get("created_at", "UNKNOWN")
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from node data
|
||||
file_path = n.get("file_path", "unknown_source")
|
||||
|
||||
entities_context.append(
|
||||
{
|
||||
"id": i + 1,
|
||||
"entity": n["entity_name"],
|
||||
"type": n.get("entity_type", "UNKNOWN"),
|
||||
"description": n.get("description", "UNKNOWN"),
|
||||
"entity": entity["entity_name"],
|
||||
"type": entity.get("entity_type", "UNKNOWN"),
|
||||
"description": entity.get("description", "UNKNOWN"),
|
||||
"created_at": created_at,
|
||||
"file_path": file_path,
|
||||
"file_path": entity.get("file_path", "unknown_source"),
|
||||
}
|
||||
)
|
||||
|
||||
# Generate relations context
|
||||
# Generate relations context for truncation
|
||||
relations_context = []
|
||||
for i, e in enumerate(final_relations):
|
||||
created_at = e.get("created_at", "UNKNOWN")
|
||||
# Convert timestamp to readable format
|
||||
for i, relation in enumerate(final_relations):
|
||||
created_at = relation.get("created_at", "UNKNOWN")
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from edge data
|
||||
file_path = e.get("file_path", "unknown_source")
|
||||
|
||||
# Handle different relation data formats
|
||||
if "src_tgt" in e:
|
||||
entity1, entity2 = e["src_tgt"]
|
||||
if "src_tgt" in relation:
|
||||
entity1, entity2 = relation["src_tgt"]
|
||||
else:
|
||||
entity1, entity2 = e.get("src_id"), e.get("tgt_id")
|
||||
entity1, entity2 = relation.get("src_id"), relation.get("tgt_id")
|
||||
|
||||
relations_context.append(
|
||||
{
|
||||
"id": i + 1,
|
||||
"entity1": entity1,
|
||||
"entity2": entity2,
|
||||
"description": e.get("description", "UNKNOWN"),
|
||||
"description": relation.get("description", "UNKNOWN"),
|
||||
"created_at": created_at,
|
||||
"file_path": file_path,
|
||||
"file_path": relation.get("file_path", "unknown_source"),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Initial KG query results: {len(entities_context)} entities, {len(relations_context)} relations"
|
||||
f"Before truncation: {len(entities_context)} entities, {len(relations_context)} relations"
|
||||
)
|
||||
|
||||
# Unified token control system - Apply precise token limits to entities and relations
|
||||
tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||
# Get new token limits from query_param (with fallback to global_config)
|
||||
max_entity_tokens = getattr(
|
||||
query_param,
|
||||
"max_entity_tokens",
|
||||
text_chunks_db.global_config.get(
|
||||
"max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS
|
||||
),
|
||||
)
|
||||
max_relation_tokens = getattr(
|
||||
query_param,
|
||||
"max_relation_tokens",
|
||||
text_chunks_db.global_config.get(
|
||||
"max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS
|
||||
),
|
||||
)
|
||||
max_total_tokens = getattr(
|
||||
query_param,
|
||||
"max_total_tokens",
|
||||
text_chunks_db.global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
|
||||
)
|
||||
|
||||
# Truncate entities based on complete JSON serialization
|
||||
# Apply token-based truncation
|
||||
if entities_context:
|
||||
# Process entities context to replace GRAPH_FIELD_SEP with : in file_path fields
|
||||
# Remove file_path and created_at for token calculation
|
||||
for entity in entities_context:
|
||||
# remove file_path and created_at
|
||||
entity.pop("file_path", None)
|
||||
entity.pop("created_at", None)
|
||||
# if "file_path" in entity and entity["file_path"]:
|
||||
# entity["file_path"] = entity["file_path"].replace(GRAPH_FIELD_SEP, ";")
|
||||
|
||||
entities_context = truncate_list_by_token_size(
|
||||
entities_context,
|
||||
|
|
@ -2541,17 +2550,11 @@ async def _build_query_context(
|
|||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Truncate relations based on complete JSON serialization
|
||||
if relations_context:
|
||||
# Process relations context to replace GRAPH_FIELD_SEP with : in file_path fields
|
||||
# Remove file_path and created_at for token calculation
|
||||
for relation in relations_context:
|
||||
# remove file_path and created_at
|
||||
relation.pop("file_path", None)
|
||||
relation.pop("created_at", None)
|
||||
# if "file_path" in relation and relation["file_path"]:
|
||||
# relation["file_path"] = relation["file_path"].replace(
|
||||
# GRAPH_FIELD_SEP, ";"
|
||||
# )
|
||||
|
||||
relations_context = truncate_list_by_token_size(
|
||||
relations_context,
|
||||
|
|
@ -2560,41 +2563,86 @@ async def _build_query_context(
|
|||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# After truncation, get text chunks based on final entities and relations
|
||||
logger.info(
|
||||
f"Truncated KG query results: {len(entities_context)} entities, {len(relations_context)} relations"
|
||||
f"After truncation: {len(entities_context)} entities, {len(relations_context)} relations"
|
||||
)
|
||||
|
||||
# Create filtered data based on truncated context
|
||||
final_node_datas = []
|
||||
if entities_context and final_entities:
|
||||
# Create filtered original data based on truncated context
|
||||
filtered_entities = []
|
||||
if entities_context:
|
||||
final_entity_names = {e["entity"] for e in entities_context}
|
||||
seen_nodes = set()
|
||||
for node in final_entities:
|
||||
name = node.get("entity_name")
|
||||
for entity in final_entities:
|
||||
name = entity.get("entity_name")
|
||||
if name in final_entity_names and name not in seen_nodes:
|
||||
final_node_datas.append(node)
|
||||
filtered_entities.append(entity)
|
||||
seen_nodes.add(name)
|
||||
|
||||
final_edge_datas = []
|
||||
if relations_context and final_relations:
|
||||
filtered_relations = []
|
||||
if relations_context:
|
||||
final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context}
|
||||
seen_edges = set()
|
||||
for edge in final_relations:
|
||||
src, tgt = edge.get("src_id"), edge.get("tgt_id")
|
||||
for relation in final_relations:
|
||||
src, tgt = relation.get("src_id"), relation.get("tgt_id")
|
||||
if src is None or tgt is None:
|
||||
src, tgt = edge.get("src_tgt", (None, None))
|
||||
src, tgt = relation.get("src_tgt", (None, None))
|
||||
|
||||
pair = (src, tgt)
|
||||
if pair in final_relation_pairs and pair not in seen_edges:
|
||||
final_edge_datas.append(edge)
|
||||
filtered_relations.append(relation)
|
||||
seen_edges.add(pair)
|
||||
|
||||
# Get text chunks based on final filtered data
|
||||
# To preserve the influence of entity order, entiy-based chunks should not be deduplcicated by vector_chunks
|
||||
if final_node_datas:
|
||||
return {
|
||||
"truncated_entities": final_entities, # Keep original for backward compatibility
|
||||
"truncated_relations": final_relations, # Keep original for backward compatibility
|
||||
"entities_context": entities_context, # Formatted and truncated for LLM
|
||||
"relations_context": relations_context, # Formatted and truncated for LLM
|
||||
"filtered_entities": filtered_entities, # Original entities that passed truncation
|
||||
"filtered_relations": filtered_relations, # Original relations that passed truncation
|
||||
}
|
||||
|
||||
|
||||
async def _merge_all_chunks(
|
||||
search_result: dict[str, Any],
|
||||
filtered_entities: list[dict] = None,
|
||||
filtered_relations: list[dict] = None,
|
||||
query: str = "",
|
||||
knowledge_graph_inst: BaseGraphStorage = None,
|
||||
text_chunks_db: BaseKVStorage = None,
|
||||
query_param: QueryParam = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
chunk_tracking: dict = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks.
|
||||
|
||||
For kg_search: uses all original entities/relations
|
||||
For kg_query: uses filtered entities/relations based on token truncation
|
||||
"""
|
||||
if chunk_tracking is None:
|
||||
chunk_tracking = search_result.get("chunk_tracking", {})
|
||||
|
||||
# Use filtered entities/relations if provided (kg_query), otherwise use all (kg_search)
|
||||
entities_to_use = (
|
||||
filtered_entities
|
||||
if filtered_entities is not None
|
||||
else search_result["final_entities"]
|
||||
)
|
||||
relations_to_use = (
|
||||
filtered_relations
|
||||
if filtered_relations is not None
|
||||
else search_result["final_relations"]
|
||||
)
|
||||
vector_chunks = search_result["vector_chunks"]
|
||||
|
||||
# Get chunks from entities
|
||||
entity_chunks = []
|
||||
if entities_to_use and text_chunks_db:
|
||||
# Pre-compute query embedding if needed
|
||||
query_embedding = search_result.get("query_embedding", None)
|
||||
|
||||
entity_chunks = await _find_related_text_unit_from_entities(
|
||||
final_node_datas,
|
||||
entities_to_use,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
|
|
@ -2604,21 +2652,21 @@ async def _build_query_context(
|
|||
query_embedding=query_embedding,
|
||||
)
|
||||
|
||||
# Find deduplcicated chunks from edge
|
||||
# Deduplication cause chunks solely relation-based to be prioritized and sent to the LLM when re-ranking is disabled
|
||||
if final_edge_datas:
|
||||
# Get chunks from relations
|
||||
relation_chunks = []
|
||||
if relations_to_use and text_chunks_db:
|
||||
relation_chunks = await _find_related_text_unit_from_relations(
|
||||
final_edge_datas,
|
||||
relations_to_use,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
entity_chunks,
|
||||
entity_chunks, # For deduplication
|
||||
query,
|
||||
chunks_vdb,
|
||||
chunk_tracking=chunk_tracking,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
|
||||
# Round-robin merge chunks from different sources with deduplication by chunk_id
|
||||
# Round-robin merge chunks from different sources with deduplication
|
||||
merged_chunks = []
|
||||
seen_chunk_ids = set()
|
||||
max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks))
|
||||
|
|
@ -2668,12 +2716,238 @@ async def _build_query_context(
|
|||
)
|
||||
|
||||
logger.info(
|
||||
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
|
||||
f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplication: {origin_len - len(merged_chunks)})"
|
||||
)
|
||||
|
||||
return merged_chunks
|
||||
|
||||
|
||||
async def kg_search(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Search knowledge graph and return structured results without LLM generation
|
||||
|
||||
For kg_search: Search + Merge chunks (NO truncation)
|
||||
Returns complete search results for user analysis
|
||||
"""
|
||||
if not query:
|
||||
return {
|
||||
"entities": [],
|
||||
"relationships": [],
|
||||
"chunks": [],
|
||||
"metadata": {
|
||||
"query_mode": query_param.mode,
|
||||
"keywords": {"high_level": [], "low_level": []},
|
||||
},
|
||||
}
|
||||
|
||||
# Handle cache (reuse existing cache logic but for search results)
|
||||
args_hash = compute_args_hash(
|
||||
query_param.mode,
|
||||
query,
|
||||
"search", # Different cache key for search vs query
|
||||
query_param.top_k,
|
||||
query_param.chunk_top_k,
|
||||
query_param.max_entity_tokens,
|
||||
query_param.max_relation_tokens,
|
||||
query_param.max_total_tokens,
|
||||
query_param.hl_keywords or [],
|
||||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
)
|
||||
cached_response = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="search"
|
||||
)
|
||||
if cached_response is not None:
|
||||
try:
|
||||
return json_repair.loads(cached_response)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
logger.warning(
|
||||
"Invalid cache format for search results, proceeding with fresh search"
|
||||
)
|
||||
|
||||
# Get keywords (reuse existing logic)
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
logger.debug(f"High-level keywords: {hl_keywords}")
|
||||
logger.debug(f"Low-level keywords: {ll_keywords}")
|
||||
|
||||
# Handle empty keywords (reuse existing logic)
|
||||
if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]:
|
||||
logger.warning("low_level_keywords is empty")
|
||||
if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]:
|
||||
logger.warning("high_level_keywords is empty")
|
||||
if hl_keywords == [] and ll_keywords == []:
|
||||
if len(query) < 50:
|
||||
logger.warning(f"Forced low_level_keywords to origin query: {query}")
|
||||
ll_keywords = [query]
|
||||
else:
|
||||
return {
|
||||
"entities": [],
|
||||
"relationships": [],
|
||||
"chunks": [],
|
||||
"metadata": {
|
||||
"query_mode": query_param.mode,
|
||||
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
|
||||
"error": "Keywords extraction failed",
|
||||
},
|
||||
}
|
||||
|
||||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||||
|
||||
# Stage 1: Pure search (no truncation for kg_search)
|
||||
search_result = await _perform_kg_search(
|
||||
query,
|
||||
ll_keywords_str,
|
||||
hl_keywords_str,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
query_param,
|
||||
chunks_vdb,
|
||||
)
|
||||
|
||||
if not search_result["final_entities"] and not search_result["final_relations"]:
|
||||
return {
|
||||
"entities": [],
|
||||
"relationships": [],
|
||||
"chunks": [],
|
||||
"metadata": {
|
||||
"query_mode": query_param.mode,
|
||||
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
|
||||
"error": "No valid results found",
|
||||
},
|
||||
}
|
||||
|
||||
# Stage 2: Merge ALL chunks (no filtering, use all entities/relations)
|
||||
merged_chunks = await _merge_all_chunks(
|
||||
search_result,
|
||||
filtered_entities=None, # Use ALL entities (no filtering)
|
||||
filtered_relations=None, # Use ALL relations (no filtering)
|
||||
query=query,
|
||||
knowledge_graph_inst=knowledge_graph_inst,
|
||||
text_chunks_db=text_chunks_db,
|
||||
query_param=query_param,
|
||||
chunks_vdb=chunks_vdb,
|
||||
chunk_tracking=search_result["chunk_tracking"],
|
||||
)
|
||||
|
||||
# Build final structured result
|
||||
final_result = {
|
||||
"entities": search_result["final_entities"],
|
||||
"relationships": search_result["final_relations"],
|
||||
"chunks": merged_chunks,
|
||||
"metadata": {
|
||||
"query_mode": query_param.mode,
|
||||
"keywords": {"high_level": hl_keywords, "low_level": ll_keywords},
|
||||
},
|
||||
}
|
||||
|
||||
# Cache the results
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
queryparam_dict = {
|
||||
"mode": query_param.mode,
|
||||
"response_type": "search",
|
||||
"top_k": query_param.top_k,
|
||||
"chunk_top_k": query_param.chunk_top_k,
|
||||
"max_entity_tokens": query_param.max_entity_tokens,
|
||||
"max_relation_tokens": query_param.max_relation_tokens,
|
||||
"max_total_tokens": query_param.max_total_tokens,
|
||||
"hl_keywords": query_param.hl_keywords or [],
|
||||
"ll_keywords": query_param.ll_keywords or [],
|
||||
"user_prompt": query_param.user_prompt or "",
|
||||
"enable_rerank": query_param.enable_rerank,
|
||||
}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=json.dumps(final_result, ensure_ascii=False),
|
||||
prompt=query,
|
||||
mode=query_param.mode,
|
||||
cache_type="search",
|
||||
queryparam=queryparam_dict,
|
||||
),
|
||||
)
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
async def _build_llm_context(
|
||||
entities_context: list[dict],
|
||||
relations_context: list[dict],
|
||||
merged_chunks: list[dict],
|
||||
query: str,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
chunk_tracking: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the final LLM context string with token processing.
|
||||
This includes dynamic token calculation and final chunk truncation.
|
||||
"""
|
||||
tokenizer = global_config.get("tokenizer")
|
||||
if not tokenizer:
|
||||
logger.warning("No tokenizer found, building context without token limits")
|
||||
|
||||
# Build basic context without token processing
|
||||
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||||
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
||||
|
||||
text_units_context = []
|
||||
for i, chunk in enumerate(merged_chunks):
|
||||
text_units_context.append(
|
||||
{
|
||||
"id": i + 1,
|
||||
"content": chunk["content"],
|
||||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
}
|
||||
)
|
||||
|
||||
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||||
|
||||
return f"""-----Entities(KG)-----
|
||||
|
||||
```json
|
||||
{entities_str}
|
||||
```
|
||||
|
||||
-----Relationships(KG)-----
|
||||
|
||||
```json
|
||||
{relations_str}
|
||||
```
|
||||
|
||||
-----Document Chunks(DC)-----
|
||||
|
||||
```json
|
||||
{text_units_str}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
# Get token limits
|
||||
max_total_tokens = getattr(
|
||||
query_param,
|
||||
"max_total_tokens",
|
||||
global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
|
||||
)
|
||||
|
||||
# Apply token processing to merged chunks
|
||||
text_units_context = []
|
||||
truncated_chunks = []
|
||||
|
||||
if merged_chunks:
|
||||
# Calculate dynamic token limit for text chunks
|
||||
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||||
|
|
@ -2704,18 +2978,7 @@ async def _build_query_context(
|
|||
)
|
||||
kg_context_tokens = len(tokenizer.encode(kg_context))
|
||||
|
||||
# Calculate actual system prompt overhead dynamically
|
||||
# 1. Converstion history not included in context length calculation
|
||||
history_context = ""
|
||||
# if query_param.conversation_history:
|
||||
# history_context = get_conversation_turns(
|
||||
# query_param.conversation_history, query_param.history_turns
|
||||
# )
|
||||
# history_tokens = (
|
||||
# len(tokenizer.encode(history_context)) if history_context else 0
|
||||
# )
|
||||
|
||||
# 2. Calculate system prompt template tokens (excluding context_data)
|
||||
# Calculate system prompt template overhead
|
||||
user_prompt = query_param.user_prompt if query_param.user_prompt else ""
|
||||
response_type = (
|
||||
query_param.response_type
|
||||
|
|
@ -2723,14 +2986,14 @@ async def _build_query_context(
|
|||
else "Multiple Paragraphs"
|
||||
)
|
||||
|
||||
# Get the system prompt template from PROMPTS
|
||||
sys_prompt_template = text_chunks_db.global_config.get(
|
||||
# Get the system prompt template from PROMPTS or global_config
|
||||
sys_prompt_template = global_config.get(
|
||||
"system_prompt_template", PROMPTS["rag_response"]
|
||||
)
|
||||
|
||||
# Create a sample system prompt with placeholders filled (excluding context_data)
|
||||
# Create sample system prompt for overhead calculation
|
||||
sample_sys_prompt = sys_prompt_template.format(
|
||||
history=history_context,
|
||||
history="", # History not included in context length calculation
|
||||
context_data="", # Empty for overhead calculation
|
||||
response_type=response_type,
|
||||
user_prompt=user_prompt,
|
||||
|
|
@ -2756,7 +3019,7 @@ async def _build_query_context(
|
|||
query=query,
|
||||
unique_chunks=merged_chunks,
|
||||
query_param=query_param,
|
||||
global_config=text_chunks_db.global_config,
|
||||
global_config=global_config,
|
||||
source_type=query_param.mode,
|
||||
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
|
||||
)
|
||||
|
|
@ -2827,6 +3090,76 @@ async def _build_query_context(
|
|||
return result
|
||||
|
||||
|
||||
# Now let's update the old _build_query_context to use the new architecture
|
||||
async def _build_query_context(
|
||||
query: str,
|
||||
ll_keywords: str,
|
||||
hl_keywords: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
) -> str:
|
||||
"""
|
||||
Main query context building function using the new 4-stage architecture:
|
||||
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
|
||||
"""
|
||||
|
||||
if not query:
|
||||
logger.warning("Query is empty, skipping context building")
|
||||
return ""
|
||||
|
||||
# Stage 1: Pure search
|
||||
search_result = await _perform_kg_search(
|
||||
query,
|
||||
ll_keywords,
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
query_param,
|
||||
chunks_vdb,
|
||||
)
|
||||
|
||||
if not search_result["final_entities"] and not search_result["final_relations"]:
|
||||
return None
|
||||
|
||||
# Stage 2: Apply token truncation for LLM efficiency
|
||||
truncation_result = await _apply_token_truncation(
|
||||
search_result,
|
||||
query_param,
|
||||
text_chunks_db.global_config,
|
||||
)
|
||||
|
||||
# Stage 3: Merge chunks using filtered entities/relations
|
||||
merged_chunks = await _merge_all_chunks(
|
||||
search_result,
|
||||
filtered_entities=truncation_result["filtered_entities"],
|
||||
filtered_relations=truncation_result["filtered_relations"],
|
||||
query=query,
|
||||
knowledge_graph_inst=knowledge_graph_inst,
|
||||
text_chunks_db=text_chunks_db,
|
||||
query_param=query_param,
|
||||
chunks_vdb=chunks_vdb,
|
||||
chunk_tracking=search_result["chunk_tracking"],
|
||||
)
|
||||
|
||||
# Stage 4: Build final LLM context with dynamic token processing
|
||||
context = await _build_llm_context(
|
||||
entities_context=truncation_result["entities_context"],
|
||||
relations_context=truncation_result["relations_context"],
|
||||
merged_chunks=merged_chunks,
|
||||
query=query,
|
||||
query_param=query_param,
|
||||
global_config=text_chunks_db.global_config,
|
||||
chunk_tracking=search_result["chunk_tracking"],
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
async def _get_node_data(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue