Merge pull request #1959 from danielaskdd/pick-trunk-by-vector

Feat: add KG related chunks selection by vector similarity
This commit is contained in:
Daniel.y 2025-08-15 19:33:51 +08:00 committed by GitHub
commit bdd1169cfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 742 additions and 161 deletions

1
.gitignore vendored
View file

@ -72,3 +72,4 @@ test_*
# Cline files
memory-bank
memory-bank/
.clinerules

View file

@ -71,10 +71,20 @@ ENABLE_LLM_CACHE=true
# MAX_RELATION_TOKENS=10000
### control the maximum tokens send to LLM (include entities, raltions and chunks)
# MAX_TOTAL_TOKENS=30000
### maximum number of related chunks per source entity or relation (higher values increase re-ranking time)
### maximum number of related chunks per source entity or relation
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)
### Higher values increase re-ranking time
# RELATED_CHUNK_NUMBER=5
### Reranker configuration (Set ENABLE_RERANK to true in reranking model is configed)
### chunk selection strategies
### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval
### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM
### If reranking is enabled, the impact of chunk selection strategies will be diminished.
# KG_CHUNK_PICK_METHOD=VECTOR
### Reranking configuration
### Reranker Set ENABLE_RERANK to true in reranking model is configed
# ENABLE_RERANK=True
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
# MIN_RERANK_SCORE=0.0
@ -258,7 +268,7 @@ POSTGRES_IVFFLAT_LISTS=100
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
NEO4J_USERNAME=neo4j
NEO4J_PASSWORD='your_password'
# NEO4J_DATABASE=chunk_entity_relation
# NEO4J_DATABASE=chunk-entity-relation
NEO4J_MAX_CONNECTION_POOL_SIZE=100
NEO4J_CONNECTION_TIMEOUT=30
NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30

View file

@ -290,6 +290,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
ids: List of vector IDs to be deleted
"""
@abstractmethod
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
pass
@dataclass
class BaseKVStorage(StorageNameSpace, ABC):

View file

@ -27,6 +27,7 @@ DEFAULT_MAX_RELATION_TOKENS = 10000
DEFAULT_MAX_TOTAL_TOKENS = 30000
DEFAULT_COSINE_THRESHOLD = 0.2
DEFAULT_RELATED_CHUNK_NUMBER = 5
DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
# Deprated: history message have negtive effect on query performance
DEFAULT_HISTORY_TURNS = 0

View file

@ -210,9 +210,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
continue
meta = self._id_to_meta.get(idx, {})
# Filter out __vector__ from query results to avoid returning large vector data
filtered_meta = {k: v for k, v in meta.items() if k != "__vector__"}
results.append(
{
**meta,
**filtered_meta,
"id": meta.get("__id__"),
"distance": float(dist),
"created_at": meta.get("__created_at__"),
@ -424,8 +426,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
if not metadata:
return None
# Filter out __vector__ from metadata to avoid returning large vector data
filtered_metadata = {k: v for k, v in metadata.items() if k != "__vector__"}
return {
**metadata,
**filtered_metadata,
"id": metadata.get("__id__"),
"created_at": metadata.get("__created_at__"),
}
@ -448,9 +452,13 @@ class FaissVectorDBStorage(BaseVectorStorage):
if fid is not None:
metadata = self._id_to_meta.get(fid, {})
if metadata:
# Filter out __vector__ from metadata to avoid returning large vector data
filtered_metadata = {
k: v for k, v in metadata.items() if k != "__vector__"
}
results.append(
{
**metadata,
**filtered_metadata,
"id": metadata.get("__id__"),
"created_at": metadata.get("__created_at__"),
}
@ -458,6 +466,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
return results
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
vectors_dict = {}
for id in ids:
# Find the Faiss internal ID for the custom ID
fid = self._find_faiss_id_by_custom_id(id)
if fid is not None and fid in self._id_to_meta:
metadata = self._id_to_meta[fid]
# Get the stored vector from metadata
if "__vector__" in metadata:
vectors_dict[id] = metadata["__vector__"]
return vectors_dict
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources

View file

@ -1018,6 +1018,50 @@ class MilvusVectorDBStorage(BaseVectorStorage):
)
return []
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
try:
# Ensure collection is loaded before querying
self._ensure_collection_loaded()
# Prepare the ID filter expression
id_list = '", "'.join(ids)
filter_expr = f'id in ["{id_list}"]'
# Query Milvus with the filter, requesting only vector field
result = self._client.query(
collection_name=self.final_namespace,
filter=filter_expr,
output_fields=["vector"],
)
vectors_dict = {}
for item in result:
if item and "vector" in item and "id" in item:
# Convert numpy array to list if needed
vector_data = item["vector"]
if isinstance(vector_data, np.ndarray):
vector_data = vector_data.tolist()
vectors_dict[item["id"]] = vector_data
return vectors_dict
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
)
return {}
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources

View file

@ -1967,6 +1967,37 @@ class MongoVectorDBStorage(BaseVectorStorage):
)
return []
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
try:
# Query MongoDB for the specified IDs, only retrieving the vector field
cursor = self._data.find({"_id": {"$in": ids}}, {"vector": 1})
results = await cursor.to_list(length=None)
vectors_dict = {}
for result in results:
if result and "vector" in result and "_id" in result:
# MongoDB stores vectors as arrays, so they should already be lists
vectors_dict[result["_id"]] = result["vector"]
return vectors_dict
except PyMongoError as e:
logger.error(
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
)
return {}
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection and recreating vector index.

View file

@ -1,5 +1,7 @@
import asyncio
import base64
import os
import zlib
from typing import Any, final
from dataclasses import dataclass
import numpy as np
@ -93,8 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
# logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -120,6 +121,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
# Compress vector using Float16 + zlib + Base64 for storage optimization
vector_f16 = embeddings[i].astype(np.float16)
compressed_vector = zlib.compress(vector_f16.tobytes())
encoded_vector = base64.b64encode(compressed_vector).decode("utf-8")
d["vector"] = encoded_vector
d["__vector__"] = embeddings[i]
client = await self._get_client()
results = client.upsert(datas=list_data)
@ -147,7 +153,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
results = [
{
**dp,
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
@ -296,7 +302,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
if result:
dp = result[0]
return {
**dp,
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp.get("__id__"),
"created_at": dp.get("__created_at__"),
}
@ -318,13 +324,41 @@ class NanoVectorDBStorage(BaseVectorStorage):
results = client.get(ids)
return [
{
**dp,
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp.get("__id__"),
"created_at": dp.get("__created_at__"),
}
for dp in results
]
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
client = await self._get_client()
results = client.get(ids)
vectors_dict = {}
for result in results:
if result and "vector" in result and "__id__" in result:
# Decompress vector data (Base64 + zlib + Float16 compressed)
decoded = base64.b64decode(result["vector"])
decompressed = zlib.decompress(decoded)
vector_f16 = np.frombuffer(decompressed, dtype=np.float16)
vector_f32 = vector_f16.astype(np.float32).tolist()
vectors_dict[result["__id__"]] = vector_f32
return vectors_dict
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources

View file

@ -155,7 +155,7 @@ class Neo4JStorage(BaseGraphStorage):
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"[{self.workspace}] "
+ f"{database} at {URI} is not available".capitalize()
+ f"Database {database} at {URI} is not available"
)
raise e
except neo4jExceptions.AuthError as e:
@ -167,7 +167,7 @@ class Neo4JStorage(BaseGraphStorage):
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"[{self.workspace}] "
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
+ f"Database {database} at {URI} not found. Try to create specified database."
)
try:
async with self._driver.session() as session:
@ -177,7 +177,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure result is consumed
logger.info(
f"[{self.workspace}] "
+ f"{database} at {URI} created".capitalize()
+ f"Database {database} at {URI} created"
)
connected = True
except (

View file

@ -2158,6 +2158,53 @@ class PGVectorStorage(BaseVectorStorage):
)
return []
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for vector lookup: {self.namespace}"
)
return {}
ids_str = ",".join([f"'{id}'" for id in ids])
query = f"SELECT id, content_vector FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
params = {"workspace": self.workspace}
try:
results = await self.db.query(query, params, multirows=True)
vectors_dict = {}
for result in results:
if result and "content_vector" in result and "id" in result:
try:
# Parse JSON string to get vector as list of floats
vector_data = json.loads(result["content_vector"])
if isinstance(vector_data, list):
vectors_dict[result["id"]] = vector_data
except (json.JSONDecodeError, TypeError) as e:
logger.warning(
f"[{self.workspace}] Failed to parse vector data for ID {result['id']}: {e}"
)
return vectors_dict
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
)
return {}
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
async with get_storage_lock():

View file

@ -402,6 +402,50 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
return []
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
try:
# Convert to Qdrant compatible IDs
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Retrieve the points by IDs with vectors
results = self._client.retrieve(
collection_name=self.final_namespace,
ids=qdrant_ids,
with_vectors=True, # Important: request vectors
with_payload=True,
)
vectors_dict = {}
for point in results:
if point and point.vector is not None and point.payload:
# Get original ID from payload
original_id = point.payload.get("id")
if original_id:
# Convert numpy array to list if needed
vector_data = point.vector
if isinstance(vector_data, np.ndarray):
vector_data = vector_data.tolist()
vectors_dict[original_id] = vector_data
return vectors_dict
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
)
return {}
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources

View file

@ -31,6 +31,7 @@ from lightrag.constants import (
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_COSINE_THRESHOLD,
DEFAULT_RELATED_CHUNK_NUMBER,
DEFAULT_KG_CHUNK_PICK_METHOD,
DEFAULT_MIN_RERANK_SCORE,
DEFAULT_SUMMARY_MAX_TOKENS,
DEFAULT_MAX_ASYNC,
@ -175,6 +176,11 @@ class LightRAG:
)
"""Number of related chunks to grab from single entity or relation."""
kg_chunk_pick_method: str = field(
default=get_env_value("KG_CHUNK_PICK_METHOD", DEFAULT_KG_CHUNK_PICK_METHOD, str)
)
"""Method for selecting text chunks: 'WEIGHT' for weight-based selection, 'VECTOR' for embedding similarity-based selection."""
# Entity extraction
# ---

View file

@ -27,7 +27,8 @@ from .utils import (
use_llm_func_with_cache,
update_chunk_cache_list,
remove_think_tags,
linear_gradient_weighted_polling,
pick_by_weighted_polling,
pick_by_vector_similarity,
process_chunks_unified,
build_file_path,
)
@ -45,6 +46,7 @@ from .constants import (
DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_RELATED_CHUNK_NUMBER,
DEFAULT_KG_CHUNK_PICK_METHOD,
)
from .kg.shared_storage import get_storage_keyed_lock
import time
@ -2105,6 +2107,9 @@ async def _build_query_context(
global_entities = []
global_relations = []
# Track chunk sources and metadata for final logging
chunk_tracking = {} # chunk_id -> {source, frequency, order}
# Handle local and global modes
if query_param.mode == "local":
local_entities, local_relations = await _get_node_data(
@ -2143,6 +2148,15 @@ async def _build_query_context(
chunks_vdb,
query_param,
)
# Track vector chunks with source metadata
for i, chunk in enumerate(vector_chunks):
chunk_id = chunk.get("chunk_id") or chunk.get("id")
if chunk_id:
chunk_tracking[chunk_id] = {
"source": "C",
"frequency": 1, # Vector chunks always have frequency 1
"order": i + 1, # 1-based order in vector search results
}
# Use round-robin merge to combine local and global data fairly
final_entities = []
@ -2340,20 +2354,29 @@ async def _build_query_context(
seen_edges.add(pair)
# Get text chunks based on final filtered data
# To preserve the influence of entity order, entiy-based chunks should not be deduplcicated by vector_chunks
if final_node_datas:
entity_chunks = await _find_most_related_text_unit_from_entities(
entity_chunks = await _find_related_text_unit_from_entities(
final_node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
query,
chunks_vdb,
chunk_tracking=chunk_tracking,
)
# Find deduplcicated chunks from edge
# Deduplication cause chunks solely relation-based to be prioritized and sent to the LLM when re-ranking is disabled
if final_edge_datas:
relation_chunks = await _find_related_text_unit_from_relationships(
relation_chunks = await _find_related_text_unit_from_relations(
final_edge_datas,
query_param,
text_chunks_db,
entity_chunks,
query,
chunks_vdb,
chunk_tracking=chunk_tracking,
)
# Round-robin merge chunks from different sources with deduplication by chunk_id
@ -2373,6 +2396,7 @@ async def _build_query_context(
{
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk_id,
}
)
@ -2386,6 +2410,7 @@ async def _build_query_context(
{
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk_id,
}
)
@ -2399,10 +2424,11 @@ async def _build_query_context(
{
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk_id,
}
)
logger.debug(
logger.info(
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
)
@ -2517,6 +2543,24 @@ async def _build_query_context(
if not entities_context and not relations_context:
return None
# output chunks tracking infomations
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
if truncated_chunks and chunk_tracking:
chunk_tracking_log = []
for chunk in truncated_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id and chunk_id in chunk_tracking:
tracking_info = chunk_tracking[chunk_id]
source = tracking_info["source"]
frequency = tracking_info["frequency"]
order = tracking_info["order"]
chunk_tracking_log.append(f"{source}{frequency}/{order}")
else:
chunk_tracking_log.append("?0/0")
if chunk_tracking_log:
logger.info(f"chunks: {' '.join(chunk_tracking_log)}")
entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
@ -2603,104 +2647,6 @@ async def _get_node_data(
return node_datas, use_relations
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
):
"""
Find text chunks related to entities using linear gradient weighted polling algorithm.
This function implements the optimized text chunk selection strategy:
1. Sort text chunks for each entity by occurrence count in other entities
2. Use linear gradient weighted polling to select chunks fairly
"""
logger.debug(f"Searching text chunks for {len(node_datas)} entities")
if not node_datas:
return []
# Step 1: Collect all text chunks for each entity
entities_with_chunks = []
for entity in node_datas:
if entity.get("source_id"):
chunks = split_string_by_multi_markers(
entity["source_id"], [GRAPH_FIELD_SEP]
)
if chunks:
entities_with_chunks.append(
{
"entity_name": entity["entity_name"],
"chunks": chunks,
"entity_data": entity,
}
)
if not entities_with_chunks:
logger.warning("No entities with text chunks found")
return []
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
chunk_occurrence_count = {}
for entity_info in entities_with_chunks:
deduplicated_chunks = []
for chunk_id in entity_info["chunks"]:
chunk_occurrence_count[chunk_id] = (
chunk_occurrence_count.get(chunk_id, 0) + 1
)
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
if chunk_occurrence_count[chunk_id] == 1:
deduplicated_chunks.append(chunk_id)
# count > 1 means this chunk appeared in an earlier entity, so skip it
# Update entity's chunks to deduplicated chunks
entity_info["chunks"] = deduplicated_chunks
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
for entity_info in entities_with_chunks:
sorted_chunks = sorted(
entity_info["chunks"],
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
reverse=True,
)
entity_info["sorted_chunks"] = sorted_chunks
# Step 4: Apply linear gradient weighted polling algorithm
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
)
selected_chunk_ids = linear_gradient_weighted_polling(
entities_with_chunks, max_related_chunks, min_related_chunks=1
)
logger.debug(
f"Found {len(selected_chunk_ids)} entity-related chunks using linear gradient weighted polling"
)
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
unique_chunk_ids = list(
dict.fromkeys(selected_chunk_ids)
) # Remove duplicates while preserving order
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
# Step 6: Build result chunks with valid data
result_chunks = []
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
if chunk_data is not None and "content" in chunk_data:
chunk_data_copy = chunk_data.copy()
chunk_data_copy["source_type"] = "entity"
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
result_chunks.append(chunk_data_copy)
return result_chunks
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@ -2757,6 +2703,167 @@ async def _find_most_related_edges_from_entities(
return all_edges_data
async def _find_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
query: str = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
):
"""
Find text chunks related to entities using configurable chunk selection method.
This function supports two chunk selection strategies:
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
"""
logger.debug(f"Finding text chunks from {len(node_datas)} entities")
if not node_datas:
return []
# Step 1: Collect all text chunks for each entity
entities_with_chunks = []
for entity in node_datas:
if entity.get("source_id"):
chunks = split_string_by_multi_markers(
entity["source_id"], [GRAPH_FIELD_SEP]
)
if chunks:
entities_with_chunks.append(
{
"entity_name": entity["entity_name"],
"chunks": chunks,
"entity_data": entity,
}
)
if not entities_with_chunks:
logger.warning("No entities with text chunks found")
return []
kg_chunk_pick_method = text_chunks_db.global_config.get(
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
)
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
)
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
chunk_occurrence_count = {}
for entity_info in entities_with_chunks:
deduplicated_chunks = []
for chunk_id in entity_info["chunks"]:
chunk_occurrence_count[chunk_id] = (
chunk_occurrence_count.get(chunk_id, 0) + 1
)
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
if chunk_occurrence_count[chunk_id] == 1:
deduplicated_chunks.append(chunk_id)
# count > 1 means this chunk appeared in an earlier entity, so skip it
# Update entity's chunks to deduplicated chunks
entity_info["chunks"] = deduplicated_chunks
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
total_entity_chunks = 0
for entity_info in entities_with_chunks:
sorted_chunks = sorted(
entity_info["chunks"],
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
reverse=True,
)
entity_info["sorted_chunks"] = sorted_chunks
total_entity_chunks += len(sorted_chunks)
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
# Step 4: Apply the selected chunk selection algorithm
# Pick by vector similarity:
# The order of text chunks aligns with the naive retrieval's destination.
# When reranking is disabled, the text chunks delivered to the LLM tend to favor naive retrieval.
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
# Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func
if not embedding_func_config:
logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT"
else:
try:
actual_embedding_func = embedding_func_config.func
selected_chunk_ids = None
if actual_embedding_func:
selected_chunk_ids = await pick_by_vector_similarity(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
)
if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT"
logger.warning(
"No entity-related chunks selected by vector similarity, falling back to WEIGHT method"
)
else:
logger.info(
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity"
)
except Exception as e:
logger.error(
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
)
kg_chunk_pick_method = "WEIGHT"
if kg_chunk_pick_method == "WEIGHT":
# Pick by entity and chunk weight:
# When reranking is disabled, delivered more solely KG related chunks to the LLM
selected_chunk_ids = pick_by_weighted_polling(
entities_with_chunks, max_related_chunks, min_related_chunks=1
)
logger.info(
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling"
)
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
unique_chunk_ids = list(
dict.fromkeys(selected_chunk_ids)
) # Remove duplicates while preserving order
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
# Step 6: Build result chunks with valid data and update chunk tracking
result_chunks = []
for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)):
if chunk_data is not None and "content" in chunk_data:
chunk_data_copy = chunk_data.copy()
chunk_data_copy["source_type"] = "entity"
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
result_chunks.append(chunk_data_copy)
# Update chunk tracking if provided
if chunk_tracking is not None:
chunk_tracking[chunk_id] = {
"source": "E",
"frequency": chunk_occurrence_count.get(chunk_id, 1),
"order": i + 1, # 1-based order in final entity-related results
}
return result_chunks
async def _get_edge_data(
keywords,
knowledge_graph_inst: BaseGraphStorage,
@ -2848,20 +2955,23 @@ async def _find_most_related_entities_from_relationships(
return node_datas
async def _find_related_text_unit_from_relationships(
async def _find_related_text_unit_from_relations(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
entity_chunks: list[dict] = None,
query: str = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
):
"""
Find text chunks related to relationships using linear gradient weighted polling algorithm.
Find text chunks related to relationships using configurable chunk selection method.
This function implements the optimized text chunk selection strategy:
1. Sort text chunks for each relationship by occurrence count in other relationships
2. Use linear gradient weighted polling to select chunks fairly
This function supports two chunk selection strategies:
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
"""
logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
logger.debug(f"Finding text chunks from {len(edge_datas)} relations")
if not edge_datas:
return []
@ -2891,14 +3001,40 @@ async def _find_related_text_unit_from_relationships(
)
if not relations_with_chunks:
logger.warning("No relationships with text chunks found")
logger.warning("No relation-related chunks found")
return []
kg_chunk_pick_method = text_chunks_db.global_config.get(
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
)
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
)
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships)
# Also remove duplicates with entity_chunks
# Extract chunk IDs from entity_chunks for deduplication
entity_chunk_ids = set()
if entity_chunks:
for chunk in entity_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
entity_chunk_ids.add(chunk_id)
chunk_occurrence_count = {}
# Track unique chunk_ids that have been removed to avoid double counting
removed_entity_chunk_ids = set()
for relation_info in relations_with_chunks:
deduplicated_chunks = []
for chunk_id in relation_info["chunks"]:
# Skip chunks that already exist in entity_chunks
if chunk_id in entity_chunk_ids:
# Only count each unique chunk_id once
removed_entity_chunk_ids.add(chunk_id)
continue
chunk_occurrence_count[chunk_id] = (
chunk_occurrence_count.get(chunk_id, 0) + 1
)
@ -2911,7 +3047,21 @@ async def _find_related_text_unit_from_relationships(
# Update relationship's chunks to deduplicated chunks
relation_info["chunks"] = deduplicated_chunks
# Check if any relations still have chunks after deduplication
relations_with_chunks = [
relation_info
for relation_info in relations_with_chunks
if relation_info["chunks"]
]
if not relations_with_chunks:
logger.info(
f"Find no additional relations-related chunks from {len(edge_datas)} relations"
)
return []
# Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority)
total_relation_chunks = 0
for relation_info in relations_with_chunks:
sorted_chunks = sorted(
relation_info["chunks"],
@ -2919,66 +3069,93 @@ async def _find_related_text_unit_from_relationships(
reverse=True,
)
relation_info["sorted_chunks"] = sorted_chunks
total_relation_chunks += len(sorted_chunks)
# Step 4: Apply linear gradient weighted polling algorithm
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
logger.info(
f"Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations ({len(removed_entity_chunk_ids)} duplicated chunks removed)"
)
selected_chunk_ids = linear_gradient_weighted_polling(
relations_with_chunks, max_related_chunks, min_related_chunks=1
)
# Step 4: Apply the selected chunk selection algorithm
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
# Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func
if not embedding_func_config:
logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT"
else:
try:
actual_embedding_func = embedding_func_config.func
if actual_embedding_func:
selected_chunk_ids = await pick_by_vector_similarity(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=relations_with_chunks,
embedding_func=actual_embedding_func,
)
if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT"
logger.warning(
"No relation-related chunks selected by vector similarity, falling back to WEIGHT method"
)
else:
logger.info(
f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity"
)
except Exception as e:
logger.error(
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
)
kg_chunk_pick_method = "WEIGHT"
if kg_chunk_pick_method == "WEIGHT":
# Apply linear gradient weighted polling algorithm
selected_chunk_ids = pick_by_weighted_polling(
relations_with_chunks, max_related_chunks, min_related_chunks=1
)
logger.info(
f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling"
)
logger.debug(
f"Found {len(selected_chunk_ids)} relationship-related chunks using linear gradient weighted polling"
)
logger.info(
f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations"
)
if not selected_chunk_ids:
return []
# Step 4.5: Remove duplicates with entity_chunks before batch retrieval
if entity_chunks:
# Extract chunk IDs from entity_chunks
entity_chunk_ids = set()
for chunk in entity_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
entity_chunk_ids.add(chunk_id)
# Filter out duplicate chunk IDs
original_count = len(selected_chunk_ids)
selected_chunk_ids = [
chunk_id
for chunk_id in selected_chunk_ids
if chunk_id not in entity_chunk_ids
]
logger.debug(
f"Deduplication relation-chunks with entity-chunks: {original_count} -> {len(selected_chunk_ids)} chunks "
)
# Early return if no chunks remain after deduplication
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
unique_chunk_ids = list(
dict.fromkeys(selected_chunk_ids)
) # Remove duplicates while preserving order
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
# Step 6: Build result chunks with valid data
# Step 6: Build result chunks with valid data and update chunk tracking
result_chunks = []
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)):
if chunk_data is not None and "content" in chunk_data:
chunk_data_copy = chunk_data.copy()
chunk_data_copy["source_type"] = "relationship"
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
result_chunks.append(chunk_data_copy)
# Update chunk tracking if provided
if chunk_tracking is not None:
chunk_tracking[chunk_id] = {
"source": "R",
"frequency": chunk_occurrence_count.get(chunk_id, 1),
"order": i + 1, # 1-based order in final relation-related results
}
return result_chunks

View file

@ -60,7 +60,7 @@ def get_env_value(
# Use TYPE_CHECKING to avoid circular imports
if TYPE_CHECKING:
from lightrag.base import BaseKVStorage, QueryParam
from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
@ -1570,7 +1570,7 @@ def check_storage_env_vars(storage_name: str) -> None:
)
def linear_gradient_weighted_polling(
def pick_by_weighted_polling(
entities_or_relations: list[dict],
max_related_chunks: int,
min_related_chunks: int = 1,
@ -1650,6 +1650,120 @@ def linear_gradient_weighted_polling(
return selected_chunks
async def pick_by_vector_similarity(
query: str,
text_chunks_storage: "BaseKVStorage",
chunks_vdb: "BaseVectorStorage",
num_of_chunks: int,
entity_info: list[dict[str, Any]],
embedding_func: callable,
) -> list[str]:
"""
Vector similarity-based text chunk selection algorithm.
This algorithm selects text chunks based on cosine similarity between
the query embedding and text chunk embeddings.
Args:
query: User's original query string
text_chunks_storage: Text chunks storage instance
chunks_vdb: Vector database storage for chunks
num_of_chunks: Number of chunks to select
entity_info: List of entity information containing chunk IDs
embedding_func: Embedding function to compute query embedding
Returns:
List of selected text chunk IDs sorted by similarity (highest first)
"""
logger.debug(
f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}"
)
if not entity_info or num_of_chunks <= 0:
return []
# Collect all unique chunk IDs from entity info
all_chunk_ids = set()
for i, entity in enumerate(entity_info):
chunk_ids = entity.get("sorted_chunks", [])
all_chunk_ids.update(chunk_ids)
if not all_chunk_ids:
logger.warning(
"Vector similarity chunk selection: no chunk IDs found in entity_info"
)
return []
logger.debug(
f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected"
)
all_chunk_ids = list(all_chunk_ids)
try:
# Get query embedding
query_embedding = await embedding_func([query])
query_embedding = query_embedding[
0
] # Extract first embedding from batch result
# Get chunk embeddings from vector database
chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)
logger.debug(
f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved"
)
if not chunk_vectors or len(chunk_vectors) != len(all_chunk_ids):
if not chunk_vectors:
logger.warning(
"Vector similarity chunk selection: no vectors retrieved from chunks_vdb"
)
else:
logger.warning(
f"Vector similarity chunk selection: found {len(chunk_vectors)} but expecting {len(all_chunk_ids)}"
)
return []
# Calculate cosine similarities
similarities = []
valid_vectors = 0
for chunk_id in all_chunk_ids:
if chunk_id in chunk_vectors:
chunk_embedding = chunk_vectors[chunk_id]
try:
# Calculate cosine similarity
similarity = cosine_similarity(query_embedding, chunk_embedding)
similarities.append((chunk_id, similarity))
valid_vectors += 1
except Exception as e:
logger.warning(
f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}"
)
else:
logger.warning(
f"Vector similarity chunk selection: no vector found for chunk {chunk_id}"
)
# Sort by similarity (highest first) and select top num_of_chunks
similarities.sort(key=lambda x: x[1], reverse=True)
selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]]
logger.debug(
f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates"
)
return selected_chunks
except Exception as e:
logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}")
import traceback
logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}")
# Fallback to simple truncation
logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation")
return all_chunk_ids[:num_of_chunks]
class TokenTracker:
"""Track token usage for LLM calls."""
@ -1787,6 +1901,13 @@ async def process_chunks_unified(
# 1. Apply reranking if enabled and query is provided
if query_param.enable_rerank and query and unique_chunks:
# 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
chunk_ids = {}
for chunk in unique_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids[id(chunk)] = chunk_id
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
@ -1796,6 +1917,11 @@ async def process_chunks_unified(
top_n=rerank_top_k,
)
# 恢复 chunk_id 字段
for chunk in unique_chunks:
if id(chunk) in chunk_ids:
chunk["chunk_id"] = chunk_ids[id(chunk)]
# 2. Filter by minimum rerank score if reranking is enabled
if query_param.enable_rerank and unique_chunks:
min_rerank_score = global_config.get("min_rerank_score", 0.5)
@ -1842,12 +1968,26 @@ async def process_chunks_unified(
)
original_count = len(unique_chunks)
# Keep chunk_id field, cause truncate_list_by_token_size will lose it
chunk_ids_map = {}
for i, chunk in enumerate(unique_chunks):
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids_map[i] = chunk_id
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: x.get("content", ""),
max_token_size=chunk_token_limit,
tokenizer=tokenizer,
)
# restore chunk_id feiled
for i, chunk in enumerate(unique_chunks):
if i in chunk_ids_map:
chunk["chunk_id"] = chunk_ids_map[i]
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"