diff --git a/lightrag/base.py b/lightrag/base.py index fe92e785..c5518d23 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -142,10 +142,6 @@ class QueryParam: history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS))) """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" - # TODO: TODO: Deprecated - ID-based filtering only applies to chunks, not entities or relations, and implemented only in PostgreSQL storage - ids: list[str] | None = None - """List of doc ids to filter the results.""" - model_func: Callable[..., object] | None = None """Optional override for the LLM model function to use for this specific query. If provided, this will be used instead of the global model function. @@ -216,9 +212,16 @@ class BaseVectorStorage(StorageNameSpace, ABC): @abstractmethod async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: - """Query the vector storage and retrieve top_k results.""" + """Query the vector storage and retrieve top_k results. + + Args: + query: The query string to search for + top_k: Number of top results to return + query_embedding: Optional pre-computed embedding for the query. + If provided, skips embedding computation for better performance. + """ @abstractmethod async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 5098ebf7..7d6a6dac 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -180,16 +180,20 @@ class FaissVectorDBStorage(BaseVectorStorage): return [m["__id__"] for m in list_data] async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. """ - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - # embedding is shape (1, dim) - embedding = np.array(embedding, dtype=np.float32) + if query_embedding is not None: + embedding = np.array([query_embedding], dtype=np.float32) + else: + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + # embedding is shape (1, dim) + embedding = np.array(embedding, dtype=np.float32) + faiss.normalize_L2(embedding) # we do in-place normalization # Perform the similarity search diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 82dce30c..f2368afe 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1047,14 +1047,18 @@ class MilvusVectorDBStorage(BaseVectorStorage): return results async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: # Ensure collection is loaded before querying self._ensure_collection_loaded() - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query + # Use provided embedding or compute it + if query_embedding is not None: + embedding = [query_embedding] # Milvus expects a list of embeddings + else: + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query # Include all meta_fields (created_at is now always included) output_fields = list(self.meta_fields) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index e7ea9a0a..9e4d7e67 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1810,16 +1810,22 @@ class MongoVectorDBStorage(BaseVectorStorage): return list_data async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" - # Generate the embedding - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - - # Convert numpy array to a list to ensure compatibility with MongoDB - query_vector = embedding[0].tolist() + if query_embedding is not None: + # Convert numpy array to list if needed for MongoDB compatibility + if hasattr(query_embedding, "tolist"): + query_vector = query_embedding.tolist() + else: + query_vector = list(query_embedding) + else: + # Generate the embedding + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + # Convert numpy array to a list to ensure compatibility with MongoDB + query_vector = embedding[0].tolist() # Define the aggregation pipeline with the converted query vector pipeline = [ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 5bec06f4..def5a83d 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -137,13 +137,17 @@ class NanoVectorDBStorage(BaseVectorStorage): ) async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: - # Execute embedding outside of lock to avoid improve cocurrent - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - embedding = embedding[0] + # Use provided embedding or compute it + if query_embedding is not None: + embedding = query_embedding + else: + # Execute embedding outside of lock to avoid improve cocurrent + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + embedding = embedding[0] client = await self._get_client() results = client.query( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 88a75ba5..03a26f54 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -2005,18 +2005,21 @@ class PGVectorStorage(BaseVectorStorage): #################### query method ############### async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: - embeddings = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - embedding = embeddings[0] + if query_embedding is not None: + embedding = query_embedding + else: + embeddings = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + embedding = embeddings[0] + embedding_string = ",".join(map(str, embedding)) - # Use parameterized document IDs (None means search across all documents) + sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) params = { "workspace": self.workspace, - "doc_ids": ids, "closer_than_threshold": 1 - self.cosine_better_than_threshold, "top_k": top_k, } @@ -4582,85 +4585,34 @@ SQL_TEMPLATES = { update_time = EXCLUDED.update_time """, "relationships": """ - WITH relevant_chunks AS (SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2 - :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) - ) - , rc AS ( - SELECT array_agg(chunk_id) AS chunk_arr - FROM relevant_chunks - ), cand AS ( - SELECT - r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist + SELECT r.source_id AS src_id, + r.target_id AS tgt_id, + EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at FROM LIGHTRAG_VDB_RELATION r WHERE r.workspace = $1 + AND r.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY r.content_vector <=> '[{embedding_string}]'::vector - LIMIT ($4 * 50) - ) - SELECT c.src_id, - c.tgt_id, - EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at - FROM cand c - JOIN rc ON TRUE - WHERE c.dist < $3 - AND c.chunk_ids && (rc.chunk_arr::varchar[]) - ORDER BY c.dist, c.id - LIMIT $4; + LIMIT $3; """, "entities": """ - WITH relevant_chunks AS (SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2 - :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) - ) - , rc AS ( - SELECT array_agg(chunk_id) AS chunk_arr - FROM relevant_chunks - ), cand AS ( - SELECT - e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist + SELECT e.entity_name, + EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at FROM LIGHTRAG_VDB_ENTITY e WHERE e.workspace = $1 + AND e.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY e.content_vector <=> '[{embedding_string}]'::vector - LIMIT ($4 * 50) - ) - SELECT c.entity_name, - EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at - FROM cand c - JOIN rc ON TRUE - WHERE c.dist < $3 - AND c.chunk_ids && (rc.chunk_arr::varchar[]) - ORDER BY c.dist, c.id - LIMIT $4; + LIMIT $3; """, "chunks": """ - WITH relevant_chunks AS (SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2 - :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) - ) - , rc AS ( - SELECT array_agg(chunk_id) AS chunk_arr - FROM relevant_chunks - ), cand AS ( - SELECT - id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist - FROM LIGHTRAG_VDB_CHUNKS - WHERE workspace = $1 - ORDER BY content_vector <=> '[{embedding_string}]'::vector - LIMIT ($4 * 50) - ) SELECT c.id, c.content, c.file_path, - EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at - FROM cand c - JOIN rc ON TRUE - WHERE c.dist < $3 - AND c.id = ANY (rc.chunk_arr) - ORDER BY c.dist, c.id - LIMIT $4; + EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at + FROM LIGHTRAG_VDB_CHUNKS c + WHERE c.workspace = $1 + AND c.content_vector <=> '[{embedding_string}]'::vector < $2 + ORDER BY c.content_vector <=> '[{embedding_string}]'::vector + LIMIT $3; """, # DROP tables "drop_specifiy_table_workspace": """ diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 4ece163c..dad95bbc 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -200,14 +200,19 @@ class QdrantVectorDBStorage(BaseVectorStorage): return results async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query + if query_embedding is not None: + embedding = query_embedding + else: + embedding_result = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + embedding = embedding_result[0] + results = self._client.search( collection_name=self.final_namespace, - query_vector=embedding[0], + query_vector=embedding, limit=top_k, with_payload=True, score_threshold=self.cosine_better_than_threshold, diff --git a/lightrag/operate.py b/lightrag/operate.py index 91252e00..afa8205f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2234,6 +2234,7 @@ async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, + query_embedding: list[float] = None, ) -> list[dict]: """ Retrieve text chunks from the vector database without reranking or truncation. @@ -2245,6 +2246,7 @@ async def _get_vector_context( query: The query string to search for chunks_vdb: Vector database containing document chunks query_param: Query parameters including chunk_top_k and ids + query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls Returns: List of text chunks with metadata @@ -2253,7 +2255,9 @@ async def _get_vector_context( # Use chunk_top_k if specified, otherwise fall back to top_k search_top_k = query_param.chunk_top_k or query_param.top_k - results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) + results = await chunks_vdb.query( + query, top_k=search_top_k, query_embedding=query_embedding + ) if not results: logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})") return [] @@ -2291,6 +2295,10 @@ async def _build_query_context( query_param: QueryParam, chunks_vdb: BaseVectorStorage = None, ): + if not query: + logger.warning("Query is empty, skipping context building") + return "" + logger.info(f"Process {os.getpid()} building query context...") # Collect chunks from different sources separately @@ -2309,12 +2317,12 @@ async def _build_query_context( # Track chunk sources and metadata for final logging chunk_tracking = {} # chunk_id -> {source, frequency, order} - # Pre-compute query embedding if vector similarity method is used + # Pre-compute query embedding once for all vector operations kg_chunk_pick_method = text_chunks_db.global_config.get( "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD ) query_embedding = None - if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: + if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): embedding_func_config = text_chunks_db.embedding_func if embedding_func_config and embedding_func_config.func: try: @@ -2322,9 +2330,7 @@ async def _build_query_context( query_embedding = query_embedding[ 0 ] # Extract first embedding from batch result - logger.debug( - "Pre-computed query embedding for vector similarity chunk selection" - ) + logger.debug("Pre-computed query embedding for all vector operations") except Exception as e: logger.warning(f"Failed to pre-compute query embedding: {e}") query_embedding = None @@ -2368,6 +2374,7 @@ async def _build_query_context( query, chunks_vdb, query_param, + query_embedding, ) # Track vector chunks with source metadata for i, chunk in enumerate(vector_chunks): @@ -2830,9 +2837,7 @@ async def _get_node_data( f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" ) - results = await entities_vdb.query( - query, top_k=query_param.top_k, ids=query_param.ids - ) + results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return [], [] @@ -3108,9 +3113,7 @@ async def _get_edge_data( f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" ) - results = await relationships_vdb.query( - keywords, top_k=query_param.top_k, ids=query_param.ids - ) + results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): return [], [] @@ -3433,7 +3436,7 @@ async def naive_query( tokenizer: Tokenizer = global_config["tokenizer"] - chunks = await _get_vector_context(query, chunks_vdb, query_param) + chunks = await _get_vector_context(query, chunks_vdb, query_param, None) if chunks is None or len(chunks) == 0: return PROMPTS["fail_response"]