Update operate.py

This commit is contained in:
zrguo 2025-07-15 18:57:57 +08:00
parent 93b25a65d5
commit 29e82723e6

View file

@ -1892,26 +1892,42 @@ async def _build_query_context(
entities_context = [] entities_context = []
relations_context = [] relations_context = []
# Store original data for later text chunk retrieval
original_node_datas = []
original_edge_datas = []
# Handle local and global modes # Handle local and global modes
if query_param.mode == "local": if query_param.mode == "local":
entities_context, relations_context, entity_chunks = await _get_node_data( (
entities_context,
relations_context,
node_datas,
use_relations,
) = await _get_node_data(
ll_keywords, ll_keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(entity_chunks) original_node_datas = node_datas
original_edge_datas = use_relations
elif query_param.mode == "global": elif query_param.mode == "global":
entities_context, relations_context, relationship_chunks = await _get_edge_data( (
entities_context,
relations_context,
edge_datas,
use_entities,
) = await _get_edge_data(
hl_keywords, hl_keywords,
knowledge_graph_inst, knowledge_graph_inst,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(relationship_chunks) original_edge_datas = edge_datas
original_node_datas = use_entities
else: # hybrid or mix mode else: # hybrid or mix mode
ll_data = await _get_node_data( ll_data = await _get_node_data(
@ -1929,10 +1945,13 @@ async def _build_query_context(
query_param, query_param,
) )
(ll_entities_context, ll_relations_context, ll_chunks) = ll_data (ll_entities_context, ll_relations_context, ll_node_datas, ll_edge_datas) = (
(hl_entities_context, hl_relations_context, hl_chunks) = hl_data ll_data
)
(hl_entities_context, hl_relations_context, hl_edge_datas, hl_node_datas) = (
hl_data
)
# Collect chunks from entity and relationship sources
# Get vector chunks first if in mix mode # Get vector chunks first if in mix mode
if query_param.mode == "mix" and chunks_vdb: if query_param.mode == "mix" and chunks_vdb:
vector_chunks = await _get_vector_context( vector_chunks = await _get_vector_context(
@ -1942,8 +1961,9 @@ async def _build_query_context(
) )
all_chunks.extend(vector_chunks) all_chunks.extend(vector_chunks)
all_chunks.extend(ll_chunks) # Store original data from both sources
all_chunks.extend(hl_chunks) original_node_datas = ll_node_datas + hl_node_datas
original_edge_datas = ll_edge_datas + hl_edge_datas
# Combine entities and relations contexts # Combine entities and relations contexts
entities_context = process_combine_contexts( entities_context = process_combine_contexts(
@ -2027,6 +2047,73 @@ async def _build_query_context(
f"Truncated relations: {original_relation_count} -> {len(relations_context)} (relation max tokens: {max_relation_tokens})" f"Truncated relations: {original_relation_count} -> {len(relations_context)} (relation max tokens: {max_relation_tokens})"
) )
# After truncation, get text chunks based on final entities and relations
logger.info("Getting text chunks based on truncated entities and relations...")
# Create filtered data based on truncated context
final_node_datas = []
final_edge_datas = []
if entities_context and original_node_datas:
# Create a set of entity names from final truncated context
final_entity_names = {entity["entity"] for entity in entities_context}
# Filter original node data based on final entities
final_node_datas = [
node
for node in original_node_datas
if node.get("entity_name") in final_entity_names
]
if relations_context and original_edge_datas:
# Create a set of relation pairs from final truncated context
final_relation_pairs = {
(rel["entity1"], rel["entity2"]) for rel in relations_context
}
# Filter original edge data based on final relations
final_edge_datas = [
edge
for edge in original_edge_datas
if (edge.get("src_id"), edge.get("tgt_id")) in final_relation_pairs
or (
edge.get("src_tgt", (None, None))[0],
edge.get("src_tgt", (None, None))[1],
)
in final_relation_pairs
]
# Get text chunks based on final filtered data
text_chunk_tasks = []
if final_node_datas:
text_chunk_tasks.append(
_find_most_related_text_unit_from_entities(
final_node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
)
if final_edge_datas:
text_chunk_tasks.append(
_find_related_text_unit_from_relationships(
final_edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
)
# Execute text chunk retrieval in parallel
if text_chunk_tasks:
text_chunk_results = await asyncio.gather(*text_chunk_tasks)
for chunks in text_chunk_results:
if chunks:
all_chunks.extend(chunks)
# Apply token processing to chunks if tokenizer is available
text_units_context = []
if tokenizer and all_chunks:
# Calculate dynamic token limit for text chunks # Calculate dynamic token limit for text chunks
entities_str = json.dumps(entities_context, ensure_ascii=False) entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False)
@ -2122,7 +2209,6 @@ async def _build_query_context(
) )
# Rebuild text_units_context with truncated chunks # Rebuild text_units_context with truncated chunks
text_units_context = []
for i, chunk in enumerate(truncated_chunks): for i, chunk in enumerate(truncated_chunks):
text_units_context.append( text_units_context.append(
{ {
@ -2187,7 +2273,7 @@ async def _get_node_data(
) )
if not len(results): if not len(results):
return "", "", "" return "", "", [], []
# Extract all entity IDs from your results list # Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
@ -2214,14 +2300,8 @@ async def _get_node_data(
} }
for k, n, d in zip(results, node_datas, node_degrees) for k, n, d in zip(results, node_datas, node_degrees)
if n is not None if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ]
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
use_relations = await _find_most_related_edges_from_entities( use_relations = await _find_most_related_edges_from_entities(
node_datas, node_datas,
query_param, query_param,
@ -2229,7 +2309,7 @@ async def _get_node_data(
) )
logger.info( logger.info(
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" f"Local query: {len(node_datas)} entites, {len(use_relations)} relations"
) )
# build prompt # build prompt
@ -2278,7 +2358,7 @@ async def _get_node_data(
} }
) )
return entities_context, relations_context, use_text_units return entities_context, relations_context, node_datas, use_relations
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
@ -2456,7 +2536,7 @@ async def _get_edge_data(
) )
if not len(results): if not len(results):
return "", "", "" return "", "", [], []
# Prepare edge pairs in two forms: # Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts. # For the batch edge properties function, use dicts.
@ -2495,21 +2575,15 @@ async def _get_edge_data(
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships( use_entities = await _find_most_related_entities_from_relationships(
edge_datas, edge_datas,
query_param, query_param,
knowledge_graph_inst, knowledge_graph_inst,
),
_find_related_text_unit_from_relationships(
edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
),
) )
logger.info( logger.info(
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations"
) )
relations_context = [] relations_context = []
@ -2558,16 +2632,8 @@ async def _get_edge_data(
} }
) )
text_units_context = [] # Return original data for later text chunk retrieval
for i, t in enumerate(use_text_units): return entities_context, relations_context, edge_datas, use_entities
text_units_context.append(
{
"id": i + 1,
"content": t["content"],
"file_path": t.get("file_path", "unknown"),
}
)
return entities_context, relations_context, text_units_context
async def _find_most_related_entities_from_relationships( async def _find_most_related_entities_from_relationships(