update chunks truncation method
This commit is contained in:
parent
f5c80d7cde
commit
04a57445da
5 changed files with 211 additions and 180 deletions
10
README-zh.md
10
README-zh.md
|
|
@ -294,6 +294,16 @@ class QueryParam:
|
||||||
top_k: int = int(os.getenv("TOP_K", "60"))
|
top_k: int = int(os.getenv("TOP_K", "60"))
|
||||||
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||||
|
|
||||||
|
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
|
||||||
|
"""Number of text chunks to retrieve initially from vector search.
|
||||||
|
If None, defaults to top_k value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
|
||||||
|
"""Number of text chunks to keep after reranking.
|
||||||
|
If None, keeps all chunks returned from initial retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
|
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
|
||||||
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||||
|
|
||||||
|
|
|
||||||
12
README.md
12
README.md
|
|
@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d
|
||||||
python examples/lightrag_openai_demo.py
|
python examples/lightrag_openai_demo.py
|
||||||
```
|
```
|
||||||
|
|
||||||
For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code’s LLM and embedding configurations accordingly.
|
For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly.
|
||||||
|
|
||||||
**Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
|
**Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
|
||||||
|
|
||||||
|
|
@ -300,6 +300,16 @@ class QueryParam:
|
||||||
top_k: int = int(os.getenv("TOP_K", "60"))
|
top_k: int = int(os.getenv("TOP_K", "60"))
|
||||||
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||||
|
|
||||||
|
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
|
||||||
|
"""Number of text chunks to retrieve initially from vector search.
|
||||||
|
If None, defaults to top_k value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
|
||||||
|
"""Number of text chunks to keep after reranking.
|
||||||
|
If None, keeps all chunks returned from initial retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
|
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
|
||||||
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,9 @@ OLLAMA_EMULATING_MODEL_TAG=latest
|
||||||
# HISTORY_TURNS=3
|
# HISTORY_TURNS=3
|
||||||
# COSINE_THRESHOLD=0.2
|
# COSINE_THRESHOLD=0.2
|
||||||
# TOP_K=60
|
# TOP_K=60
|
||||||
# MAX_TOKEN_TEXT_CHUNK=4000
|
# CHUNK_TOP_K=5
|
||||||
|
# CHUNK_RERANK_TOP_K=5
|
||||||
|
# MAX_TOKEN_TEXT_CHUNK=6000
|
||||||
# MAX_TOKEN_RELATION_DESC=4000
|
# MAX_TOKEN_RELATION_DESC=4000
|
||||||
# MAX_TOKEN_ENTITY_DESC=4000
|
# MAX_TOKEN_ENTITY_DESC=4000
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,17 @@ class QueryParam:
|
||||||
top_k: int = int(os.getenv("TOP_K", "60"))
|
top_k: int = int(os.getenv("TOP_K", "60"))
|
||||||
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||||
|
|
||||||
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
|
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
|
||||||
|
"""Number of text chunks to retrieve initially from vector search.
|
||||||
|
If None, defaults to top_k value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
|
||||||
|
"""Number of text chunks to keep after reranking.
|
||||||
|
If None, keeps all chunks returned from initial retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000"))
|
||||||
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||||
|
|
||||||
max_token_for_global_context: int = int(
|
max_token_for_global_context: int = int(
|
||||||
|
|
@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
||||||
False: if the cache drop failed, or the cache mode is not supported
|
False: if the cache drop failed, or the cache mode is not supported
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
|
||||||
# """Delete specific cache records from storage by chunk IDs
|
|
||||||
|
|
||||||
# Importance notes for in-memory storage:
|
|
||||||
# 1. Changes will be persisted to disk during the next index_done_callback
|
|
||||||
# 2. update flags to notify other processes that data persistence is needed
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# chunk_ids (list[str]): List of chunk IDs to be dropped from storage
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# True: if the cache drop successfully
|
|
||||||
# False: if the cache drop failed, or the operation is not supported
|
|
||||||
# """
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace, ABC):
|
class BaseGraphStorage(StorageNameSpace, ABC):
|
||||||
|
|
|
||||||
|
|
@ -1526,6 +1526,7 @@ async def kg_query(
|
||||||
|
|
||||||
# Build context
|
# Build context
|
||||||
context = await _build_query_context(
|
context = await _build_query_context(
|
||||||
|
query,
|
||||||
ll_keywords_str,
|
ll_keywords_str,
|
||||||
hl_keywords_str,
|
hl_keywords_str,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
|
|
@ -1744,93 +1745,52 @@ async def _get_vector_context(
|
||||||
query: str,
|
query: str,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
tokenizer: Tokenizer,
|
) -> list[dict]:
|
||||||
) -> tuple[list, list, list] | None:
|
|
||||||
"""
|
"""
|
||||||
Retrieve vector context from the vector database.
|
Retrieve text chunks from the vector database without reranking or truncation.
|
||||||
|
|
||||||
This function performs vector search to find relevant text chunks for a query,
|
This function performs vector search to find relevant text chunks for a query.
|
||||||
formats them with file path and creation time information.
|
Reranking and truncation will be handled later in the unified processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The query string to search for
|
query: The query string to search for
|
||||||
chunks_vdb: Vector database containing document chunks
|
chunks_vdb: Vector database containing document chunks
|
||||||
query_param: Query parameters including top_k and ids
|
query_param: Query parameters including chunk_top_k and ids
|
||||||
tokenizer: Tokenizer for counting tokens
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple (empty_entities, empty_relations, text_units) for combine_contexts,
|
List of text chunks with metadata
|
||||||
compatible with _get_edge_data and _get_node_data format
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = await chunks_vdb.query(
|
# Use chunk_top_k if specified, otherwise fall back to top_k
|
||||||
query, top_k=query_param.top_k, ids=query_param.ids
|
search_top_k = query_param.chunk_top_k or query_param.top_k
|
||||||
)
|
|
||||||
|
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
|
||||||
if not results:
|
if not results:
|
||||||
return [], [], []
|
return []
|
||||||
|
|
||||||
valid_chunks = []
|
valid_chunks = []
|
||||||
for result in results:
|
for result in results:
|
||||||
if "content" in result:
|
if "content" in result:
|
||||||
# Directly use content from chunks_vdb.query result
|
chunk_with_metadata = {
|
||||||
chunk_with_time = {
|
|
||||||
"content": result["content"],
|
"content": result["content"],
|
||||||
"created_at": result.get("created_at", None),
|
"created_at": result.get("created_at", None),
|
||||||
"file_path": result.get("file_path", "unknown_source"),
|
"file_path": result.get("file_path", "unknown_source"),
|
||||||
|
"source_type": "vector", # Mark the source type
|
||||||
}
|
}
|
||||||
valid_chunks.append(chunk_with_time)
|
valid_chunks.append(chunk_with_metadata)
|
||||||
|
|
||||||
if not valid_chunks:
|
|
||||||
return [], [], []
|
|
||||||
|
|
||||||
# Apply reranking if enabled
|
|
||||||
global_config = chunks_vdb.global_config
|
|
||||||
valid_chunks = await apply_rerank_if_enabled(
|
|
||||||
query=query,
|
|
||||||
retrieved_docs=valid_chunks,
|
|
||||||
global_config=global_config,
|
|
||||||
top_k=query_param.top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
maybe_trun_chunks = truncate_list_by_token_size(
|
|
||||||
valid_chunks,
|
|
||||||
key=lambda x: x["content"],
|
|
||||||
max_token_size=query_param.max_token_for_text_unit,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
f"Vector search retrieved {len(valid_chunks)} chunks (top_k: {search_top_k})"
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
|
|
||||||
)
|
)
|
||||||
|
return valid_chunks
|
||||||
|
|
||||||
if not maybe_trun_chunks:
|
|
||||||
return [], [], []
|
|
||||||
|
|
||||||
# Create empty entities and relations contexts
|
|
||||||
entities_context = []
|
|
||||||
relations_context = []
|
|
||||||
|
|
||||||
# Create text_units_context directly as a list of dictionaries
|
|
||||||
text_units_context = []
|
|
||||||
for i, chunk in enumerate(maybe_trun_chunks):
|
|
||||||
text_units_context.append(
|
|
||||||
{
|
|
||||||
"id": i + 1,
|
|
||||||
"content": chunk["content"],
|
|
||||||
"file_path": chunk["file_path"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return entities_context, relations_context, text_units_context
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in _get_vector_context: {e}")
|
logger.error(f"Error in _get_vector_context: {e}")
|
||||||
return [], [], []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def _build_query_context(
|
async def _build_query_context(
|
||||||
|
query: str,
|
||||||
ll_keywords: str,
|
ll_keywords: str,
|
||||||
hl_keywords: str,
|
hl_keywords: str,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
|
@ -1838,27 +1798,36 @@ async def _build_query_context(
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
):
|
):
|
||||||
logger.info(f"Process {os.getpid()} building query context...")
|
logger.info(f"Process {os.getpid()} building query context...")
|
||||||
|
|
||||||
# Handle local and global modes as before
|
# Collect all chunks from different sources
|
||||||
|
all_chunks = []
|
||||||
|
entities_context = []
|
||||||
|
relations_context = []
|
||||||
|
|
||||||
|
# Handle local and global modes
|
||||||
if query_param.mode == "local":
|
if query_param.mode == "local":
|
||||||
entities_context, relations_context, text_units_context = await _get_node_data(
|
entities_context, relations_context, entity_chunks = await _get_node_data(
|
||||||
ll_keywords,
|
ll_keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
entities_vdb,
|
entities_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
all_chunks.extend(entity_chunks)
|
||||||
|
|
||||||
elif query_param.mode == "global":
|
elif query_param.mode == "global":
|
||||||
entities_context, relations_context, text_units_context = await _get_edge_data(
|
entities_context, relations_context, relationship_chunks = await _get_edge_data(
|
||||||
hl_keywords,
|
hl_keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
all_chunks.extend(relationship_chunks)
|
||||||
|
|
||||||
else: # hybrid or mix mode
|
else: # hybrid or mix mode
|
||||||
ll_data = await _get_node_data(
|
ll_data = await _get_node_data(
|
||||||
ll_keywords,
|
ll_keywords,
|
||||||
|
|
@ -1875,61 +1844,58 @@ async def _build_query_context(
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(ll_entities_context, ll_relations_context, ll_chunks) = ll_data
|
||||||
ll_entities_context,
|
(hl_entities_context, hl_relations_context, hl_chunks) = hl_data
|
||||||
ll_relations_context,
|
|
||||||
ll_text_units_context,
|
|
||||||
) = ll_data
|
|
||||||
|
|
||||||
(
|
# Collect chunks from entity and relationship sources
|
||||||
hl_entities_context,
|
all_chunks.extend(ll_chunks)
|
||||||
hl_relations_context,
|
all_chunks.extend(hl_chunks)
|
||||||
hl_text_units_context,
|
|
||||||
) = hl_data
|
|
||||||
|
|
||||||
# Initialize vector data with empty lists
|
# Get vector chunks if in mix mode
|
||||||
vector_entities_context, vector_relations_context, vector_text_units_context = (
|
if query_param.mode == "mix" and chunks_vdb:
|
||||||
[],
|
vector_chunks = await _get_vector_context(
|
||||||
[],
|
query,
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only get vector data if in mix mode
|
|
||||||
if query_param.mode == "mix" and hasattr(query_param, "original_query"):
|
|
||||||
# Get tokenizer from text_chunks_db
|
|
||||||
tokenizer = text_chunks_db.global_config.get("tokenizer")
|
|
||||||
|
|
||||||
# Get vector context in triple format
|
|
||||||
vector_data = await _get_vector_context(
|
|
||||||
query_param.original_query, # We need to pass the original query
|
|
||||||
chunks_vdb,
|
chunks_vdb,
|
||||||
query_param,
|
query_param,
|
||||||
tokenizer,
|
|
||||||
)
|
)
|
||||||
|
all_chunks.extend(vector_chunks)
|
||||||
|
|
||||||
# If vector_data is not None, unpack it
|
# Combine entities and relations contexts
|
||||||
if vector_data is not None:
|
|
||||||
(
|
|
||||||
vector_entities_context,
|
|
||||||
vector_relations_context,
|
|
||||||
vector_text_units_context,
|
|
||||||
) = vector_data
|
|
||||||
|
|
||||||
# Combine and deduplicate the entities, relationships, and sources
|
|
||||||
entities_context = process_combine_contexts(
|
entities_context = process_combine_contexts(
|
||||||
hl_entities_context, ll_entities_context, vector_entities_context
|
hl_entities_context, ll_entities_context
|
||||||
)
|
)
|
||||||
relations_context = process_combine_contexts(
|
relations_context = process_combine_contexts(
|
||||||
hl_relations_context, ll_relations_context, vector_relations_context
|
hl_relations_context, ll_relations_context
|
||||||
)
|
)
|
||||||
text_units_context = process_combine_contexts(
|
|
||||||
hl_text_units_context, ll_text_units_context, vector_text_units_context
|
# Process all chunks uniformly: deduplication, reranking, and token truncation
|
||||||
|
processed_chunks = await process_chunks_unified(
|
||||||
|
query=query,
|
||||||
|
chunks=all_chunks,
|
||||||
|
query_param=query_param,
|
||||||
|
global_config=text_chunks_db.global_config,
|
||||||
|
source_type="mixed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build final text_units_context from processed chunks
|
||||||
|
text_units_context = []
|
||||||
|
for i, chunk in enumerate(processed_chunks):
|
||||||
|
text_units_context.append(
|
||||||
|
{
|
||||||
|
"id": i + 1,
|
||||||
|
"content": chunk["content"],
|
||||||
|
"file_path": chunk.get("file_path", "unknown_source"),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
|
||||||
|
)
|
||||||
|
|
||||||
# not necessary to use LLM to generate a response
|
# not necessary to use LLM to generate a response
|
||||||
if not entities_context and not relations_context:
|
if not entities_context and not relations_context:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 转换为 JSON 字符串
|
|
||||||
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||||||
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
||||||
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||||||
|
|
@ -1975,15 +1941,6 @@ async def _get_node_data(
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
|
|
||||||
# Apply reranking if enabled for entity results
|
|
||||||
global_config = entities_vdb.global_config
|
|
||||||
results = await apply_rerank_if_enabled(
|
|
||||||
query=query,
|
|
||||||
retrieved_docs=results,
|
|
||||||
global_config=global_config,
|
|
||||||
top_k=query_param.top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract all entity IDs from your results list
|
# Extract all entity IDs from your results list
|
||||||
node_ids = [r["entity_name"] for r in results]
|
node_ids = [r["entity_name"] for r in results]
|
||||||
|
|
||||||
|
|
@ -2085,16 +2042,7 @@ async def _get_node_data(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
text_units_context = []
|
return entities_context, relations_context, use_text_units
|
||||||
for i, t in enumerate(use_text_units):
|
|
||||||
text_units_context.append(
|
|
||||||
{
|
|
||||||
"id": i + 1,
|
|
||||||
"content": t["content"],
|
|
||||||
"file_path": t.get("file_path", "unknown_source"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return entities_context, relations_context, text_units_context
|
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_text_unit_from_entities(
|
async def _find_most_related_text_unit_from_entities(
|
||||||
|
|
@ -2183,23 +2131,21 @@ async def _find_most_related_text_unit_from_entities(
|
||||||
logger.warning("No valid text units found")
|
logger.warning("No valid text units found")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
# Sort by relation counts and order, but don't truncate
|
||||||
all_text_units = sorted(
|
all_text_units = sorted(
|
||||||
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
||||||
)
|
)
|
||||||
all_text_units = truncate_list_by_token_size(
|
|
||||||
all_text_units,
|
|
||||||
key=lambda x: x["data"]["content"],
|
|
||||||
max_token_size=query_param.max_token_for_text_unit,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Found {len(all_text_units)} entity-related chunks")
|
||||||
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
|
|
||||||
)
|
|
||||||
|
|
||||||
all_text_units = [t["data"] for t in all_text_units]
|
# Add source type marking and return chunk data
|
||||||
return all_text_units
|
result_chunks = []
|
||||||
|
for t in all_text_units:
|
||||||
|
chunk_data = t["data"].copy()
|
||||||
|
chunk_data["source_type"] = "entity"
|
||||||
|
result_chunks.append(chunk_data)
|
||||||
|
|
||||||
|
return result_chunks
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_edges_from_entities(
|
async def _find_most_related_edges_from_entities(
|
||||||
|
|
@ -2287,15 +2233,6 @@ async def _get_edge_data(
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
|
|
||||||
# Apply reranking if enabled for relationship results
|
|
||||||
global_config = relationships_vdb.global_config
|
|
||||||
results = await apply_rerank_if_enabled(
|
|
||||||
query=keywords,
|
|
||||||
retrieved_docs=results,
|
|
||||||
global_config=global_config,
|
|
||||||
top_k=query_param.top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare edge pairs in two forms:
|
# Prepare edge pairs in two forms:
|
||||||
# For the batch edge properties function, use dicts.
|
# For the batch edge properties function, use dicts.
|
||||||
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
|
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
|
||||||
|
|
@ -2510,21 +2447,16 @@ async def _find_related_text_unit_from_relationships(
|
||||||
logger.warning("No valid text chunks after filtering")
|
logger.warning("No valid text chunks after filtering")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
logger.debug(f"Found {len(valid_text_units)} relationship-related chunks")
|
||||||
truncated_text_units = truncate_list_by_token_size(
|
|
||||||
valid_text_units,
|
|
||||||
key=lambda x: x["data"]["content"],
|
|
||||||
max_token_size=query_param.max_token_for_text_unit,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
# Add source type marking and return chunk data
|
||||||
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
|
result_chunks = []
|
||||||
)
|
for t in valid_text_units:
|
||||||
|
chunk_data = t["data"].copy()
|
||||||
|
chunk_data["source_type"] = "relationship"
|
||||||
|
result_chunks.append(chunk_data)
|
||||||
|
|
||||||
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
|
return result_chunks
|
||||||
|
|
||||||
return all_text_units
|
|
||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
|
|
@ -2552,12 +2484,30 @@ async def naive_query(
|
||||||
|
|
||||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
|
|
||||||
_, _, text_units_context = await _get_vector_context(
|
chunks = await _get_vector_context(query, chunks_vdb, query_param)
|
||||||
query, chunks_vdb, query_param, tokenizer
|
|
||||||
|
if chunks is None or len(chunks) == 0:
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
# Process chunks using unified processing
|
||||||
|
processed_chunks = await process_chunks_unified(
|
||||||
|
query=query,
|
||||||
|
chunks=chunks,
|
||||||
|
query_param=query_param,
|
||||||
|
global_config=global_config,
|
||||||
|
source_type="vector",
|
||||||
)
|
)
|
||||||
|
|
||||||
if text_units_context is None or len(text_units_context) == 0:
|
# Build text_units_context from processed chunks
|
||||||
return PROMPTS["fail_response"]
|
text_units_context = []
|
||||||
|
for i, chunk in enumerate(processed_chunks):
|
||||||
|
text_units_context.append(
|
||||||
|
{
|
||||||
|
"id": i + 1,
|
||||||
|
"content": chunk["content"],
|
||||||
|
"file_path": chunk.get("file_path", "unknown_source"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
|
|
@ -2683,6 +2633,7 @@ async def kg_query_with_keywords(
|
||||||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||||||
|
|
||||||
context = await _build_query_context(
|
context = await _build_query_context(
|
||||||
|
query,
|
||||||
ll_keywords_str,
|
ll_keywords_str,
|
||||||
hl_keywords_str,
|
hl_keywords_str,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
|
|
@ -2805,8 +2756,6 @@ async def query_with_keywords(
|
||||||
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
|
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
|
||||||
)
|
)
|
||||||
|
|
||||||
param.original_query = query
|
|
||||||
|
|
||||||
# Use appropriate query method based on mode
|
# Use appropriate query method based on mode
|
||||||
if param.mode in ["local", "global", "hybrid", "mix"]:
|
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||||||
return await kg_query_with_keywords(
|
return await kg_query_with_keywords(
|
||||||
|
|
@ -2887,3 +2836,68 @@ async def apply_rerank_if_enabled(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during reranking: {e}, using original documents")
|
logger.error(f"Error during reranking: {e}, using original documents")
|
||||||
return retrieved_docs
|
return retrieved_docs
|
||||||
|
|
||||||
|
|
||||||
|
async def process_chunks_unified(
|
||||||
|
query: str,
|
||||||
|
chunks: list[dict],
|
||||||
|
query_param: QueryParam,
|
||||||
|
global_config: dict,
|
||||||
|
source_type: str = "mixed",
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Unified processing for text chunks: deduplication, reranking, and token truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query for reranking
|
||||||
|
chunks: List of text chunks to process
|
||||||
|
query_param: Query parameters containing configuration
|
||||||
|
global_config: Global configuration dictionary
|
||||||
|
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed and filtered list of text chunks
|
||||||
|
"""
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 1. Deduplication based on content
|
||||||
|
seen_content = set()
|
||||||
|
unique_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
content = chunk.get("content", "")
|
||||||
|
if content and content not in seen_content:
|
||||||
|
seen_content.add(content)
|
||||||
|
unique_chunks.append(chunk)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Apply reranking if enabled and query is provided
|
||||||
|
if global_config.get("enable_rerank", False) and query and unique_chunks:
|
||||||
|
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
|
||||||
|
unique_chunks = await apply_rerank_if_enabled(
|
||||||
|
query=query,
|
||||||
|
retrieved_docs=unique_chunks,
|
||||||
|
global_config=global_config,
|
||||||
|
top_k=rerank_top_k,
|
||||||
|
)
|
||||||
|
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
||||||
|
|
||||||
|
# 3. Token-based final truncation
|
||||||
|
tokenizer = global_config.get("tokenizer")
|
||||||
|
if tokenizer and unique_chunks:
|
||||||
|
original_count = len(unique_chunks)
|
||||||
|
unique_chunks = truncate_list_by_token_size(
|
||||||
|
unique_chunks,
|
||||||
|
key=lambda x: x.get("content", ""),
|
||||||
|
max_token_size=query_param.max_token_for_text_unit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
||||||
|
f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return unique_chunks
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue