Remove deprecated ID-based filtering from vector storage queries

- Remove ids param from QueryParam
- Simplify BaseVectorStorage.query signature
- Update all vector storage implementations
- Streamline PostgreSQL query templates
- Remove ID filtering from operate.py calls
This commit is contained in:
yangdx 2025-08-29 17:06:48 +08:00
parent 20b800d694
commit a923d378dd
8 changed files with 26 additions and 100 deletions

View file

@ -142,10 +142,6 @@ class QueryParam:
history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS))) history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS)))
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
# TODO: TODO: Deprecated - ID-based filtering only applies to chunks, not entities or relations, and implemented only in PostgreSQL storage
ids: list[str] | None = None
"""List of doc ids to filter the results."""
model_func: Callable[..., object] | None = None model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query. """Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function. If provided, this will be used instead of the global model function.
@ -215,9 +211,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
meta_fields: set[str] = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
@abstractmethod @abstractmethod
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.""" """Query the vector storage and retrieve top_k results."""
@abstractmethod @abstractmethod

View file

@ -179,9 +179,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
""" """
Search by a textual query; returns top_k results with their metadata + similarity distance. Search by a textual query; returns top_k results with their metadata + similarity distance.
""" """

View file

@ -1046,9 +1046,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
) )
return results return results
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
# Ensure collection is loaded before querying # Ensure collection is loaded before querying
self._ensure_collection_loaded() self._ensure_collection_loaded()

View file

@ -1809,9 +1809,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data return list_data
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search.""" """Queries the vector database using Atlas Vector Search."""
# Generate the embedding # Generate the embedding
embedding = await self.embedding_func( embedding = await self.embedding_func(

View file

@ -136,9 +136,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
) )
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid improve cocurrent # Execute embedding outside of lock to avoid improve cocurrent
embedding = await self.embedding_func( embedding = await self.embedding_func(
[query], _priority=5 [query], _priority=5

View file

@ -2004,19 +2004,16 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
#################### query method ############### #################### query method ###############
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embeddings = await self.embedding_func( embeddings = await self.embedding_func(
[query], _priority=5 [query], _priority=5
) # higher priority for query ) # higher priority for query
embedding = embeddings[0] embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
# Use parameterized document IDs (None means search across all documents)
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
params = { params = {
"workspace": self.workspace, "workspace": self.workspace,
"doc_ids": ids,
"closer_than_threshold": 1 - self.cosine_better_than_threshold, "closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k, "top_k": top_k,
} }
@ -4582,85 +4579,34 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time
""", """,
"relationships": """ "relationships": """
WITH relevant_chunks AS (SELECT id as chunk_id SELECT r.source_id AS src_id,
FROM LIGHTRAG_VDB_CHUNKS r.target_id AS tgt_id,
WHERE $2 EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_RELATION r FROM LIGHTRAG_VDB_RELATION r
WHERE r.workspace = $1 WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50) LIMIT $3;
)
SELECT c.src_id,
c.tgt_id,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"entities": """ "entities": """
WITH relevant_chunks AS (SELECT id as chunk_id SELECT e.entity_name,
FROM LIGHTRAG_VDB_CHUNKS EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_ENTITY e FROM LIGHTRAG_VDB_ENTITY e
WHERE e.workspace = $1 WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50) LIMIT $3;
)
SELECT c.entity_name,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"chunks": """ "chunks": """
WITH relevant_chunks AS (SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_CHUNKS
WHERE workspace = $1
ORDER BY content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50)
)
SELECT c.id, SELECT c.id,
c.content, c.content,
c.file_path, c.file_path,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
FROM cand c FROM LIGHTRAG_VDB_CHUNKS c
JOIN rc ON TRUE WHERE c.workspace = $1
WHERE c.dist < $3 AND c.content_vector <=> '[{embedding_string}]'::vector < $2
AND c.id = ANY (rc.chunk_arr) ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
ORDER BY c.dist, c.id LIMIT $3;
LIMIT $4;
""", """,
# DROP tables # DROP tables
"drop_specifiy_table_workspace": """ "drop_specifiy_table_workspace": """

View file

@ -199,9 +199,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
return results return results
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func( embedding = await self.embedding_func(
[query], _priority=5 [query], _priority=5
) # higher priority for query ) # higher priority for query

View file

@ -2253,7 +2253,7 @@ async def _get_vector_context(
# Use chunk_top_k if specified, otherwise fall back to top_k # Use chunk_top_k if specified, otherwise fall back to top_k
search_top_k = query_param.chunk_top_k or query_param.top_k search_top_k = query_param.chunk_top_k or query_param.top_k
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) results = await chunks_vdb.query(query, top_k=search_top_k)
if not results: if not results:
logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})") logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})")
return [] return []
@ -2830,9 +2830,7 @@ async def _get_node_data(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
) )
results = await entities_vdb.query( results = await entities_vdb.query(query, top_k=query_param.top_k)
query, top_k=query_param.top_k, ids=query_param.ids
)
if not len(results): if not len(results):
return [], [] return [], []
@ -3108,9 +3106,7 @@ async def _get_edge_data(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
) )
results = await relationships_vdb.query( results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
keywords, top_k=query_param.top_k, ids=query_param.ids
)
if not len(results): if not len(results):
return [], [] return [], []