feat: make payload inclusion optional for vector search
This commit is contained in:
parent
4765f9e4a0
commit
d5a888e6c0
6 changed files with 112 additions and 47 deletions
|
|
@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
limit: Optional[int] = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
|
include_payload: bool = False, # TODO: Add support for this parameter
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Search for items in a collection using either a text or a vector query.
|
Search for items in a collection using either a text or a vector query.
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
limit: Optional[int] = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise MissingQueryParameterError()
|
raise MissingQueryParameterError()
|
||||||
|
|
@ -247,21 +248,40 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
if include_payload:
|
||||||
|
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
||||||
|
if not result_values:
|
||||||
|
return []
|
||||||
|
normalized_values = normalize_distances(result_values)
|
||||||
|
|
||||||
if not result_values:
|
return [
|
||||||
return []
|
ScoredResult(
|
||||||
|
id=parse_id(result["id"]),
|
||||||
|
payload=result["payload"],
|
||||||
|
score=normalized_values[value_index],
|
||||||
|
)
|
||||||
|
for value_index, result in enumerate(result_values)
|
||||||
|
]
|
||||||
|
|
||||||
normalized_values = normalize_distances(result_values)
|
else:
|
||||||
|
result_values = await (
|
||||||
return [
|
collection.vector_search(query_vector)
|
||||||
ScoredResult(
|
.limit(limit)
|
||||||
id=parse_id(result["id"]),
|
.select(["id", "vector", "_distance"])
|
||||||
payload=result["payload"],
|
.to_list()
|
||||||
score=normalized_values[value_index],
|
|
||||||
)
|
)
|
||||||
for value_index, result in enumerate(result_values)
|
if not result_values:
|
||||||
]
|
return []
|
||||||
|
|
||||||
|
normalized_values = normalize_distances(result_values)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id=parse_id(result["id"]),
|
||||||
|
score=normalized_values[value_index],
|
||||||
|
)
|
||||||
|
for value_index, result in enumerate(result_values)
|
||||||
|
]
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
|
||||||
- id (UUID): Unique identifier for the scored result.
|
- id (UUID): Unique identifier for the scored result.
|
||||||
- score (float): The score associated with the result, where a lower score indicates a
|
- score (float): The score associated with the result, where a lower score indicates a
|
||||||
better outcome.
|
better outcome.
|
||||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
|
||||||
key-value pairs in a dictionary.
|
key-value pairs in a dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
score: float # Lower score is better
|
score: float # Lower score is better
|
||||||
payload: Dict[str, Any]
|
payload: Optional[Dict[str, Any]] = None
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
query_vector: Optional[List[float]] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: Optional[int] = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
) -> List[ScoredResult]:
|
) -> List[ScoredResult]:
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise MissingQueryParameterError()
|
raise MissingQueryParameterError()
|
||||||
|
|
@ -324,44 +325,81 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||||
closest_items = []
|
closest_items = []
|
||||||
|
|
||||||
# Use async session to connect to the database
|
if include_payload:
|
||||||
async with self.get_async_session() as session:
|
# Use async session to connect to the database
|
||||||
query = select(
|
async with self.get_async_session() as session:
|
||||||
PGVectorDataPoint,
|
query = select(
|
||||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
PGVectorDataPoint,
|
||||||
).order_by("similarity")
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||||
|
).order_by("similarity")
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
||||||
# Find closest vectors to query_vector
|
# Find closest vectors to query_vector
|
||||||
closest_items = await session.execute(query)
|
closest_items = await session.execute(query)
|
||||||
|
|
||||||
vector_list = []
|
vector_list = []
|
||||||
|
|
||||||
# Extract distances and find min/max for normalization
|
# Extract distances and find min/max for normalization
|
||||||
for vector in closest_items.all():
|
for vector in closest_items.all():
|
||||||
vector_list.append(
|
vector_list.append(
|
||||||
{
|
{
|
||||||
"id": parse_id(str(vector.id)),
|
"id": parse_id(str(vector.id)),
|
||||||
"payload": vector.payload,
|
"payload": vector.payload,
|
||||||
"_distance": vector.similarity,
|
"_distance": vector.similarity,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(vector_list) == 0:
|
if len(vector_list) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Normalize vector distance and add this as score information to vector_list
|
# Normalize vector distance and add this as score information to vector_list
|
||||||
normalized_values = normalize_distances(vector_list)
|
normalized_values = normalize_distances(vector_list)
|
||||||
for i in range(0, len(normalized_values)):
|
for i in range(0, len(normalized_values)):
|
||||||
vector_list[i]["score"] = normalized_values[i]
|
vector_list[i]["score"] = normalized_values[i]
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
# Use async session to connect to the database
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
query = select(
|
||||||
|
PGVectorDataPoint.c.id,
|
||||||
|
PGVectorDataPoint.c.vector,
|
||||||
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||||
|
).order_by("similarity")
|
||||||
|
|
||||||
|
if limit > 0:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
# Find closest vectors to query_vector
|
||||||
|
closest_items = await session.execute(query)
|
||||||
|
|
||||||
|
vector_list = []
|
||||||
|
|
||||||
|
# Extract distances and find min/max for normalization
|
||||||
|
for vector in closest_items.all():
|
||||||
|
vector_list.append(
|
||||||
|
{
|
||||||
|
"id": parse_id(str(vector.id)),
|
||||||
|
"_distance": vector.similarity,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(vector_list) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Normalize vector distance and add this as score information to vector_list
|
||||||
|
normalized_values = normalize_distances(vector_list)
|
||||||
|
for i in range(0, len(normalized_values)):
|
||||||
|
vector_list[i]["score"] = normalized_values[i]
|
||||||
|
|
||||||
|
# Create and return ScoredResult objects
|
||||||
|
return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list]
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
|
||||||
query_vector: Optional[List[float]],
|
query_vector: Optional[List[float]],
|
||||||
limit: Optional[int],
|
limit: Optional[int],
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a search in the specified collection using either a text query or a vector
|
Perform a search in the specified collection using either a text query or a vector
|
||||||
|
|
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
|
||||||
- limit (Optional[int]): The maximum number of results to return from the search.
|
- limit (Optional[int]): The maximum number of results to return from the search.
|
||||||
- with_vector (bool): Whether to return the vector representations with search
|
- with_vector (bool): Whether to return the vector representations with search
|
||||||
results. (default False)
|
results. (default False)
|
||||||
|
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
||||||
|
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
||||||
|
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
found_chunks = await vector_engine.search(
|
||||||
|
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
if len(found_chunks) == 0:
|
if len(found_chunks) == 0:
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue