cognee/cognitive_architecture/infrastructure/databases/vector/qdrant/adapter.py
Boris Arzentar 769d6b5080 feat: add create-memory and remember API endpoints
Add possibility to create a new Vector memory and store text data points using openai embeddings.
2024-02-25 23:56:50 +01:00

61 lines
2.2 KiB
Python

from typing import List, Optional
from pydantic import BaseModel, Field
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
class CollectionConfig(BaseModel, extra = "forbid"):
vector_config: models.VectorParams = Field(..., description="Vector configuration")
hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
class QDrantAdapter(VectorDBInterface):
qdrant_url: str = None
qdrant_path: str = None
qdrant_api_key: str = None
def __init__(self, qdrant_path, qdrant_url, qdrant_api_key):
if qdrant_path is not None:
self.qdrant_path = qdrant_path
else:
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
def get_qdrant_client(self) -> AsyncQdrantClient:
if self.qdrant_path is not None:
return AsyncQdrantClient(
path = self.qdrant_path,
)
elif self.qdrant_url is not None:
return AsyncQdrantClient(
url = self.qdrant_url,
api_key = self.qdrant_api_key,
)
return AsyncQdrantClient(
location = ":memory:"
)
async def create_collection(
self,
collection_name: str,
collection_config: CollectionConfig,
):
client = self.get_qdrant_client()
return await client.create_collection(
collection_name = collection_name,
vectors_config = collection_config.vector_config,
hnsw_config = collection_config.hnsw_config,
optimizers_config = collection_config.optimizers_config,
quantization_config = collection_config.quantization_config
)
async def create_data_points(self, collection_name: str, data_points: List[any]):
client = self.get_qdrant_client()
return await client.upload_points(
collection_name = collection_name,
points = data_points
)