feat: make payload inclusion optional for vector search
This commit is contained in:
parent
4765f9e4a0
commit
d5a888e6c0
6 changed files with 112 additions and 47 deletions
|
|
@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
include_payload: bool = False, # TODO: Add support for this parameter
|
||||
):
|
||||
"""
|
||||
Search for items in a collection using either a text or a vector query.
|
||||
|
|
|
|||
|
|
@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise MissingQueryParameterError()
|
||||
|
|
@ -247,21 +248,40 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if limit <= 0:
|
||||
return []
|
||||
|
||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
||||
if include_payload:
|
||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
||||
if not result_values:
|
||||
return []
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
if not result_values:
|
||||
return []
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
score=normalized_values[value_index],
|
||||
else:
|
||||
result_values = await (
|
||||
collection.vector_search(query_vector)
|
||||
.limit(limit)
|
||||
.select(["id", "vector", "_distance"])
|
||||
.to_list()
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
if not result_values:
|
||||
return []
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
|
|||
- id (UUID): Unique identifier for the scored result.
|
||||
- score (float): The score associated with the result, where a lower score indicates a
|
||||
better outcome.
|
||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
||||
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
|
||||
key-value pairs in a dictionary.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
score: float # Lower score is better
|
||||
payload: Dict[str, Any]
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
|
|
|
|||
|
|
@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
query_vector: Optional[List[float]] = None,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
include_payload: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise MissingQueryParameterError()
|
||||
|
|
@ -324,44 +325,81 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
PGVectorDataPoint,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
if include_payload:
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
PGVectorDataPoint,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
|
||||
vector_list = []
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
"payload": vector.payload,
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
"payload": vector.payload,
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
|
||||
if len(vector_list) == 0:
|
||||
return []
|
||||
if len(vector_list) == 0:
|
||||
return []
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
for row in vector_list
|
||||
]
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
for row in vector_list
|
||||
]
|
||||
else:
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
PGVectorDataPoint.c.id,
|
||||
PGVectorDataPoint.c.vector,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
|
||||
if len(vector_list) == 0:
|
||||
return []
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
|
|||
query_vector: Optional[List[float]],
|
||||
limit: Optional[int],
|
||||
with_vector: bool = False,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform a search in the specified collection using either a text query or a vector
|
||||
|
|
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
|
|||
- limit (Optional[int]): The maximum number of results to return from the search.
|
||||
- with_vector (bool): Whether to return the vector representations with search
|
||||
results. (default False)
|
||||
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
||||
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
||||
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
|
|||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
found_chunks = await vector_engine.search(
|
||||
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
||||
)
|
||||
|
||||
if len(found_chunks) == 0:
|
||||
return ""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue