diff --git a/lightrag/base.py b/lightrag/base.py index 0e651f7b..9ba34280 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -290,6 +290,19 @@ class BaseVectorStorage(StorageNameSpace, ABC): ids: List of vector IDs to be deleted """ + @abstractmethod + async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: + """Get vectors by their IDs, returning only ID and vector data for efficiency + + Args: + ids: List of unique identifiers + + Returns: + Dictionary mapping IDs to their vector embeddings + Format: {id: [vector_values], ...} + """ + pass + @dataclass class BaseKVStorage(StorageNameSpace, ABC): diff --git a/lightrag/constants.py b/lightrag/constants.py index e66fe0ae..895852bd 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -27,6 +27,7 @@ DEFAULT_MAX_RELATION_TOKENS = 10000 DEFAULT_MAX_TOTAL_TOKENS = 30000 DEFAULT_COSINE_THRESHOLD = 0.2 DEFAULT_RELATED_CHUNK_NUMBER = 5 +DEFAULT_KG_CHUNK_PICK_METHOD = "WEIGHT" # Deprated: history message have negtive effect on query performance DEFAULT_HISTORY_TURNS = 0 diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index b5ce2aa3..5bec06f4 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -1,5 +1,7 @@ import asyncio +import base64 import os +import zlib from typing import Any, final from dataclasses import dataclass import numpy as np @@ -93,8 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage): 2. Only one process should updating the storage at a time before index_done_callback, KG-storage-log should be used to avoid data corruption """ - - logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") + # logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -120,6 +121,11 @@ class NanoVectorDBStorage(BaseVectorStorage): embeddings = np.concatenate(embeddings_list) if len(embeddings) == len(list_data): for i, d in enumerate(list_data): + # Compress vector using Float16 + zlib + Base64 for storage optimization + vector_f16 = embeddings[i].astype(np.float16) + compressed_vector = zlib.compress(vector_f16.tobytes()) + encoded_vector = base64.b64encode(compressed_vector).decode("utf-8") + d["vector"] = encoded_vector d["__vector__"] = embeddings[i] client = await self._get_client() results = client.upsert(datas=list_data) @@ -147,7 +153,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ) results = [ { - **dp, + **{k: v for k, v in dp.items() if k != "vector"}, "id": dp["__id__"], "distance": dp["__metrics__"], "created_at": dp.get("__created_at__"), @@ -296,7 +302,7 @@ class NanoVectorDBStorage(BaseVectorStorage): if result: dp = result[0] return { - **dp, + **{k: v for k, v in dp.items() if k != "vector"}, "id": dp.get("__id__"), "created_at": dp.get("__created_at__"), } @@ -318,13 +324,41 @@ class NanoVectorDBStorage(BaseVectorStorage): results = client.get(ids) return [ { - **dp, + **{k: v for k, v in dp.items() if k != "vector"}, "id": dp.get("__id__"), "created_at": dp.get("__created_at__"), } for dp in results ] + async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: + """Get vectors by their IDs, returning only ID and vector data for efficiency + + Args: + ids: List of unique identifiers + + Returns: + Dictionary mapping IDs to their vector embeddings + Format: {id: [vector_values], ...} + """ + if not ids: + return {} + + client = await self._get_client() + results = client.get(ids) + + vectors_dict = {} + for result in results: + if result and "vector" in result and "__id__" in result: + # Decompress vector data (Base64 + zlib + Float16 compressed) + decoded = base64.b64decode(result["vector"]) + decompressed = zlib.decompress(decoded) + vector_f16 = np.frombuffer(decompressed, dtype=np.float16) + vector_f32 = vector_f16.astype(np.float32).tolist() + vectors_dict[result["__id__"]] = vector_f32 + + return vectors_dict + async def drop(self) -> dict[str, str]: """Drop all vector data from storage and clean up resources diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d6822296..4ffea343 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -31,6 +31,7 @@ from lightrag.constants import ( DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_COSINE_THRESHOLD, DEFAULT_RELATED_CHUNK_NUMBER, + DEFAULT_KG_CHUNK_PICK_METHOD, DEFAULT_MIN_RERANK_SCORE, DEFAULT_SUMMARY_MAX_TOKENS, DEFAULT_MAX_ASYNC, @@ -175,6 +176,11 @@ class LightRAG: ) """Number of related chunks to grab from single entity or relation.""" + kg_chunk_pick_method: str = field( + default=get_env_value("KG_CHUNK_PICK_METHOD", DEFAULT_KG_CHUNK_PICK_METHOD, str) + ) + """Method for selecting text chunks: 'WEIGHT' for weight-based selection, 'VECTOR' for embedding similarity-based selection.""" + # Entity extraction # --- diff --git a/lightrag/operate.py b/lightrag/operate.py index dd8d54be..60f5bc6c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -28,6 +28,7 @@ from .utils import ( update_chunk_cache_list, remove_think_tags, linear_gradient_weighted_polling, + vector_similarity_sorting, process_chunks_unified, build_file_path, ) @@ -2349,19 +2350,23 @@ async def _build_query_context( # Get text chunks based on final filtered data if final_node_datas: - entity_chunks = await _find_most_related_text_unit_from_entities( + entity_chunks = await _find_related_text_unit_from_entities( final_node_datas, query_param, text_chunks_db, knowledge_graph_inst, + query, + chunks_vdb, ) if final_edge_datas: - relation_chunks = await _find_related_text_unit_from_relationships( + relation_chunks = await _find_related_text_unit_from_relations( final_edge_datas, query_param, text_chunks_db, entity_chunks, + query, + chunks_vdb, ) # Round-robin merge chunks from different sources with deduplication by chunk_id @@ -2410,7 +2415,7 @@ async def _build_query_context( } ) - logger.debug( + logger.info( f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}" ) @@ -2611,104 +2616,6 @@ async def _get_node_data( return node_datas, use_relations -async def _find_most_related_text_unit_from_entities( - node_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage, - knowledge_graph_inst: BaseGraphStorage, -): - """ - Find text chunks related to entities using linear gradient weighted polling algorithm. - - This function implements the optimized text chunk selection strategy: - 1. Sort text chunks for each entity by occurrence count in other entities - 2. Use linear gradient weighted polling to select chunks fairly - """ - logger.debug(f"Searching text chunks for {len(node_datas)} entities") - - if not node_datas: - return [] - - # Step 1: Collect all text chunks for each entity - entities_with_chunks = [] - for entity in node_datas: - if entity.get("source_id"): - chunks = split_string_by_multi_markers( - entity["source_id"], [GRAPH_FIELD_SEP] - ) - if chunks: - entities_with_chunks.append( - { - "entity_name": entity["entity_name"], - "chunks": chunks, - "entity_data": entity, - } - ) - - if not entities_with_chunks: - logger.warning("No entities with text chunks found") - return [] - - # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities) - chunk_occurrence_count = {} - for entity_info in entities_with_chunks: - deduplicated_chunks = [] - for chunk_id in entity_info["chunks"]: - chunk_occurrence_count[chunk_id] = ( - chunk_occurrence_count.get(chunk_id, 0) + 1 - ) - - # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) - if chunk_occurrence_count[chunk_id] == 1: - deduplicated_chunks.append(chunk_id) - # count > 1 means this chunk appeared in an earlier entity, so skip it - - # Update entity's chunks to deduplicated chunks - entity_info["chunks"] = deduplicated_chunks - - # Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority) - for entity_info in entities_with_chunks: - sorted_chunks = sorted( - entity_info["chunks"], - key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), - reverse=True, - ) - entity_info["sorted_chunks"] = sorted_chunks - - # Step 4: Apply linear gradient weighted polling algorithm - max_related_chunks = text_chunks_db.global_config.get( - "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER - ) - - selected_chunk_ids = linear_gradient_weighted_polling( - entities_with_chunks, max_related_chunks, min_related_chunks=1 - ) - - logger.debug( - f"Found {len(selected_chunk_ids)} entity-related chunks using linear gradient weighted polling" - ) - - if not selected_chunk_ids: - return [] - - # Step 5: Batch retrieve chunk data - unique_chunk_ids = list( - dict.fromkeys(selected_chunk_ids) - ) # Remove duplicates while preserving order - chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) - - # Step 6: Build result chunks with valid data - result_chunks = [] - for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list): - if chunk_data is not None and "content" in chunk_data: - chunk_data_copy = chunk_data.copy() - chunk_data_copy["source_type"] = "entity" - chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication - result_chunks.append(chunk_data_copy) - - return result_chunks - - async def _find_most_related_edges_from_entities( node_datas: list[dict], query_param: QueryParam, @@ -2765,6 +2672,162 @@ async def _find_most_related_edges_from_entities( return all_edges_data +async def _find_related_text_unit_from_entities( + node_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage, + knowledge_graph_inst: BaseGraphStorage, + query: str = None, + chunks_vdb: BaseVectorStorage = None, +): + """ + Find text chunks related to entities using configurable chunk selection method. + + This function supports two chunk selection strategies: + 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count + 2. VECTOR: Vector similarity-based selection using embedding cosine similarity + """ + logger.debug(f"Finding text chunks from {len(node_datas)} entities") + + if not node_datas: + return [] + + # Step 1: Collect all text chunks for each entity + entities_with_chunks = [] + for entity in node_datas: + if entity.get("source_id"): + chunks = split_string_by_multi_markers( + entity["source_id"], [GRAPH_FIELD_SEP] + ) + if chunks: + entities_with_chunks.append( + { + "entity_name": entity["entity_name"], + "chunks": chunks, + "entity_data": entity, + } + ) + + if not entities_with_chunks: + logger.warning("No entities with text chunks found") + return [] + + # Check chunk selection method from environment variable + kg_chunk_pick_method = os.getenv("KG_CHUNK_PICK_METHOD", "WEIGHT").upper() + + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + ) + # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities) + chunk_occurrence_count = {} + for entity_info in entities_with_chunks: + deduplicated_chunks = [] + for chunk_id in entity_info["chunks"]: + chunk_occurrence_count[chunk_id] = ( + chunk_occurrence_count.get(chunk_id, 0) + 1 + ) + + # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) + if chunk_occurrence_count[chunk_id] == 1: + deduplicated_chunks.append(chunk_id) + # count > 1 means this chunk appeared in an earlier entity, so skip it + + # Update entity's chunks to deduplicated chunks + entity_info["chunks"] = deduplicated_chunks + + # Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority) + total_entity_chunks = 0 + for entity_info in entities_with_chunks: + sorted_chunks = sorted( + entity_info["chunks"], + key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), + reverse=True, + ) + entity_info["sorted_chunks"] = sorted_chunks + total_entity_chunks += len(sorted_chunks) + + # Step 4: Apply the selected chunk selection algorithm + selected_chunk_ids = [] # Initialize to avoid UnboundLocalError + + if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: + # Calculate num_of_chunks = max_related_chunks * len(entity_info) / 2 + num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) + + # Get embedding function from global config + embedding_func_config = text_chunks_db.embedding_func + if not embedding_func_config: + logger.warning("No embedding function found, falling back to WEIGHT method") + kg_chunk_pick_method = "WEIGHT" + else: + try: + # Extract the actual callable function from EmbeddingFunc + if hasattr(embedding_func_config, "func"): + actual_embedding_func = embedding_func_config.func + else: + logger.warning( + "Invalid embedding function format, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" + actual_embedding_func = None + + selected_chunk_ids = None + if actual_embedding_func: + selected_chunk_ids = await vector_similarity_sorting( + query=query, + text_chunks_storage=text_chunks_db, + chunks_vdb=chunks_vdb, + num_of_chunks=num_of_chunks, + entity_info=entities_with_chunks, + embedding_func=actual_embedding_func, + ) + + if selected_chunk_ids == []: + kg_chunk_pick_method = "WEIGHT" + logger.warning( + "No entity-related chunks selected by vector similarity, falling back to WEIGHT method" + ) + else: + logger.info( + f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity" + ) + + except Exception as e: + logger.error( + f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" + + if kg_chunk_pick_method == "WEIGHT": + # Apply linear gradient weighted polling algorithm + selected_chunk_ids = linear_gradient_weighted_polling( + entities_with_chunks, max_related_chunks, min_related_chunks=1 + ) + + logger.debug( + f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling" + ) + + if not selected_chunk_ids: + return [] + + # Step 5: Batch retrieve chunk data + unique_chunk_ids = list( + dict.fromkeys(selected_chunk_ids) + ) # Remove duplicates while preserving order + chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) + + # Step 6: Build result chunks with valid data + result_chunks = [] + for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list): + if chunk_data is not None and "content" in chunk_data: + chunk_data_copy = chunk_data.copy() + chunk_data_copy["source_type"] = "entity" + chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication + result_chunks.append(chunk_data_copy) + + return result_chunks + + async def _get_edge_data( keywords, knowledge_graph_inst: BaseGraphStorage, @@ -2856,20 +2919,22 @@ async def _find_most_related_entities_from_relationships( return node_datas -async def _find_related_text_unit_from_relationships( +async def _find_related_text_unit_from_relations( edge_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, entity_chunks: list[dict] = None, + query: str = None, + chunks_vdb: BaseVectorStorage = None, ): """ - Find text chunks related to relationships using linear gradient weighted polling algorithm. + Find text chunks related to relationships using configurable chunk selection method. - This function implements the optimized text chunk selection strategy: - 1. Sort text chunks for each relationship by occurrence count in other relationships - 2. Use linear gradient weighted polling to select chunks fairly + This function supports two chunk selection strategies: + 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count + 2. VECTOR: Vector similarity-based selection using embedding cosine similarity """ - logger.debug(f"Searching text chunks for {len(edge_datas)} relationships") + logger.debug(f"Finding text chunks from {len(edge_datas)} relations") if not edge_datas: return [] @@ -2899,14 +2964,40 @@ async def _find_related_text_unit_from_relationships( ) if not relations_with_chunks: - logger.warning("No relationships with text chunks found") + logger.warning("No relation-related chunks found") return [] + # Check chunk selection method from environment variable + kg_chunk_pick_method = os.getenv("KG_CHUNK_PICK_METHOD", "WEIGHT").upper() + + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + ) + # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships) + # Also remove duplicates with entity_chunks + + # Extract chunk IDs from entity_chunks for deduplication + entity_chunk_ids = set() + if entity_chunks: + for chunk in entity_chunks: + chunk_id = chunk.get("chunk_id") + if chunk_id: + entity_chunk_ids.add(chunk_id) + chunk_occurrence_count = {} + # Track unique chunk_ids that have been removed to avoid double counting + removed_entity_chunk_ids = set() + for relation_info in relations_with_chunks: deduplicated_chunks = [] for chunk_id in relation_info["chunks"]: + # Skip chunks that already exist in entity_chunks + if chunk_id in entity_chunk_ids: + # Only count each unique chunk_id once + removed_entity_chunk_ids.add(chunk_id) + continue + chunk_occurrence_count[chunk_id] = ( chunk_occurrence_count.get(chunk_id, 0) + 1 ) @@ -2919,6 +3010,20 @@ async def _find_related_text_unit_from_relationships( # Update relationship's chunks to deduplicated chunks relation_info["chunks"] = deduplicated_chunks + # Check if any relations still have chunks after deduplication + relations_with_chunks = [ + relation_info + for relation_info in relations_with_chunks + if relation_info["chunks"] + ] + + logger.info( + f"Find {len(relations_with_chunks)} additional relations-related chunks ({len(removed_entity_chunk_ids)} duplicated chunks removed)" + ) + + if not relations_with_chunks: + return [] + # Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority) for relation_info in relations_with_chunks: sorted_chunks = sorted( @@ -2928,50 +3033,73 @@ async def _find_related_text_unit_from_relationships( ) relation_info["sorted_chunks"] = sorted_chunks - # Step 4: Apply linear gradient weighted polling algorithm - max_related_chunks = text_chunks_db.global_config.get( - "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER - ) + # Step 4: Apply the selected chunk selection algorithm + selected_chunk_ids = [] # Initialize to avoid UnboundLocalError - selected_chunk_ids = linear_gradient_weighted_polling( - relations_with_chunks, max_related_chunks, min_related_chunks=1 - ) + if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: + # Calculate num_of_chunks = max_related_chunks * len(entity_info) / 2 + num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) + + # Get embedding function from global config + embedding_func_config = text_chunks_db.embedding_func + if not embedding_func_config: + logger.warning("No embedding function found, falling back to WEIGHT method") + kg_chunk_pick_method = "WEIGHT" + else: + try: + # Extract the actual callable function from EmbeddingFunc + if hasattr(embedding_func_config, "func"): + actual_embedding_func = embedding_func_config.func + else: + logger.warning( + "Invalid embedding function format, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" + actual_embedding_func = None + + if actual_embedding_func: + selected_chunk_ids = await vector_similarity_sorting( + query=query, + text_chunks_storage=text_chunks_db, + chunks_vdb=chunks_vdb, + num_of_chunks=num_of_chunks, + entity_info=relations_with_chunks, + embedding_func=actual_embedding_func, + ) + + if selected_chunk_ids == []: + kg_chunk_pick_method = "WEIGHT" + logger.warning( + "No relation-related chunks selected by vector similarity, falling back to WEIGHT method" + ) + else: + logger.info( + f"Selecting {len(selected_chunk_ids)} relation-related chunks by vector similarity" + ) + + except Exception as e: + logger.error( + f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" + + if kg_chunk_pick_method == "WEIGHT": + # Apply linear gradient weighted polling algorithm + selected_chunk_ids = linear_gradient_weighted_polling( + relations_with_chunks, max_related_chunks, min_related_chunks=1 + ) + + logger.info( + f"Selecting {len(selected_chunk_ids)} relation-related chunks by weighted polling" + ) logger.debug( - f"Found {len(selected_chunk_ids)} relationship-related chunks using linear gradient weighted polling" - ) - logger.info( f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations" ) if not selected_chunk_ids: return [] - # Step 4.5: Remove duplicates with entity_chunks before batch retrieval - if entity_chunks: - # Extract chunk IDs from entity_chunks - entity_chunk_ids = set() - for chunk in entity_chunks: - chunk_id = chunk.get("chunk_id") - if chunk_id: - entity_chunk_ids.add(chunk_id) - - # Filter out duplicate chunk IDs - original_count = len(selected_chunk_ids) - selected_chunk_ids = [ - chunk_id - for chunk_id in selected_chunk_ids - if chunk_id not in entity_chunk_ids - ] - - logger.debug( - f"Deduplication relation-chunks with entity-chunks: {original_count} -> {len(selected_chunk_ids)} chunks " - ) - - # Early return if no chunks remain after deduplication - if not selected_chunk_ids: - return [] - # Step 5: Batch retrieve chunk data unique_chunk_ids = list( dict.fromkeys(selected_chunk_ids) diff --git a/lightrag/utils.py b/lightrag/utils.py index 96b7bdc3..a1d0dca9 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -60,7 +60,7 @@ def get_env_value( # Use TYPE_CHECKING to avoid circular imports if TYPE_CHECKING: - from lightrag.base import BaseKVStorage, QueryParam + from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance @@ -1650,6 +1650,115 @@ def linear_gradient_weighted_polling( return selected_chunks +async def vector_similarity_sorting( + query: str, + text_chunks_storage: "BaseKVStorage", + chunks_vdb: "BaseVectorStorage", + num_of_chunks: int, + entity_info: list[dict[str, Any]], + embedding_func: callable, +) -> list[str]: + """ + Vector similarity-based text chunk selection algorithm. + + This algorithm selects text chunks based on cosine similarity between + the query embedding and text chunk embeddings. + + Args: + query: User's original query string + text_chunks_storage: Text chunks storage instance + chunks_vdb: Vector database storage for chunks + num_of_chunks: Number of chunks to select + entity_info: List of entity information containing chunk IDs + embedding_func: Embedding function to compute query embedding + + Returns: + List of selected text chunk IDs sorted by similarity (highest first) + """ + logger.debug( + f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}" + ) + + if not entity_info or num_of_chunks <= 0: + return [] + + # Collect all unique chunk IDs from entity info + all_chunk_ids = set() + for i, entity in enumerate(entity_info): + chunk_ids = entity.get("sorted_chunks", []) + all_chunk_ids.update(chunk_ids) + + if not all_chunk_ids: + logger.warning( + "Vector similarity chunk selection: no chunk IDs found in entity_info" + ) + return [] + + logger.debug( + f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected" + ) + + all_chunk_ids = list(all_chunk_ids) + + try: + # Get query embedding + query_embedding = await embedding_func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + + # Get chunk embeddings from vector database + chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids) + logger.debug( + f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved" + ) + + if not chunk_vectors: + logger.warning( + "Vector similarity chunk selection: no vectors retrieved from chunks_vdb" + ) + return [] + + # Calculate cosine similarities + similarities = [] + valid_vectors = 0 + for chunk_id in all_chunk_ids: + if chunk_id in chunk_vectors: + chunk_embedding = chunk_vectors[chunk_id] + try: + # Calculate cosine similarity + similarity = cosine_similarity(query_embedding, chunk_embedding) + similarities.append((chunk_id, similarity)) + valid_vectors += 1 + except Exception as e: + logger.warning( + f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}" + ) + else: + logger.warning( + f"Vector similarity chunk selection: no vector found for chunk {chunk_id}" + ) + + # Sort by similarity (highest first) and select top num_of_chunks + similarities.sort(key=lambda x: x[1], reverse=True) + selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]] + + logger.debug( + f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates" + ) + + return selected_chunks + + except Exception as e: + logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}") + import traceback + + logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}") + # Fallback to simple truncation + logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation") + return all_chunk_ids[:num_of_chunks] + + class TokenTracker: """Track token usage for LLM calls."""