refactor: make include_payload use in vector databases a bit more readable

This commit is contained in:
Igor Ilic 2026-01-14 23:33:05 +01:00
parent d5a888e6c0
commit 51a9ff0613
3 changed files with 64 additions and 101 deletions

View file

@ -355,7 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False, # TODO: Add support for this parameter
include_payload: bool = False, # TODO: Add support for this parameter when set to False
):
"""
Search for items in a collection using either a text or a vector query.

View file

@ -248,40 +248,30 @@ class LanceDBAdapter(VectorDBInterface):
if limit <= 0:
return []
if include_payload:
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
if not result_values:
return []
normalized_values = normalize_distances(result_values)
# Note: Exclude payload if not needed to optimize performance
select_columns = (
["id", "vector", "payload", "_distance"]
if include_payload
else ["id", "vector", "_distance"]
)
result_values = (
await collection.vector_search(query_vector)
.select(select_columns)
.limit(limit)
.to_list()
)
if not result_values:
return []
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
]
else:
result_values = await (
collection.vector_search(query_vector)
.limit(limit)
.select(["id", "vector", "_distance"])
.to_list()
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"] if include_payload else None,
score=normalized_values[value_index],
)
if not result_values:
return []
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
]
for value_index, result in enumerate(result_values)
]
async def batch_search(
self,

View file

@ -325,81 +325,54 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
if include_payload:
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
# Note: Exclude payload from returned columns if not needed to optimize performance
select_columns = (
[PGVectorDataPoint]
if include_payload
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
)
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
*select_columns,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if limit > 0:
query = query.limit(limit)
if limit > 0:
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
vector_list = []
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"_distance": vector.similarity,
}
)
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload if include_payload else None,
"_distance": vector.similarity,
}
)
if len(vector_list) == 0:
return []
if len(vector_list) == 0:
return []
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
for row in vector_list
]
else:
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint.c.id,
PGVectorDataPoint.c.vector,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if limit > 0:
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"_distance": vector.similarity,
}
)
if len(vector_list) == 0:
return []
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list]
# Create and return ScoredResult objects
return [
ScoredResult(
id=row.get("id"),
payload=row.get("payload") if include_payload else None,
score=row.get("score"),
)
for row in vector_list
]
async def batch_search(
self,