Add possibility to create a new Vector memory and store text data points using openai embeddings.
61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
from typing import List, Optional
|
|
from pydantic import BaseModel, Field
|
|
from qdrant_client import AsyncQdrantClient, models
|
|
from ..vector_db_interface import VectorDBInterface
|
|
|
|
class CollectionConfig(BaseModel, extra = "forbid"):
|
|
vector_config: models.VectorParams = Field(..., description="Vector configuration")
|
|
hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
|
optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
|
|
quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
|
|
|
|
class QDrantAdapter(VectorDBInterface):
|
|
qdrant_url: str = None
|
|
qdrant_path: str = None
|
|
qdrant_api_key: str = None
|
|
|
|
def __init__(self, qdrant_path, qdrant_url, qdrant_api_key):
|
|
if qdrant_path is not None:
|
|
self.qdrant_path = qdrant_path
|
|
else:
|
|
self.qdrant_url = qdrant_url
|
|
|
|
self.qdrant_api_key = qdrant_api_key
|
|
|
|
def get_qdrant_client(self) -> AsyncQdrantClient:
|
|
if self.qdrant_path is not None:
|
|
return AsyncQdrantClient(
|
|
path = self.qdrant_path,
|
|
)
|
|
elif self.qdrant_url is not None:
|
|
return AsyncQdrantClient(
|
|
url = self.qdrant_url,
|
|
api_key = self.qdrant_api_key,
|
|
)
|
|
|
|
return AsyncQdrantClient(
|
|
location = ":memory:"
|
|
)
|
|
|
|
async def create_collection(
|
|
self,
|
|
collection_name: str,
|
|
collection_config: CollectionConfig,
|
|
):
|
|
client = self.get_qdrant_client()
|
|
|
|
return await client.create_collection(
|
|
collection_name = collection_name,
|
|
vectors_config = collection_config.vector_config,
|
|
hnsw_config = collection_config.hnsw_config,
|
|
optimizers_config = collection_config.optimizers_config,
|
|
quantization_config = collection_config.quantization_config
|
|
)
|
|
|
|
async def create_data_points(self, collection_name: str, data_points: List[any]):
|
|
client = self.get_qdrant_client()
|
|
|
|
return await client.upload_points(
|
|
collection_name = collection_name,
|
|
points = data_points
|
|
)
|