feat: KG related chunks selection by vector similarity
- Add env switch to toggle weighted polling vs vector-similarity strategy - Implement similarity-based sorting with fallback to weighted - Introduce batch vector read API for vector storage - Implement vector store and retrive funtion for Nanovector DB - Preserve default behavior (weighted polling selection method)
This commit is contained in:
parent
5b0e26d9da
commit
f1dafa0d01
6 changed files with 440 additions and 149 deletions
|
|
@ -290,6 +290,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
Dictionary mapping IDs to their vector embeddings
|
||||
Format: {id: [vector_values], ...}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(StorageNameSpace, ABC):
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ DEFAULT_MAX_RELATION_TOKENS = 10000
|
|||
DEFAULT_MAX_TOTAL_TOKENS = 30000
|
||||
DEFAULT_COSINE_THRESHOLD = 0.2
|
||||
DEFAULT_RELATED_CHUNK_NUMBER = 5
|
||||
DEFAULT_KG_CHUNK_PICK_METHOD = "WEIGHT"
|
||||
# Deprated: history message have negtive effect on query performance
|
||||
DEFAULT_HISTORY_TURNS = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import zlib
|
||||
from typing import Any, final
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
|
@ -93,8 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
||||
# logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
|
|
@ -120,6 +121,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||
embeddings = np.concatenate(embeddings_list)
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
# Compress vector using Float16 + zlib + Base64 for storage optimization
|
||||
vector_f16 = embeddings[i].astype(np.float16)
|
||||
compressed_vector = zlib.compress(vector_f16.tobytes())
|
||||
encoded_vector = base64.b64encode(compressed_vector).decode("utf-8")
|
||||
d["vector"] = encoded_vector
|
||||
d["__vector__"] = embeddings[i]
|
||||
client = await self._get_client()
|
||||
results = client.upsert(datas=list_data)
|
||||
|
|
@ -147,7 +153,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||
)
|
||||
results = [
|
||||
{
|
||||
**dp,
|
||||
**{k: v for k, v in dp.items() if k != "vector"},
|
||||
"id": dp["__id__"],
|
||||
"distance": dp["__metrics__"],
|
||||
"created_at": dp.get("__created_at__"),
|
||||
|
|
@ -296,7 +302,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||
if result:
|
||||
dp = result[0]
|
||||
return {
|
||||
**dp,
|
||||
**{k: v for k, v in dp.items() if k != "vector"},
|
||||
"id": dp.get("__id__"),
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
|
|
@ -318,13 +324,41 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||
results = client.get(ids)
|
||||
return [
|
||||
{
|
||||
**dp,
|
||||
**{k: v for k, v in dp.items() if k != "vector"},
|
||||
"id": dp.get("__id__"),
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
for dp in results
|
||||
]
|
||||
|
||||
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
Dictionary mapping IDs to their vector embeddings
|
||||
Format: {id: [vector_values], ...}
|
||||
"""
|
||||
if not ids:
|
||||
return {}
|
||||
|
||||
client = await self._get_client()
|
||||
results = client.get(ids)
|
||||
|
||||
vectors_dict = {}
|
||||
for result in results:
|
||||
if result and "vector" in result and "__id__" in result:
|
||||
# Decompress vector data (Base64 + zlib + Float16 compressed)
|
||||
decoded = base64.b64decode(result["vector"])
|
||||
decompressed = zlib.decompress(decoded)
|
||||
vector_f16 = np.frombuffer(decompressed, dtype=np.float16)
|
||||
vector_f32 = vector_f16.astype(np.float32).tolist()
|
||||
vectors_dict[result["__id__"]] = vector_f32
|
||||
|
||||
return vectors_dict
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from lightrag.constants import (
|
|||
DEFAULT_MAX_TOTAL_TOKENS,
|
||||
DEFAULT_COSINE_THRESHOLD,
|
||||
DEFAULT_RELATED_CHUNK_NUMBER,
|
||||
DEFAULT_KG_CHUNK_PICK_METHOD,
|
||||
DEFAULT_MIN_RERANK_SCORE,
|
||||
DEFAULT_SUMMARY_MAX_TOKENS,
|
||||
DEFAULT_MAX_ASYNC,
|
||||
|
|
@ -175,6 +176,11 @@ class LightRAG:
|
|||
)
|
||||
"""Number of related chunks to grab from single entity or relation."""
|
||||
|
||||
kg_chunk_pick_method: str = field(
|
||||
default=get_env_value("KG_CHUNK_PICK_METHOD", DEFAULT_KG_CHUNK_PICK_METHOD, str)
|
||||
)
|
||||
"""Method for selecting text chunks: 'WEIGHT' for weight-based selection, 'VECTOR' for embedding similarity-based selection."""
|
||||
|
||||
# Entity extraction
|
||||
# ---
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from .utils import (
|
|||
update_chunk_cache_list,
|
||||
remove_think_tags,
|
||||
linear_gradient_weighted_polling,
|
||||
vector_similarity_sorting,
|
||||
process_chunks_unified,
|
||||
build_file_path,
|
||||
)
|
||||
|
|
@ -2349,19 +2350,23 @@ async def _build_query_context(
|
|||
|
||||
# Get text chunks based on final filtered data
|
||||
if final_node_datas:
|
||||
entity_chunks = await _find_most_related_text_unit_from_entities(
|
||||
entity_chunks = await _find_related_text_unit_from_entities(
|
||||
final_node_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
query,
|
||||
chunks_vdb,
|
||||
)
|
||||
|
||||
if final_edge_datas:
|
||||
relation_chunks = await _find_related_text_unit_from_relationships(
|
||||
relation_chunks = await _find_related_text_unit_from_relations(
|
||||
final_edge_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
entity_chunks,
|
||||
query,
|
||||
chunks_vdb,
|
||||
)
|
||||
|
||||
# Round-robin merge chunks from different sources with deduplication by chunk_id
|
||||
|
|
@ -2410,7 +2415,7 @@ async def _build_query_context(
|
|||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
|
||||
)
|
||||
|
||||
|
|
@ -2611,104 +2616,6 @@ async def _get_node_data(
|
|||
return node_datas, use_relations
|
||||
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
"""
|
||||
Find text chunks related to entities using linear gradient weighted polling algorithm.
|
||||
|
||||
This function implements the optimized text chunk selection strategy:
|
||||
1. Sort text chunks for each entity by occurrence count in other entities
|
||||
2. Use linear gradient weighted polling to select chunks fairly
|
||||
"""
|
||||
logger.debug(f"Searching text chunks for {len(node_datas)} entities")
|
||||
|
||||
if not node_datas:
|
||||
return []
|
||||
|
||||
# Step 1: Collect all text chunks for each entity
|
||||
entities_with_chunks = []
|
||||
for entity in node_datas:
|
||||
if entity.get("source_id"):
|
||||
chunks = split_string_by_multi_markers(
|
||||
entity["source_id"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
if chunks:
|
||||
entities_with_chunks.append(
|
||||
{
|
||||
"entity_name": entity["entity_name"],
|
||||
"chunks": chunks,
|
||||
"entity_data": entity,
|
||||
}
|
||||
)
|
||||
|
||||
if not entities_with_chunks:
|
||||
logger.warning("No entities with text chunks found")
|
||||
return []
|
||||
|
||||
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
|
||||
chunk_occurrence_count = {}
|
||||
for entity_info in entities_with_chunks:
|
||||
deduplicated_chunks = []
|
||||
for chunk_id in entity_info["chunks"]:
|
||||
chunk_occurrence_count[chunk_id] = (
|
||||
chunk_occurrence_count.get(chunk_id, 0) + 1
|
||||
)
|
||||
|
||||
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
|
||||
if chunk_occurrence_count[chunk_id] == 1:
|
||||
deduplicated_chunks.append(chunk_id)
|
||||
# count > 1 means this chunk appeared in an earlier entity, so skip it
|
||||
|
||||
# Update entity's chunks to deduplicated chunks
|
||||
entity_info["chunks"] = deduplicated_chunks
|
||||
|
||||
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
|
||||
for entity_info in entities_with_chunks:
|
||||
sorted_chunks = sorted(
|
||||
entity_info["chunks"],
|
||||
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
|
||||
reverse=True,
|
||||
)
|
||||
entity_info["sorted_chunks"] = sorted_chunks
|
||||
|
||||
# Step 4: Apply linear gradient weighted polling algorithm
|
||||
max_related_chunks = text_chunks_db.global_config.get(
|
||||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||||
)
|
||||
|
||||
selected_chunk_ids = linear_gradient_weighted_polling(
|
||||
entities_with_chunks, max_related_chunks, min_related_chunks=1
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Found {len(selected_chunk_ids)} entity-related chunks using linear gradient weighted polling"
|
||||
)
|
||||
|
||||
if not selected_chunk_ids:
|
||||
return []
|
||||
|
||||
# Step 5: Batch retrieve chunk data
|
||||
unique_chunk_ids = list(
|
||||
dict.fromkeys(selected_chunk_ids)
|
||||
) # Remove duplicates while preserving order
|
||||
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
||||
|
||||
# Step 6: Build result chunks with valid data
|
||||
result_chunks = []
|
||||
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
|
||||
if chunk_data is not None and "content" in chunk_data:
|
||||
chunk_data_copy = chunk_data.copy()
|
||||
chunk_data_copy["source_type"] = "entity"
|
||||
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
||||
result_chunks.append(chunk_data_copy)
|
||||
|
||||
return result_chunks
|
||||
|
||||
|
||||
async def _find_most_related_edges_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
|
|
@ -2765,6 +2672,162 @@ async def _find_most_related_edges_from_entities(
|
|||
return all_edges_data
|
||||
|
||||
|
||||
async def _find_related_text_unit_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
query: str = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
):
|
||||
"""
|
||||
Find text chunks related to entities using configurable chunk selection method.
|
||||
|
||||
This function supports two chunk selection strategies:
|
||||
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
|
||||
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
|
||||
"""
|
||||
logger.debug(f"Finding text chunks from {len(node_datas)} entities")
|
||||
|
||||
if not node_datas:
|
||||
return []
|
||||
|
||||
# Step 1: Collect all text chunks for each entity
|
||||
entities_with_chunks = []
|
||||
for entity in node_datas:
|
||||
if entity.get("source_id"):
|
||||
chunks = split_string_by_multi_markers(
|
||||
entity["source_id"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
if chunks:
|
||||
entities_with_chunks.append(
|
||||
{
|
||||
"entity_name": entity["entity_name"],
|
||||
"chunks": chunks,
|
||||
"entity_data": entity,
|
||||
}
|
||||
)
|
||||
|
||||
if not entities_with_chunks:
|
||||
logger.warning("No entities with text chunks found")
|
||||
return []
|
||||
|
||||
# Check chunk selection method from environment variable
|
||||
kg_chunk_pick_method = os.getenv("KG_CHUNK_PICK_METHOD", "WEIGHT").upper()
|
||||
|
||||
max_related_chunks = text_chunks_db.global_config.get(
|
||||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||||
)
|
||||
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
|
||||
chunk_occurrence_count = {}
|
||||
for entity_info in entities_with_chunks:
|
||||
deduplicated_chunks = []
|
||||
for chunk_id in entity_info["chunks"]:
|
||||
chunk_occurrence_count[chunk_id] = (
|
||||
chunk_occurrence_count.get(chunk_id, 0) + 1
|
||||
)
|
||||
|
||||
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
|
||||
if chunk_occurrence_count[chunk_id] == 1:
|
||||
deduplicated_chunks.append(chunk_id)
|
||||
# count > 1 means this chunk appeared in an earlier entity, so skip it
|
||||
|
||||
# Update entity's chunks to deduplicated chunks
|
||||
entity_info["chunks"] = deduplicated_chunks
|
||||
|
||||
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
|
||||
total_entity_chunks = 0
|
||||
for entity_info in entities_with_chunks:
|
||||
sorted_chunks = sorted(
|
||||
entity_info["chunks"],
|
||||
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
|
||||
reverse=True,
|
||||
)
|
||||
entity_info["sorted_chunks"] = sorted_chunks
|
||||
total_entity_chunks += len(sorted_chunks)
|
||||
|
||||
# Step 4: Apply the selected chunk selection algorithm
|
||||
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
|
||||
|
||||
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
|
||||
# Calculate num_of_chunks = max_related_chunks * len(entity_info) / 2
|
||||
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
|
||||
|
||||
# Get embedding function from global config
|
||||
embedding_func_config = text_chunks_db.embedding_func
|
||||
if not embedding_func_config:
|
||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
else:
|
||||
try:
|
||||
# Extract the actual callable function from EmbeddingFunc
|
||||
if hasattr(embedding_func_config, "func"):
|
||||
actual_embedding_func = embedding_func_config.func
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid embedding function format, falling back to WEIGHT method"
|
||||
)
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
actual_embedding_func = None
|
||||
|
||||
selected_chunk_ids = None
|
||||
if actual_embedding_func:
|
||||
selected_chunk_ids = await vector_similarity_sorting(
|
||||
query=query,
|
||||
text_chunks_storage=text_chunks_db,
|
||||
chunks_vdb=chunks_vdb,
|
||||
num_of_chunks=num_of_chunks,
|
||||
entity_info=entities_with_chunks,
|
||||
embedding_func=actual_embedding_func,
|
||||
)
|
||||
|
||||
if selected_chunk_ids == []:
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
logger.warning(
|
||||
"No entity-related chunks selected by vector similarity, falling back to WEIGHT method"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
|
||||
)
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
|
||||
if kg_chunk_pick_method == "WEIGHT":
|
||||
# Apply linear gradient weighted polling algorithm
|
||||
selected_chunk_ids = linear_gradient_weighted_polling(
|
||||
entities_with_chunks, max_related_chunks, min_related_chunks=1
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling"
|
||||
)
|
||||
|
||||
if not selected_chunk_ids:
|
||||
return []
|
||||
|
||||
# Step 5: Batch retrieve chunk data
|
||||
unique_chunk_ids = list(
|
||||
dict.fromkeys(selected_chunk_ids)
|
||||
) # Remove duplicates while preserving order
|
||||
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
||||
|
||||
# Step 6: Build result chunks with valid data
|
||||
result_chunks = []
|
||||
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
|
||||
if chunk_data is not None and "content" in chunk_data:
|
||||
chunk_data_copy = chunk_data.copy()
|
||||
chunk_data_copy["source_type"] = "entity"
|
||||
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
||||
result_chunks.append(chunk_data_copy)
|
||||
|
||||
return result_chunks
|
||||
|
||||
|
||||
async def _get_edge_data(
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
|
|
@ -2856,20 +2919,22 @@ async def _find_most_related_entities_from_relationships(
|
|||
return node_datas
|
||||
|
||||
|
||||
async def _find_related_text_unit_from_relationships(
|
||||
async def _find_related_text_unit_from_relations(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
entity_chunks: list[dict] = None,
|
||||
query: str = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
):
|
||||
"""
|
||||
Find text chunks related to relationships using linear gradient weighted polling algorithm.
|
||||
Find text chunks related to relationships using configurable chunk selection method.
|
||||
|
||||
This function implements the optimized text chunk selection strategy:
|
||||
1. Sort text chunks for each relationship by occurrence count in other relationships
|
||||
2. Use linear gradient weighted polling to select chunks fairly
|
||||
This function supports two chunk selection strategies:
|
||||
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
|
||||
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
|
||||
"""
|
||||
logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
|
||||
logger.debug(f"Finding text chunks from {len(edge_datas)} relations")
|
||||
|
||||
if not edge_datas:
|
||||
return []
|
||||
|
|
@ -2899,14 +2964,40 @@ async def _find_related_text_unit_from_relationships(
|
|||
)
|
||||
|
||||
if not relations_with_chunks:
|
||||
logger.warning("No relationships with text chunks found")
|
||||
logger.warning("No relation-related chunks found")
|
||||
return []
|
||||
|
||||
# Check chunk selection method from environment variable
|
||||
kg_chunk_pick_method = os.getenv("KG_CHUNK_PICK_METHOD", "WEIGHT").upper()
|
||||
|
||||
max_related_chunks = text_chunks_db.global_config.get(
|
||||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||||
)
|
||||
|
||||
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships)
|
||||
# Also remove duplicates with entity_chunks
|
||||
|
||||
# Extract chunk IDs from entity_chunks for deduplication
|
||||
entity_chunk_ids = set()
|
||||
if entity_chunks:
|
||||
for chunk in entity_chunks:
|
||||
chunk_id = chunk.get("chunk_id")
|
||||
if chunk_id:
|
||||
entity_chunk_ids.add(chunk_id)
|
||||
|
||||
chunk_occurrence_count = {}
|
||||
# Track unique chunk_ids that have been removed to avoid double counting
|
||||
removed_entity_chunk_ids = set()
|
||||
|
||||
for relation_info in relations_with_chunks:
|
||||
deduplicated_chunks = []
|
||||
for chunk_id in relation_info["chunks"]:
|
||||
# Skip chunks that already exist in entity_chunks
|
||||
if chunk_id in entity_chunk_ids:
|
||||
# Only count each unique chunk_id once
|
||||
removed_entity_chunk_ids.add(chunk_id)
|
||||
continue
|
||||
|
||||
chunk_occurrence_count[chunk_id] = (
|
||||
chunk_occurrence_count.get(chunk_id, 0) + 1
|
||||
)
|
||||
|
|
@ -2919,6 +3010,20 @@ async def _find_related_text_unit_from_relationships(
|
|||
# Update relationship's chunks to deduplicated chunks
|
||||
relation_info["chunks"] = deduplicated_chunks
|
||||
|
||||
# Check if any relations still have chunks after deduplication
|
||||
relations_with_chunks = [
|
||||
relation_info
|
||||
for relation_info in relations_with_chunks
|
||||
if relation_info["chunks"]
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Find {len(relations_with_chunks)} additional relations-related chunks ({len(removed_entity_chunk_ids)} duplicated chunks removed)"
|
||||
)
|
||||
|
||||
if not relations_with_chunks:
|
||||
return []
|
||||
|
||||
# Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority)
|
||||
for relation_info in relations_with_chunks:
|
||||
sorted_chunks = sorted(
|
||||
|
|
@ -2928,50 +3033,73 @@ async def _find_related_text_unit_from_relationships(
|
|||
)
|
||||
relation_info["sorted_chunks"] = sorted_chunks
|
||||
|
||||
# Step 4: Apply linear gradient weighted polling algorithm
|
||||
max_related_chunks = text_chunks_db.global_config.get(
|
||||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||||
)
|
||||
# Step 4: Apply the selected chunk selection algorithm
|
||||
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
|
||||
|
||||
selected_chunk_ids = linear_gradient_weighted_polling(
|
||||
relations_with_chunks, max_related_chunks, min_related_chunks=1
|
||||
)
|
||||
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
|
||||
# Calculate num_of_chunks = max_related_chunks * len(entity_info) / 2
|
||||
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
|
||||
|
||||
# Get embedding function from global config
|
||||
embedding_func_config = text_chunks_db.embedding_func
|
||||
if not embedding_func_config:
|
||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
else:
|
||||
try:
|
||||
# Extract the actual callable function from EmbeddingFunc
|
||||
if hasattr(embedding_func_config, "func"):
|
||||
actual_embedding_func = embedding_func_config.func
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid embedding function format, falling back to WEIGHT method"
|
||||
)
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
actual_embedding_func = None
|
||||
|
||||
if actual_embedding_func:
|
||||
selected_chunk_ids = await vector_similarity_sorting(
|
||||
query=query,
|
||||
text_chunks_storage=text_chunks_db,
|
||||
chunks_vdb=chunks_vdb,
|
||||
num_of_chunks=num_of_chunks,
|
||||
entity_info=relations_with_chunks,
|
||||
embedding_func=actual_embedding_func,
|
||||
)
|
||||
|
||||
if selected_chunk_ids == []:
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
logger.warning(
|
||||
"No relation-related chunks selected by vector similarity, falling back to WEIGHT method"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Selecting {len(selected_chunk_ids)} relation-related chunks by vector similarity"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
|
||||
)
|
||||
kg_chunk_pick_method = "WEIGHT"
|
||||
|
||||
if kg_chunk_pick_method == "WEIGHT":
|
||||
# Apply linear gradient weighted polling algorithm
|
||||
selected_chunk_ids = linear_gradient_weighted_polling(
|
||||
relations_with_chunks, max_related_chunks, min_related_chunks=1
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Selecting {len(selected_chunk_ids)} relation-related chunks by weighted polling"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Found {len(selected_chunk_ids)} relationship-related chunks using linear gradient weighted polling"
|
||||
)
|
||||
logger.info(
|
||||
f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations"
|
||||
)
|
||||
|
||||
if not selected_chunk_ids:
|
||||
return []
|
||||
|
||||
# Step 4.5: Remove duplicates with entity_chunks before batch retrieval
|
||||
if entity_chunks:
|
||||
# Extract chunk IDs from entity_chunks
|
||||
entity_chunk_ids = set()
|
||||
for chunk in entity_chunks:
|
||||
chunk_id = chunk.get("chunk_id")
|
||||
if chunk_id:
|
||||
entity_chunk_ids.add(chunk_id)
|
||||
|
||||
# Filter out duplicate chunk IDs
|
||||
original_count = len(selected_chunk_ids)
|
||||
selected_chunk_ids = [
|
||||
chunk_id
|
||||
for chunk_id in selected_chunk_ids
|
||||
if chunk_id not in entity_chunk_ids
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Deduplication relation-chunks with entity-chunks: {original_count} -> {len(selected_chunk_ids)} chunks "
|
||||
)
|
||||
|
||||
# Early return if no chunks remain after deduplication
|
||||
if not selected_chunk_ids:
|
||||
return []
|
||||
|
||||
# Step 5: Batch retrieve chunk data
|
||||
unique_chunk_ids = list(
|
||||
dict.fromkeys(selected_chunk_ids)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def get_env_value(
|
|||
|
||||
# Use TYPE_CHECKING to avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from lightrag.base import BaseKVStorage, QueryParam
|
||||
from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
|
|
@ -1650,6 +1650,115 @@ def linear_gradient_weighted_polling(
|
|||
return selected_chunks
|
||||
|
||||
|
||||
async def vector_similarity_sorting(
|
||||
query: str,
|
||||
text_chunks_storage: "BaseKVStorage",
|
||||
chunks_vdb: "BaseVectorStorage",
|
||||
num_of_chunks: int,
|
||||
entity_info: list[dict[str, Any]],
|
||||
embedding_func: callable,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Vector similarity-based text chunk selection algorithm.
|
||||
|
||||
This algorithm selects text chunks based on cosine similarity between
|
||||
the query embedding and text chunk embeddings.
|
||||
|
||||
Args:
|
||||
query: User's original query string
|
||||
text_chunks_storage: Text chunks storage instance
|
||||
chunks_vdb: Vector database storage for chunks
|
||||
num_of_chunks: Number of chunks to select
|
||||
entity_info: List of entity information containing chunk IDs
|
||||
embedding_func: Embedding function to compute query embedding
|
||||
|
||||
Returns:
|
||||
List of selected text chunk IDs sorted by similarity (highest first)
|
||||
"""
|
||||
logger.debug(
|
||||
f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}"
|
||||
)
|
||||
|
||||
if not entity_info or num_of_chunks <= 0:
|
||||
return []
|
||||
|
||||
# Collect all unique chunk IDs from entity info
|
||||
all_chunk_ids = set()
|
||||
for i, entity in enumerate(entity_info):
|
||||
chunk_ids = entity.get("sorted_chunks", [])
|
||||
all_chunk_ids.update(chunk_ids)
|
||||
|
||||
if not all_chunk_ids:
|
||||
logger.warning(
|
||||
"Vector similarity chunk selection: no chunk IDs found in entity_info"
|
||||
)
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected"
|
||||
)
|
||||
|
||||
all_chunk_ids = list(all_chunk_ids)
|
||||
|
||||
try:
|
||||
# Get query embedding
|
||||
query_embedding = await embedding_func([query])
|
||||
query_embedding = query_embedding[
|
||||
0
|
||||
] # Extract first embedding from batch result
|
||||
|
||||
# Get chunk embeddings from vector database
|
||||
chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)
|
||||
logger.debug(
|
||||
f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved"
|
||||
)
|
||||
|
||||
if not chunk_vectors:
|
||||
logger.warning(
|
||||
"Vector similarity chunk selection: no vectors retrieved from chunks_vdb"
|
||||
)
|
||||
return []
|
||||
|
||||
# Calculate cosine similarities
|
||||
similarities = []
|
||||
valid_vectors = 0
|
||||
for chunk_id in all_chunk_ids:
|
||||
if chunk_id in chunk_vectors:
|
||||
chunk_embedding = chunk_vectors[chunk_id]
|
||||
try:
|
||||
# Calculate cosine similarity
|
||||
similarity = cosine_similarity(query_embedding, chunk_embedding)
|
||||
similarities.append((chunk_id, similarity))
|
||||
valid_vectors += 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Vector similarity chunk selection: no vector found for chunk {chunk_id}"
|
||||
)
|
||||
|
||||
# Sort by similarity (highest first) and select top num_of_chunks
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]]
|
||||
|
||||
logger.debug(
|
||||
f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates"
|
||||
)
|
||||
|
||||
return selected_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}")
|
||||
# Fallback to simple truncation
|
||||
logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation")
|
||||
return all_chunk_ids[:num_of_chunks]
|
||||
|
||||
|
||||
class TokenTracker:
|
||||
"""Track token usage for LLM calls."""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue