feat: Add batch search to PGVectorAdapter

Added batch search to PGVectorAdapter

Feature #COG-170
This commit is contained in:
Igor Ilic 2024-10-17 18:11:11 +02:00
parent aa26eabdbb
commit 02cd2408d6

View file

@ -1,8 +1,8 @@
from typing import List, Optional, get_type_hints, Any, Dict
from sqlalchemy import text, select
from sqlalchemy import JSON, Column, Table
from sqlalchemy.dialects.postgresql import ARRAY
from ..models.ScoredResult import ScoredResult
import asyncio
from ..vector_db_interface import VectorDBInterface, DataPoint
from sqlalchemy.orm import Mapped, mapped_column
@ -153,11 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
limit: int = 5,
with_vector: bool = False,
) -> List[ScoredResult]:
# Validate inputs
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
# Get the vector for query_text if provided
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
@ -195,7 +193,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
limit: int = None,
with_vectors: bool = False,
):
pass
query_vectors = await self.embedding_engine.embed_text(query_texts)
return asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass