refactor: make include_payload use in vector databases a bit more readable
This commit is contained in:
parent
d5a888e6c0
commit
51a9ff0613
3 changed files with 64 additions and 101 deletions
|
|
@ -355,7 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
include_payload: bool = False, # TODO: Add support for this parameter
|
||||
include_payload: bool = False, # TODO: Add support for this parameter when set to False
|
||||
):
|
||||
"""
|
||||
Search for items in a collection using either a text or a vector query.
|
||||
|
|
|
|||
|
|
@ -248,40 +248,30 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if limit <= 0:
|
||||
return []
|
||||
|
||||
if include_payload:
|
||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
||||
if not result_values:
|
||||
return []
|
||||
normalized_values = normalize_distances(result_values)
|
||||
# Note: Exclude payload if not needed to optimize performance
|
||||
select_columns = (
|
||||
["id", "vector", "payload", "_distance"]
|
||||
if include_payload
|
||||
else ["id", "vector", "_distance"]
|
||||
)
|
||||
result_values = (
|
||||
await collection.vector_search(query_vector)
|
||||
.select(select_columns)
|
||||
.limit(limit)
|
||||
.to_list()
|
||||
)
|
||||
if not result_values:
|
||||
return []
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
|
||||
else:
|
||||
result_values = await (
|
||||
collection.vector_search(query_vector)
|
||||
.limit(limit)
|
||||
.select(["id", "vector", "_distance"])
|
||||
.to_list()
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"] if include_payload else None,
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
if not result_values:
|
||||
return []
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -325,81 +325,54 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
|
||||
if include_payload:
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
PGVectorDataPoint,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
# Note: Exclude payload from returned columns if not needed to optimize performance
|
||||
select_columns = (
|
||||
[PGVectorDataPoint]
|
||||
if include_payload
|
||||
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
|
||||
)
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
*select_columns,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
|
||||
vector_list = []
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
"payload": vector.payload,
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
"payload": vector.payload if include_payload else None,
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
|
||||
if len(vector_list) == 0:
|
||||
return []
|
||||
if len(vector_list) == 0:
|
||||
return []
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
for row in vector_list
|
||||
]
|
||||
else:
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
PGVectorDataPoint.c.id,
|
||||
PGVectorDataPoint.c.vector,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
|
||||
if len(vector_list) == 0:
|
||||
return []
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list]
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(
|
||||
id=row.get("id"),
|
||||
payload=row.get("payload") if include_payload else None,
|
||||
score=row.get("score"),
|
||||
)
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue