feat: KG related chunks selection by vector similarity

- Add env switch to toggle weighted polling vs vector-similarity strategy
- Implement similarity-based sorting with fallback to weighted
- Introduce batch vector read API for vector storage
- Implement vector store and retrive funtion for Nanovector DB
- Preserve default behavior (weighted polling selection method)
This commit is contained in:
yangdx 2025-08-13 18:16:42 +08:00
parent 5b0e26d9da
commit f1dafa0d01
6 changed files with 440 additions and 149 deletions

View file

@ -290,6 +290,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
ids: List of vector IDs to be deleted
"""
@abstractmethod
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
pass
@dataclass
class BaseKVStorage(StorageNameSpace, ABC):

View file

@ -27,6 +27,7 @@ DEFAULT_MAX_RELATION_TOKENS = 10000
DEFAULT_MAX_TOTAL_TOKENS = 30000
DEFAULT_COSINE_THRESHOLD = 0.2
DEFAULT_RELATED_CHUNK_NUMBER = 5
DEFAULT_KG_CHUNK_PICK_METHOD = "WEIGHT"
# Deprated: history message have negtive effect on query performance
DEFAULT_HISTORY_TURNS = 0

View file

@ -1,5 +1,7 @@
import asyncio
import base64
import os
import zlib
from typing import Any, final
from dataclasses import dataclass
import numpy as np
@ -93,8 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
# logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -120,6 +121,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
# Compress vector using Float16 + zlib + Base64 for storage optimization
vector_f16 = embeddings[i].astype(np.float16)
compressed_vector = zlib.compress(vector_f16.tobytes())
encoded_vector = base64.b64encode(compressed_vector).decode("utf-8")
d["vector"] = encoded_vector
d["__vector__"] = embeddings[i]
client = await self._get_client()
results = client.upsert(datas=list_data)
@ -147,7 +153,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
results = [
{
**dp,
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
@ -296,7 +302,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
if result:
dp = result[0]
return {
**dp,
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp.get("__id__"),
"created_at": dp.get("__created_at__"),
}
@ -318,13 +324,41 @@ class NanoVectorDBStorage(BaseVectorStorage):
results = client.get(ids)
return [
{
**dp,
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp.get("__id__"),
"created_at": dp.get("__created_at__"),
}
for dp in results
]
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
client = await self._get_client()
results = client.get(ids)
vectors_dict = {}
for result in results:
if result and "vector" in result and "__id__" in result:
# Decompress vector data (Base64 + zlib + Float16 compressed)
decoded = base64.b64decode(result["vector"])
decompressed = zlib.decompress(decoded)
vector_f16 = np.frombuffer(decompressed, dtype=np.float16)
vector_f32 = vector_f16.astype(np.float32).tolist()
vectors_dict[result["__id__"]] = vector_f32
return vectors_dict
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources

View file

@ -31,6 +31,7 @@ from lightrag.constants import (
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_COSINE_THRESHOLD,
DEFAULT_RELATED_CHUNK_NUMBER,
DEFAULT_KG_CHUNK_PICK_METHOD,
DEFAULT_MIN_RERANK_SCORE,
DEFAULT_SUMMARY_MAX_TOKENS,
DEFAULT_MAX_ASYNC,
@ -175,6 +176,11 @@ class LightRAG:
)
"""Number of related chunks to grab from single entity or relation."""
kg_chunk_pick_method: str = field(
default=get_env_value("KG_CHUNK_PICK_METHOD", DEFAULT_KG_CHUNK_PICK_METHOD, str)
)
"""Method for selecting text chunks: 'WEIGHT' for weight-based selection, 'VECTOR' for embedding similarity-based selection."""
# Entity extraction
# ---

View file

@ -28,6 +28,7 @@ from .utils import (
update_chunk_cache_list,
remove_think_tags,
linear_gradient_weighted_polling,
vector_similarity_sorting,
process_chunks_unified,
build_file_path,
)
@ -2349,19 +2350,23 @@ async def _build_query_context(
# Get text chunks based on final filtered data
if final_node_datas:
entity_chunks = await _find_most_related_text_unit_from_entities(
entity_chunks = await _find_related_text_unit_from_entities(
final_node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
query,
chunks_vdb,
)
if final_edge_datas:
relation_chunks = await _find_related_text_unit_from_relationships(
relation_chunks = await _find_related_text_unit_from_relations(
final_edge_datas,
query_param,
text_chunks_db,
entity_chunks,
query,
chunks_vdb,
)
# Round-robin merge chunks from different sources with deduplication by chunk_id
@ -2410,7 +2415,7 @@ async def _build_query_context(
}
)
logger.debug(
logger.info(
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
)
@ -2611,104 +2616,6 @@ async def _get_node_data(
return node_datas, use_relations
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
):
"""
Find text chunks related to entities using linear gradient weighted polling algorithm.
This function implements the optimized text chunk selection strategy:
1. Sort text chunks for each entity by occurrence count in other entities
2. Use linear gradient weighted polling to select chunks fairly
"""
logger.debug(f"Searching text chunks for {len(node_datas)} entities")
if not node_datas:
return []
# Step 1: Collect all text chunks for each entity
entities_with_chunks = []
for entity in node_datas:
if entity.get("source_id"):
chunks = split_string_by_multi_markers(
entity["source_id"], [GRAPH_FIELD_SEP]
)
if chunks:
entities_with_chunks.append(
{
"entity_name": entity["entity_name"],
"chunks": chunks,
"entity_data": entity,
}
)
if not entities_with_chunks:
logger.warning("No entities with text chunks found")
return []
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
chunk_occurrence_count = {}
for entity_info in entities_with_chunks:
deduplicated_chunks = []
for chunk_id in entity_info["chunks"]:
chunk_occurrence_count[chunk_id] = (
chunk_occurrence_count.get(chunk_id, 0) + 1
)
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
if chunk_occurrence_count[chunk_id] == 1:
deduplicated_chunks.append(chunk_id)
# count > 1 means this chunk appeared in an earlier entity, so skip it
# Update entity's chunks to deduplicated chunks
entity_info["chunks"] = deduplicated_chunks
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
for entity_info in entities_with_chunks:
sorted_chunks = sorted(
entity_info["chunks"],
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
reverse=True,
)
entity_info["sorted_chunks"] = sorted_chunks
# Step 4: Apply linear gradient weighted polling algorithm
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
)
selected_chunk_ids = linear_gradient_weighted_polling(
entities_with_chunks, max_related_chunks, min_related_chunks=1
)
logger.debug(
f"Found {len(selected_chunk_ids)} entity-related chunks using linear gradient weighted polling"
)
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
unique_chunk_ids = list(
dict.fromkeys(selected_chunk_ids)
) # Remove duplicates while preserving order
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
# Step 6: Build result chunks with valid data
result_chunks = []
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
if chunk_data is not None and "content" in chunk_data:
chunk_data_copy = chunk_data.copy()
chunk_data_copy["source_type"] = "entity"
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
result_chunks.append(chunk_data_copy)
return result_chunks
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@ -2765,6 +2672,162 @@ async def _find_most_related_edges_from_entities(
return all_edges_data
async def _find_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
query: str = None,
chunks_vdb: BaseVectorStorage = None,
):
"""
Find text chunks related to entities using configurable chunk selection method.
This function supports two chunk selection strategies:
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
"""
logger.debug(f"Finding text chunks from {len(node_datas)} entities")
if not node_datas:
return []
# Step 1: Collect all text chunks for each entity
entities_with_chunks = []
for entity in node_datas:
if entity.get("source_id"):
chunks = split_string_by_multi_markers(
entity["source_id"], [GRAPH_FIELD_SEP]
)
if chunks:
entities_with_chunks.append(
{
"entity_name": entity["entity_name"],
"chunks": chunks,
"entity_data": entity,
}
)
if not entities_with_chunks:
logger.warning("No entities with text chunks found")
return []
# Check chunk selection method from environment variable
kg_chunk_pick_method = os.getenv("KG_CHUNK_PICK_METHOD", "WEIGHT").upper()
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
)
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
chunk_occurrence_count = {}
for entity_info in entities_with_chunks:
deduplicated_chunks = []
for chunk_id in entity_info["chunks"]:
chunk_occurrence_count[chunk_id] = (
chunk_occurrence_count.get(chunk_id, 0) + 1
)
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
if chunk_occurrence_count[chunk_id] == 1:
deduplicated_chunks.append(chunk_id)
# count > 1 means this chunk appeared in an earlier entity, so skip it
# Update entity's chunks to deduplicated chunks
entity_info["chunks"] = deduplicated_chunks
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
total_entity_chunks = 0
for entity_info in entities_with_chunks:
sorted_chunks = sorted(
entity_info["chunks"],
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
reverse=True,
)
entity_info["sorted_chunks"] = sorted_chunks
total_entity_chunks += len(sorted_chunks)
# Step 4: Apply the selected chunk selection algorithm
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
# Calculate num_of_chunks = max_related_chunks * len(entity_info) / 2
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
# Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func
if not embedding_func_config:
logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT"
else:
try:
# Extract the actual callable function from EmbeddingFunc
if hasattr(embedding_func_config, "func"):
actual_embedding_func = embedding_func_config.func
else:
logger.warning(
"Invalid embedding function format, falling back to WEIGHT method"
)
kg_chunk_pick_method = "WEIGHT"
actual_embedding_func = None
selected_chunk_ids = None
if actual_embedding_func:
selected_chunk_ids = await vector_similarity_sorting(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
)
if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT"
logger.warning(
"No entity-related chunks selected by vector similarity, falling back to WEIGHT method"
)
else:
logger.info(
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity"
)
except Exception as e:
logger.error(
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
)
kg_chunk_pick_method = "WEIGHT"
if kg_chunk_pick_method == "WEIGHT":
# Apply linear gradient weighted polling algorithm
selected_chunk_ids = linear_gradient_weighted_polling(
entities_with_chunks, max_related_chunks, min_related_chunks=1
)
logger.debug(
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling"
)
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
unique_chunk_ids = list(
dict.fromkeys(selected_chunk_ids)
) # Remove duplicates while preserving order
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
# Step 6: Build result chunks with valid data
result_chunks = []
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
if chunk_data is not None and "content" in chunk_data:
chunk_data_copy = chunk_data.copy()
chunk_data_copy["source_type"] = "entity"
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
result_chunks.append(chunk_data_copy)
return result_chunks
async def _get_edge_data(
keywords,
knowledge_graph_inst: BaseGraphStorage,
@ -2856,20 +2919,22 @@ async def _find_most_related_entities_from_relationships(
return node_datas
async def _find_related_text_unit_from_relationships(
async def _find_related_text_unit_from_relations(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
entity_chunks: list[dict] = None,
query: str = None,
chunks_vdb: BaseVectorStorage = None,
):
"""
Find text chunks related to relationships using linear gradient weighted polling algorithm.
Find text chunks related to relationships using configurable chunk selection method.
This function implements the optimized text chunk selection strategy:
1. Sort text chunks for each relationship by occurrence count in other relationships
2. Use linear gradient weighted polling to select chunks fairly
This function supports two chunk selection strategies:
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
"""
logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
logger.debug(f"Finding text chunks from {len(edge_datas)} relations")
if not edge_datas:
return []
@ -2899,14 +2964,40 @@ async def _find_related_text_unit_from_relationships(
)
if not relations_with_chunks:
logger.warning("No relationships with text chunks found")
logger.warning("No relation-related chunks found")
return []
# Check chunk selection method from environment variable
kg_chunk_pick_method = os.getenv("KG_CHUNK_PICK_METHOD", "WEIGHT").upper()
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
)
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships)
# Also remove duplicates with entity_chunks
# Extract chunk IDs from entity_chunks for deduplication
entity_chunk_ids = set()
if entity_chunks:
for chunk in entity_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
entity_chunk_ids.add(chunk_id)
chunk_occurrence_count = {}
# Track unique chunk_ids that have been removed to avoid double counting
removed_entity_chunk_ids = set()
for relation_info in relations_with_chunks:
deduplicated_chunks = []
for chunk_id in relation_info["chunks"]:
# Skip chunks that already exist in entity_chunks
if chunk_id in entity_chunk_ids:
# Only count each unique chunk_id once
removed_entity_chunk_ids.add(chunk_id)
continue
chunk_occurrence_count[chunk_id] = (
chunk_occurrence_count.get(chunk_id, 0) + 1
)
@ -2919,6 +3010,20 @@ async def _find_related_text_unit_from_relationships(
# Update relationship's chunks to deduplicated chunks
relation_info["chunks"] = deduplicated_chunks
# Check if any relations still have chunks after deduplication
relations_with_chunks = [
relation_info
for relation_info in relations_with_chunks
if relation_info["chunks"]
]
logger.info(
f"Find {len(relations_with_chunks)} additional relations-related chunks ({len(removed_entity_chunk_ids)} duplicated chunks removed)"
)
if not relations_with_chunks:
return []
# Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority)
for relation_info in relations_with_chunks:
sorted_chunks = sorted(
@ -2928,50 +3033,73 @@ async def _find_related_text_unit_from_relationships(
)
relation_info["sorted_chunks"] = sorted_chunks
# Step 4: Apply linear gradient weighted polling algorithm
max_related_chunks = text_chunks_db.global_config.get(
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
)
# Step 4: Apply the selected chunk selection algorithm
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
selected_chunk_ids = linear_gradient_weighted_polling(
relations_with_chunks, max_related_chunks, min_related_chunks=1
)
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
# Calculate num_of_chunks = max_related_chunks * len(entity_info) / 2
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
# Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func
if not embedding_func_config:
logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT"
else:
try:
# Extract the actual callable function from EmbeddingFunc
if hasattr(embedding_func_config, "func"):
actual_embedding_func = embedding_func_config.func
else:
logger.warning(
"Invalid embedding function format, falling back to WEIGHT method"
)
kg_chunk_pick_method = "WEIGHT"
actual_embedding_func = None
if actual_embedding_func:
selected_chunk_ids = await vector_similarity_sorting(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=relations_with_chunks,
embedding_func=actual_embedding_func,
)
if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT"
logger.warning(
"No relation-related chunks selected by vector similarity, falling back to WEIGHT method"
)
else:
logger.info(
f"Selecting {len(selected_chunk_ids)} relation-related chunks by vector similarity"
)
except Exception as e:
logger.error(
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
)
kg_chunk_pick_method = "WEIGHT"
if kg_chunk_pick_method == "WEIGHT":
# Apply linear gradient weighted polling algorithm
selected_chunk_ids = linear_gradient_weighted_polling(
relations_with_chunks, max_related_chunks, min_related_chunks=1
)
logger.info(
f"Selecting {len(selected_chunk_ids)} relation-related chunks by weighted polling"
)
logger.debug(
f"Found {len(selected_chunk_ids)} relationship-related chunks using linear gradient weighted polling"
)
logger.info(
f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations"
)
if not selected_chunk_ids:
return []
# Step 4.5: Remove duplicates with entity_chunks before batch retrieval
if entity_chunks:
# Extract chunk IDs from entity_chunks
entity_chunk_ids = set()
for chunk in entity_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
entity_chunk_ids.add(chunk_id)
# Filter out duplicate chunk IDs
original_count = len(selected_chunk_ids)
selected_chunk_ids = [
chunk_id
for chunk_id in selected_chunk_ids
if chunk_id not in entity_chunk_ids
]
logger.debug(
f"Deduplication relation-chunks with entity-chunks: {original_count} -> {len(selected_chunk_ids)} chunks "
)
# Early return if no chunks remain after deduplication
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
unique_chunk_ids = list(
dict.fromkeys(selected_chunk_ids)

View file

@ -60,7 +60,7 @@ def get_env_value(
# Use TYPE_CHECKING to avoid circular imports
if TYPE_CHECKING:
from lightrag.base import BaseKVStorage, QueryParam
from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
@ -1650,6 +1650,115 @@ def linear_gradient_weighted_polling(
return selected_chunks
async def vector_similarity_sorting(
query: str,
text_chunks_storage: "BaseKVStorage",
chunks_vdb: "BaseVectorStorage",
num_of_chunks: int,
entity_info: list[dict[str, Any]],
embedding_func: callable,
) -> list[str]:
"""
Vector similarity-based text chunk selection algorithm.
This algorithm selects text chunks based on cosine similarity between
the query embedding and text chunk embeddings.
Args:
query: User's original query string
text_chunks_storage: Text chunks storage instance
chunks_vdb: Vector database storage for chunks
num_of_chunks: Number of chunks to select
entity_info: List of entity information containing chunk IDs
embedding_func: Embedding function to compute query embedding
Returns:
List of selected text chunk IDs sorted by similarity (highest first)
"""
logger.debug(
f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}"
)
if not entity_info or num_of_chunks <= 0:
return []
# Collect all unique chunk IDs from entity info
all_chunk_ids = set()
for i, entity in enumerate(entity_info):
chunk_ids = entity.get("sorted_chunks", [])
all_chunk_ids.update(chunk_ids)
if not all_chunk_ids:
logger.warning(
"Vector similarity chunk selection: no chunk IDs found in entity_info"
)
return []
logger.debug(
f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected"
)
all_chunk_ids = list(all_chunk_ids)
try:
# Get query embedding
query_embedding = await embedding_func([query])
query_embedding = query_embedding[
0
] # Extract first embedding from batch result
# Get chunk embeddings from vector database
chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)
logger.debug(
f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved"
)
if not chunk_vectors:
logger.warning(
"Vector similarity chunk selection: no vectors retrieved from chunks_vdb"
)
return []
# Calculate cosine similarities
similarities = []
valid_vectors = 0
for chunk_id in all_chunk_ids:
if chunk_id in chunk_vectors:
chunk_embedding = chunk_vectors[chunk_id]
try:
# Calculate cosine similarity
similarity = cosine_similarity(query_embedding, chunk_embedding)
similarities.append((chunk_id, similarity))
valid_vectors += 1
except Exception as e:
logger.warning(
f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}"
)
else:
logger.warning(
f"Vector similarity chunk selection: no vector found for chunk {chunk_id}"
)
# Sort by similarity (highest first) and select top num_of_chunks
similarities.sort(key=lambda x: x[1], reverse=True)
selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]]
logger.debug(
f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates"
)
return selected_chunks
except Exception as e:
logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}")
import traceback
logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}")
# Fallback to simple truncation
logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation")
return all_chunk_ids[:num_of_chunks]
class TokenTracker:
"""Track token usage for LLM calls."""