refactor: make include_payload use in vector databases a bit more readable
This commit is contained in:
parent
d5a888e6c0
commit
51a9ff0613
3 changed files with 64 additions and 101 deletions
|
|
@ -355,7 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
limit: Optional[int] = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
include_payload: bool = False, # TODO: Add support for this parameter
|
include_payload: bool = False, # TODO: Add support for this parameter when set to False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Search for items in a collection using either a text or a vector query.
|
Search for items in a collection using either a text or a vector query.
|
||||||
|
|
|
||||||
|
|
@ -248,40 +248,30 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if include_payload:
|
# Note: Exclude payload if not needed to optimize performance
|
||||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
select_columns = (
|
||||||
if not result_values:
|
["id", "vector", "payload", "_distance"]
|
||||||
return []
|
if include_payload
|
||||||
normalized_values = normalize_distances(result_values)
|
else ["id", "vector", "_distance"]
|
||||||
|
)
|
||||||
|
result_values = (
|
||||||
|
await collection.vector_search(query_vector)
|
||||||
|
.select(select_columns)
|
||||||
|
.limit(limit)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
if not result_values:
|
||||||
|
return []
|
||||||
|
normalized_values = normalize_distances(result_values)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=parse_id(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
payload=result["payload"],
|
payload=result["payload"] if include_payload else None,
|
||||||
score=normalized_values[value_index],
|
score=normalized_values[value_index],
|
||||||
)
|
|
||||||
for value_index, result in enumerate(result_values)
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
|
||||||
result_values = await (
|
|
||||||
collection.vector_search(query_vector)
|
|
||||||
.limit(limit)
|
|
||||||
.select(["id", "vector", "_distance"])
|
|
||||||
.to_list()
|
|
||||||
)
|
)
|
||||||
if not result_values:
|
for value_index, result in enumerate(result_values)
|
||||||
return []
|
]
|
||||||
|
|
||||||
normalized_values = normalize_distances(result_values)
|
|
||||||
|
|
||||||
return [
|
|
||||||
ScoredResult(
|
|
||||||
id=parse_id(result["id"]),
|
|
||||||
score=normalized_values[value_index],
|
|
||||||
)
|
|
||||||
for value_index, result in enumerate(result_values)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -325,81 +325,54 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||||
closest_items = []
|
closest_items = []
|
||||||
|
|
||||||
if include_payload:
|
# Note: Exclude payload from returned columns if not needed to optimize performance
|
||||||
# Use async session to connect to the database
|
select_columns = (
|
||||||
async with self.get_async_session() as session:
|
[PGVectorDataPoint]
|
||||||
query = select(
|
if include_payload
|
||||||
PGVectorDataPoint,
|
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
|
||||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
)
|
||||||
).order_by("similarity")
|
# Use async session to connect to the database
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
query = select(
|
||||||
|
*select_columns,
|
||||||
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||||
|
).order_by("similarity")
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
||||||
# Find closest vectors to query_vector
|
# Find closest vectors to query_vector
|
||||||
closest_items = await session.execute(query)
|
closest_items = await session.execute(query)
|
||||||
|
|
||||||
vector_list = []
|
vector_list = []
|
||||||
|
|
||||||
# Extract distances and find min/max for normalization
|
# Extract distances and find min/max for normalization
|
||||||
for vector in closest_items.all():
|
for vector in closest_items.all():
|
||||||
vector_list.append(
|
vector_list.append(
|
||||||
{
|
{
|
||||||
"id": parse_id(str(vector.id)),
|
"id": parse_id(str(vector.id)),
|
||||||
"payload": vector.payload,
|
"payload": vector.payload if include_payload else None,
|
||||||
"_distance": vector.similarity,
|
"_distance": vector.similarity,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(vector_list) == 0:
|
if len(vector_list) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Normalize vector distance and add this as score information to vector_list
|
# Normalize vector distance and add this as score information to vector_list
|
||||||
normalized_values = normalize_distances(vector_list)
|
normalized_values = normalize_distances(vector_list)
|
||||||
for i in range(0, len(normalized_values)):
|
for i in range(0, len(normalized_values)):
|
||||||
vector_list[i]["score"] = normalized_values[i]
|
vector_list[i]["score"] = normalized_values[i]
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
ScoredResult(
|
||||||
for row in vector_list
|
id=row.get("id"),
|
||||||
]
|
payload=row.get("payload") if include_payload else None,
|
||||||
else:
|
score=row.get("score"),
|
||||||
# Use async session to connect to the database
|
)
|
||||||
async with self.get_async_session() as session:
|
for row in vector_list
|
||||||
query = select(
|
]
|
||||||
PGVectorDataPoint.c.id,
|
|
||||||
PGVectorDataPoint.c.vector,
|
|
||||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
|
||||||
).order_by("similarity")
|
|
||||||
|
|
||||||
if limit > 0:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
# Find closest vectors to query_vector
|
|
||||||
closest_items = await session.execute(query)
|
|
||||||
|
|
||||||
vector_list = []
|
|
||||||
|
|
||||||
# Extract distances and find min/max for normalization
|
|
||||||
for vector in closest_items.all():
|
|
||||||
vector_list.append(
|
|
||||||
{
|
|
||||||
"id": parse_id(str(vector.id)),
|
|
||||||
"_distance": vector.similarity,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(vector_list) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Normalize vector distance and add this as score information to vector_list
|
|
||||||
normalized_values = normalize_distances(vector_list)
|
|
||||||
for i in range(0, len(normalized_values)):
|
|
||||||
vector_list[i]["score"] = normalized_values[i]
|
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
|
||||||
return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list]
|
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue