Remove deprecated ID-based filtering from vector storage queries
- Remove ids param from QueryParam - Simplify BaseVectorStorage.query signature - Update all vector storage implementations - Streamline PostgreSQL query templates - Remove ID filtering from operate.py calls
This commit is contained in:
parent
20b800d694
commit
a923d378dd
8 changed files with 26 additions and 100 deletions
|
|
@ -142,10 +142,6 @@ class QueryParam:
|
|||
history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS)))
|
||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||
|
||||
# TODO: TODO: Deprecated - ID-based filtering only applies to chunks, not entities or relations, and implemented only in PostgreSQL storage
|
||||
ids: list[str] | None = None
|
||||
"""List of doc ids to filter the results."""
|
||||
|
||||
model_func: Callable[..., object] | None = None
|
||||
"""Optional override for the LLM model function to use for this specific query.
|
||||
If provided, this will be used instead of the global model function.
|
||||
|
|
@ -215,9 +211,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Query the vector storage and retrieve top_k results."""
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -179,9 +179,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||
)
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1046,9 +1046,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||
)
|
||||
return results
|
||||
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
# Ensure collection is loaded before querying
|
||||
self._ensure_collection_loaded()
|
||||
|
||||
|
|
|
|||
|
|
@ -1809,9 +1809,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||
|
||||
return list_data
|
||||
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func(
|
||||
|
|
|
|||
|
|
@ -136,9 +136,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||
f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
# Execute embedding outside of lock to avoid improve cocurrent
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
|
|
|
|||
|
|
@ -2004,19 +2004,16 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
#################### query method ###############
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embeddings[0]
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
# Use parameterized document IDs (None means search across all documents)
|
||||
|
||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
||||
params = {
|
||||
"workspace": self.workspace,
|
||||
"doc_ids": ids,
|
||||
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
||||
"top_k": top_k,
|
||||
}
|
||||
|
|
@ -4582,85 +4579,34 @@ SQL_TEMPLATES = {
|
|||
update_time = EXCLUDED.update_time
|
||||
""",
|
||||
"relationships": """
|
||||
WITH relevant_chunks AS (SELECT id as chunk_id
|
||||
FROM LIGHTRAG_VDB_CHUNKS
|
||||
WHERE $2
|
||||
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
|
||||
)
|
||||
, rc AS (
|
||||
SELECT array_agg(chunk_id) AS chunk_arr
|
||||
FROM relevant_chunks
|
||||
), cand AS (
|
||||
SELECT
|
||||
r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist
|
||||
SELECT r.source_id AS src_id,
|
||||
r.target_id AS tgt_id,
|
||||
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_RELATION r
|
||||
WHERE r.workspace = $1
|
||||
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT ($4 * 50)
|
||||
)
|
||||
SELECT c.src_id,
|
||||
c.tgt_id,
|
||||
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
|
||||
FROM cand c
|
||||
JOIN rc ON TRUE
|
||||
WHERE c.dist < $3
|
||||
AND c.chunk_ids && (rc.chunk_arr::varchar[])
|
||||
ORDER BY c.dist, c.id
|
||||
LIMIT $4;
|
||||
LIMIT $3;
|
||||
""",
|
||||
"entities": """
|
||||
WITH relevant_chunks AS (SELECT id as chunk_id
|
||||
FROM LIGHTRAG_VDB_CHUNKS
|
||||
WHERE $2
|
||||
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
|
||||
)
|
||||
, rc AS (
|
||||
SELECT array_agg(chunk_id) AS chunk_arr
|
||||
FROM relevant_chunks
|
||||
), cand AS (
|
||||
SELECT
|
||||
e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist
|
||||
SELECT e.entity_name,
|
||||
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_ENTITY e
|
||||
WHERE e.workspace = $1
|
||||
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT ($4 * 50)
|
||||
)
|
||||
SELECT c.entity_name,
|
||||
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
|
||||
FROM cand c
|
||||
JOIN rc ON TRUE
|
||||
WHERE c.dist < $3
|
||||
AND c.chunk_ids && (rc.chunk_arr::varchar[])
|
||||
ORDER BY c.dist, c.id
|
||||
LIMIT $4;
|
||||
LIMIT $3;
|
||||
""",
|
||||
"chunks": """
|
||||
WITH relevant_chunks AS (SELECT id as chunk_id
|
||||
FROM LIGHTRAG_VDB_CHUNKS
|
||||
WHERE $2
|
||||
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
|
||||
)
|
||||
, rc AS (
|
||||
SELECT array_agg(chunk_id) AS chunk_arr
|
||||
FROM relevant_chunks
|
||||
), cand AS (
|
||||
SELECT
|
||||
id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist
|
||||
FROM LIGHTRAG_VDB_CHUNKS
|
||||
WHERE workspace = $1
|
||||
ORDER BY content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT ($4 * 50)
|
||||
)
|
||||
SELECT c.id,
|
||||
c.content,
|
||||
c.file_path,
|
||||
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
|
||||
FROM cand c
|
||||
JOIN rc ON TRUE
|
||||
WHERE c.dist < $3
|
||||
AND c.id = ANY (rc.chunk_arr)
|
||||
ORDER BY c.dist, c.id
|
||||
LIMIT $4;
|
||||
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_CHUNKS c
|
||||
WHERE c.workspace = $1
|
||||
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT $3;
|
||||
""",
|
||||
# DROP tables
|
||||
"drop_specifiy_table_workspace": """
|
||||
|
|
|
|||
|
|
@ -199,9 +199,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||
)
|
||||
return results
|
||||
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
|
|
|
|||
|
|
@ -2253,7 +2253,7 @@ async def _get_vector_context(
|
|||
# Use chunk_top_k if specified, otherwise fall back to top_k
|
||||
search_top_k = query_param.chunk_top_k or query_param.top_k
|
||||
|
||||
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
|
||||
results = await chunks_vdb.query(query, top_k=search_top_k)
|
||||
if not results:
|
||||
logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})")
|
||||
return []
|
||||
|
|
@ -2830,9 +2830,7 @@ async def _get_node_data(
|
|||
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||||
)
|
||||
|
||||
results = await entities_vdb.query(
|
||||
query, top_k=query_param.top_k, ids=query_param.ids
|
||||
)
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||
|
||||
if not len(results):
|
||||
return [], []
|
||||
|
|
@ -3108,9 +3106,7 @@ async def _get_edge_data(
|
|||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||
)
|
||||
|
||||
results = await relationships_vdb.query(
|
||||
keywords, top_k=query_param.top_k, ids=query_param.ids
|
||||
)
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
|
||||
if not len(results):
|
||||
return [], []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue