feat: make payload inclusion optional for vector search

This commit is contained in:
Igor Ilic 2026-01-14 23:15:28 +01:00
parent 4765f9e4a0
commit d5a888e6c0
6 changed files with 112 additions and 47 deletions

View file

@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False, # TODO: Add support for this parameter
):
"""
Search for items in a collection using either a text or a vector query.

View file

@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False,
):
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -247,21 +248,40 @@ class LanceDBAdapter(VectorDBInterface):
if limit <= 0:
return []
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
if include_payload:
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
if not result_values:
return []
normalized_values = normalize_distances(result_values)
if not result_values:
return []
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
]
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
else:
result_values = await (
collection.vector_search(query_vector)
.limit(limit)
.select(["id", "vector", "_distance"])
.to_list()
)
for value_index, result in enumerate(result_values)
]
if not result_values:
return []
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
]
async def batch_search(
self,

View file

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional
from uuid import UUID
from pydantic import BaseModel
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
- id (UUID): Unique identifier for the scored result.
- score (float): The score associated with the result, where a lower score indicates a
better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
key-value pairs in a dictionary.
"""
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]
payload: Optional[Dict[str, Any]] = None

View file

@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vector: Optional[List[float]] = None,
limit: Optional[int] = 15,
with_vector: bool = False,
include_payload: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -324,44 +325,81 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if include_payload:
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if limit > 0:
query = query.limit(limit)
if limit > 0:
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
vector_list = []
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"_distance": vector.similarity,
}
)
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"_distance": vector.similarity,
}
)
if len(vector_list) == 0:
return []
if len(vector_list) == 0:
return []
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
for row in vector_list
]
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
for row in vector_list
]
else:
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint.c.id,
PGVectorDataPoint.c.vector,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if limit > 0:
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"_distance": vector.similarity,
}
)
if len(vector_list) == 0:
return []
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list]
async def batch_search(
self,

View file

@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
query_vector: Optional[List[float]],
limit: Optional[int],
with_vector: bool = False,
include_payload: bool = False,
):
"""
Perform a search in the specified collection using either a text query or a vector
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
- limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
"""
raise NotImplementedError

View file

@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
found_chunks = await vector_engine.search(
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
)
if len(found_chunks) == 0:
return ""