Merge pull request #2025 from danielaskdd/remove-ids-filter

refac: Remove deprecated doc-id based filtering from vector storage queries
This commit is contained in:
Daniel.y 2025-08-29 19:39:42 +08:00 committed by GitHub
commit 163ec26e10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 102 additions and 121 deletions

View file

@ -142,10 +142,6 @@ class QueryParam:
history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS))) history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS)))
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
# TODO: TODO: Deprecated - ID-based filtering only applies to chunks, not entities or relations, and implemented only in PostgreSQL storage
ids: list[str] | None = None
"""List of doc ids to filter the results."""
model_func: Callable[..., object] | None = None model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query. """Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function. If provided, this will be used instead of the global model function.
@ -216,9 +212,16 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.""" """Query the vector storage and retrieve top_k results.
Args:
query: The query string to search for
top_k: Number of top results to return
query_embedding: Optional pre-computed embedding for the query.
If provided, skips embedding computation for better performance.
"""
@abstractmethod @abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:

View file

@ -180,16 +180,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Search by a textual query; returns top_k results with their metadata + similarity distance. Search by a textual query; returns top_k results with their metadata + similarity distance.
""" """
embedding = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = np.array([query_embedding], dtype=np.float32)
) # higher priority for query else:
# embedding is shape (1, dim) embedding = await self.embedding_func(
embedding = np.array(embedding, dtype=np.float32) [query], _priority=5
) # higher priority for query
# embedding is shape (1, dim)
embedding = np.array(embedding, dtype=np.float32)
faiss.normalize_L2(embedding) # we do in-place normalization faiss.normalize_L2(embedding) # we do in-place normalization
# Perform the similarity search # Perform the similarity search

View file

@ -1047,14 +1047,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return results return results
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# Ensure collection is loaded before querying # Ensure collection is loaded before querying
self._ensure_collection_loaded() self._ensure_collection_loaded()
embedding = await self.embedding_func( # Use provided embedding or compute it
[query], _priority=5 if query_embedding is not None:
) # higher priority for query embedding = [query_embedding] # Milvus expects a list of embeddings
else:
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Include all meta_fields (created_at is now always included) # Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields) output_fields = list(self.meta_fields)

View file

@ -1810,16 +1810,22 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data return list_data
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search.""" """Queries the vector database using Atlas Vector Search."""
# Generate the embedding if query_embedding is not None:
embedding = await self.embedding_func( # Convert numpy array to list if needed for MongoDB compatibility
[query], _priority=5 if hasattr(query_embedding, "tolist"):
) # higher priority for query query_vector = query_embedding.tolist()
else:
# Convert numpy array to a list to ensure compatibility with MongoDB query_vector = list(query_embedding)
query_vector = embedding[0].tolist() else:
# Generate the embedding
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Convert numpy array to a list to ensure compatibility with MongoDB
query_vector = embedding[0].tolist()
# Define the aggregation pipeline with the converted query vector # Define the aggregation pipeline with the converted query vector
pipeline = [ pipeline = [

View file

@ -137,13 +137,17 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid improve cocurrent # Use provided embedding or compute it
embedding = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = query_embedding
) # higher priority for query else:
embedding = embedding[0] # Execute embedding outside of lock to avoid improve cocurrent
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding[0]
client = await self._get_client() client = await self._get_client()
results = client.query( results = client.query(

View file

@ -2005,18 +2005,21 @@ class PGVectorStorage(BaseVectorStorage):
#################### query method ############### #################### query method ###############
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
embeddings = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = query_embedding
) # higher priority for query else:
embedding = embeddings[0] embeddings = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
# Use parameterized document IDs (None means search across all documents)
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
params = { params = {
"workspace": self.workspace, "workspace": self.workspace,
"doc_ids": ids,
"closer_than_threshold": 1 - self.cosine_better_than_threshold, "closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k, "top_k": top_k,
} }
@ -4582,85 +4585,34 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time
""", """,
"relationships": """ "relationships": """
WITH relevant_chunks AS (SELECT id as chunk_id SELECT r.source_id AS src_id,
FROM LIGHTRAG_VDB_CHUNKS r.target_id AS tgt_id,
WHERE $2 EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_RELATION r FROM LIGHTRAG_VDB_RELATION r
WHERE r.workspace = $1 WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50) LIMIT $3;
)
SELECT c.src_id,
c.tgt_id,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"entities": """ "entities": """
WITH relevant_chunks AS (SELECT id as chunk_id SELECT e.entity_name,
FROM LIGHTRAG_VDB_CHUNKS EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_ENTITY e FROM LIGHTRAG_VDB_ENTITY e
WHERE e.workspace = $1 WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50) LIMIT $3;
)
SELECT c.entity_name,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"chunks": """ "chunks": """
WITH relevant_chunks AS (SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_CHUNKS
WHERE workspace = $1
ORDER BY content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50)
)
SELECT c.id, SELECT c.id,
c.content, c.content,
c.file_path, c.file_path,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
FROM cand c FROM LIGHTRAG_VDB_CHUNKS c
JOIN rc ON TRUE WHERE c.workspace = $1
WHERE c.dist < $3 AND c.content_vector <=> '[{embedding_string}]'::vector < $2
AND c.id = ANY (rc.chunk_arr) ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
ORDER BY c.dist, c.id LIMIT $3;
LIMIT $4;
""", """,
# DROP tables # DROP tables
"drop_specifiy_table_workspace": """ "drop_specifiy_table_workspace": """

View file

@ -200,14 +200,19 @@ class QdrantVectorDBStorage(BaseVectorStorage):
return results return results
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
embedding = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = query_embedding
) # higher priority for query else:
embedding_result = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding_result[0]
results = self._client.search( results = self._client.search(
collection_name=self.final_namespace, collection_name=self.final_namespace,
query_vector=embedding[0], query_vector=embedding,
limit=top_k, limit=top_k,
with_payload=True, with_payload=True,
score_threshold=self.cosine_better_than_threshold, score_threshold=self.cosine_better_than_threshold,

View file

@ -2234,6 +2234,7 @@ async def _get_vector_context(
query: str, query: str,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
query_param: QueryParam, query_param: QueryParam,
query_embedding: list[float] = None,
) -> list[dict]: ) -> list[dict]:
""" """
Retrieve text chunks from the vector database without reranking or truncation. Retrieve text chunks from the vector database without reranking or truncation.
@ -2245,6 +2246,7 @@ async def _get_vector_context(
query: The query string to search for query: The query string to search for
chunks_vdb: Vector database containing document chunks chunks_vdb: Vector database containing document chunks
query_param: Query parameters including chunk_top_k and ids query_param: Query parameters including chunk_top_k and ids
query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls
Returns: Returns:
List of text chunks with metadata List of text chunks with metadata
@ -2253,7 +2255,9 @@ async def _get_vector_context(
# Use chunk_top_k if specified, otherwise fall back to top_k # Use chunk_top_k if specified, otherwise fall back to top_k
search_top_k = query_param.chunk_top_k or query_param.top_k search_top_k = query_param.chunk_top_k or query_param.top_k
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) results = await chunks_vdb.query(
query, top_k=search_top_k, query_embedding=query_embedding
)
if not results: if not results:
logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})") logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})")
return [] return []
@ -2291,6 +2295,10 @@ async def _build_query_context(
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
): ):
if not query:
logger.warning("Query is empty, skipping context building")
return ""
logger.info(f"Process {os.getpid()} building query context...") logger.info(f"Process {os.getpid()} building query context...")
# Collect chunks from different sources separately # Collect chunks from different sources separately
@ -2309,12 +2317,12 @@ async def _build_query_context(
# Track chunk sources and metadata for final logging # Track chunk sources and metadata for final logging
chunk_tracking = {} # chunk_id -> {source, frequency, order} chunk_tracking = {} # chunk_id -> {source, frequency, order}
# Pre-compute query embedding if vector similarity method is used # Pre-compute query embedding once for all vector operations
kg_chunk_pick_method = text_chunks_db.global_config.get( kg_chunk_pick_method = text_chunks_db.global_config.get(
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
) )
query_embedding = None query_embedding = None
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
embedding_func_config = text_chunks_db.embedding_func embedding_func_config = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func: if embedding_func_config and embedding_func_config.func:
try: try:
@ -2322,9 +2330,7 @@ async def _build_query_context(
query_embedding = query_embedding[ query_embedding = query_embedding[
0 0
] # Extract first embedding from batch result ] # Extract first embedding from batch result
logger.debug( logger.debug("Pre-computed query embedding for all vector operations")
"Pre-computed query embedding for vector similarity chunk selection"
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to pre-compute query embedding: {e}") logger.warning(f"Failed to pre-compute query embedding: {e}")
query_embedding = None query_embedding = None
@ -2368,6 +2374,7 @@ async def _build_query_context(
query, query,
chunks_vdb, chunks_vdb,
query_param, query_param,
query_embedding,
) )
# Track vector chunks with source metadata # Track vector chunks with source metadata
for i, chunk in enumerate(vector_chunks): for i, chunk in enumerate(vector_chunks):
@ -2830,9 +2837,7 @@ async def _get_node_data(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
) )
results = await entities_vdb.query( results = await entities_vdb.query(query, top_k=query_param.top_k)
query, top_k=query_param.top_k, ids=query_param.ids
)
if not len(results): if not len(results):
return [], [] return [], []
@ -3108,9 +3113,7 @@ async def _get_edge_data(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
) )
results = await relationships_vdb.query( results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
keywords, top_k=query_param.top_k, ids=query_param.ids
)
if not len(results): if not len(results):
return [], [] return [], []
@ -3433,7 +3436,7 @@ async def naive_query(
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
chunks = await _get_vector_context(query, chunks_vdb, query_param) chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0: if chunks is None or len(chunks) == 0:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]