diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 2b21f4d05..f77d1ee20 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,8 +1,8 @@ from typing import List, Optional, get_type_hints, Any, Dict from sqlalchemy import text, select from sqlalchemy import JSON, Column, Table -from sqlalchemy.dialects.postgresql import ARRAY from ..models.ScoredResult import ScoredResult +import asyncio from ..vector_db_interface import VectorDBInterface, DataPoint from sqlalchemy.orm import Mapped, mapped_column @@ -153,11 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): limit: int = 5, with_vector: bool = False, ) -> List[ScoredResult]: - # Validate inputs if query_text is None and query_vector is None: raise ValueError("One of query_text or query_vector must be provided!") - # Get the vector for query_text if provided if query_text and not query_vector: query_vector = (await self.embedding_engine.embed_text([query_text]))[0] @@ -195,7 +193,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): limit: int = None, with_vectors: bool = False, ): - pass + query_vectors = await self.embedding_engine.embed_text(query_texts) + + return asyncio.gather( + *[self.search( + collection_name = collection_name, + query_vector = query_vector, + limit = limit, + with_vector = with_vectors, + ) for query_vector in query_vectors] + ) async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): pass