cognee/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py
Boris 219afbce68
feat: add lancedb vector storage [COG-176] (#90)
* feat: integrate lancedb

* fix: use futures in weaviate adapter to enable async behaviour
2024-05-03 10:35:41 +02:00

196 lines
6.5 KiB
Python

from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
# class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
# optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
# quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
def create_hnsw_config(hnsw_config: Dict):
if hnsw_config is not None:
return models.HnswConfig()
return None
def create_optimizers_config(optimizers_config: Dict):
if optimizers_config is not None:
return models.OptimizersConfig()
return None
def create_quantization_config(quantization_config: Dict):
if quantization_config is not None:
return models.QuantizationConfig()
return None
class QDrantAdapter(VectorDBInterface):
qdrant_url: str = None
qdrant_path: str = None
qdrant_api_key: str = None
def __init__(self, qdrant_url, qdrant_api_key, embedding_engine: EmbeddingEngine, qdrant_path = None):
self.embedding_engine = embedding_engine
if qdrant_path is not None:
self.qdrant_path = qdrant_path
else:
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
def get_qdrant_client(self) -> AsyncQdrantClient:
if self.qdrant_path is not None:
return AsyncQdrantClient(
path = self.qdrant_path, port=6333
)
elif self.qdrant_url is not None:
return AsyncQdrantClient(
url = self.qdrant_url,
api_key = self.qdrant_api_key,
port = 6333
)
return AsyncQdrantClient(
location = ":memory:"
)
async def embed_data(self, data: List[str]) -> List[float]:
return await self.embedding_engine.embed_text(data)
async def collection_exists(self, collection_name: str) -> bool:
client = self.get_qdrant_client()
result = await client.collection_exists(collection_name)
await client.close()
return result
async def create_collection(
self,
collection_name: str,
payload_schema = None,
):
client = self.get_qdrant_client()
result = await client.create_collection(
collection_name = collection_name,
vectors_config = {
"text": models.VectorParams(
size = self.embedding_engine.get_vector_size(),
distance = "Cosine"
)
}
)
await client.close()
return result
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
client = self.get_qdrant_client()
data_vectors = await self.embed_data([data_point.get_embeddable_data() for data_point in data_points])
def convert_to_qdrant_point(data_point: DataPoint):
return models.PointStruct(
id = data_point.id,
payload = data_point.payload.dict(),
vector = {
"text": data_vectors[data_points.index(data_point)]
}
)
points = [convert_to_qdrant_point(point) for point in data_points]
result = await client.upload_points(
collection_name = collection_name,
points = points
)
await client.close()
return result
async def retrieve(self, collection_name: str, data_point_id: str):
client = self.get_qdrant_client()
results = await client.retrieve(collection_name, [data_point_id], with_payload = True)
await client.close()
return results[0] if len(results) > 0 else None
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = None,
with_vector: bool = False
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
client = self.get_qdrant_client()
result = await client.search(
collection_name = collection_name,
query_vector = models.NamedVector(
name = "text",
vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0],
),
limit = limit,
with_vectors = with_vector
)
await client.close()
return result
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):
"""
Perform batch search in a Qdrant collection with dynamic search requests.
Args:
- collection_name (str): Name of the collection to search in.
- query_texts (List[str]): List of query texts to search for.
- limit (int): List of result limits for search requests.
- with_vectors (bool, optional): Bool indicating whether to return vectors for search requests.
Returns:
- results: The search results from Qdrant.
"""
vectors = await self.embed_data(query_texts)
# Generate dynamic search requests based on the provided embeddings
requests = [
models.SearchRequest(
vector = models.NamedVector(
name = "text",
vector = vector
),
limit = limit,
with_vector = with_vectors
) for vector in vectors
]
client = self.get_qdrant_client()
# Perform batch search with the dynamically generated requests
results = await client.search_batch(
collection_name = collection_name,
requests = requests
)
await client.close()
return [filter(lambda result: result.score > 0.9, result_group) for result_group in results]
async def prune(self):
client = self.get_qdrant_client()
response = await client.get_collections()
for collection in response.collections:
await client.delete_collection(collection.name)
await client.close()