feat: Add batch search to PGVectorAdapter
Added batch search to PGVectorAdapter Feature #COG-170
This commit is contained in:
parent
aa26eabdbb
commit
02cd2408d6
1 changed files with 11 additions and 4 deletions
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import List, Optional, get_type_hints, Any, Dict
|
from typing import List, Optional, get_type_hints, Any, Dict
|
||||||
from sqlalchemy import text, select
|
from sqlalchemy import text, select
|
||||||
from sqlalchemy import JSON, Column, Table
|
from sqlalchemy import JSON, Column, Table
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from ..vector_db_interface import VectorDBInterface, DataPoint
|
from ..vector_db_interface import VectorDBInterface, DataPoint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
@ -153,11 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
) -> List[ScoredResult]:
|
) -> List[ScoredResult]:
|
||||||
# Validate inputs
|
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
# Get the vector for query_text if provided
|
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
|
|
@ -195,7 +193,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
limit: int = None,
|
limit: int = None,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
):
|
||||||
pass
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
|
return asyncio.gather(
|
||||||
|
*[self.search(
|
||||||
|
collection_name = collection_name,
|
||||||
|
query_vector = query_vector,
|
||||||
|
limit = limit,
|
||||||
|
with_vector = with_vectors,
|
||||||
|
) for query_vector in query_vectors]
|
||||||
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue