feat: Add batch search to PGVectorAdapter

Added batch search to PGVectorAdapter

Feature #COG-170
This commit is contained in:
Igor Ilic 2024-10-17 18:11:11 +02:00
parent aa26eabdbb
commit 02cd2408d6

View file

@ -1,8 +1,8 @@
from typing import List, Optional, get_type_hints, Any, Dict from typing import List, Optional, get_type_hints, Any, Dict
from sqlalchemy import text, select from sqlalchemy import text, select
from sqlalchemy import JSON, Column, Table from sqlalchemy import JSON, Column, Table
from sqlalchemy.dialects.postgresql import ARRAY
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
import asyncio
from ..vector_db_interface import VectorDBInterface, DataPoint from ..vector_db_interface import VectorDBInterface, DataPoint
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -153,11 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
limit: int = 5, limit: int = 5,
with_vector: bool = False, with_vector: bool = False,
) -> List[ScoredResult]: ) -> List[ScoredResult]:
# Validate inputs
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise ValueError("One of query_text or query_vector must be provided!")
# Get the vector for query_text if provided
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
@ -195,7 +193,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
limit: int = None, limit: int = None,
with_vectors: bool = False, with_vectors: bool = False,
): ):
pass query_vectors = await self.embedding_engine.embed_text(query_texts)
return asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass pass