cognee/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py
2025-08-13 14:42:57 +02:00

513 lines
17 KiB
Python

import os
from typing import Dict, List, Optional
from qdrant_client import AsyncQdrantClient, models
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..vector_db_interface import VectorDBInterface
logger = get_logger("QDrantAdapter")
class IndexSchema(DataPoint):
"""
Represents a schema for indexing where each data point contains a text field.
This class inherits from DataPoint and defines a text attribute as well as metadata
containing index fields used for indexing operations.
"""
text: str
metadata: dict = {"index_fields": ["text"]}
# class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
# optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
# quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
def create_hnsw_config(hnsw_config: Dict):
"""
Create HNSW configuration.
This function returns an HNSW configuration object if the provided configuration is not
None, otherwise it returns None.
Parameters:
-----------
- hnsw_config (Dict): A dictionary containing HNSW configuration parameters.
Returns:
--------
An instance of models.HnswConfig if hnsw_config is not None, otherwise None.
"""
if hnsw_config is not None:
return models.HnswConfig()
return None
def create_optimizers_config(optimizers_config: Dict):
"""
Create and return an OptimizersConfig instance if the input configuration is provided.
This function checks if the given optimizers configuration is not None. If valid, it
initializes and returns a new instance of the OptimizersConfig class from the models
module. If the configuration is None, it returns None instead.
Parameters:
-----------
- optimizers_config (Dict): A dictionary containing optimizer configuration
settings.
Returns:
--------
Returns an instance of OptimizersConfig if optimizers_config is provided; otherwise,
returns None.
"""
if optimizers_config is not None:
return models.OptimizersConfig()
return None
def create_quantization_config(quantization_config: Dict):
"""
Create a quantization configuration based on the provided settings.
This function generates an instance of `QuantizationConfig` if the provided
`quantization_config` is not None. If it is None, the function returns None.
Parameters:
-----------
- quantization_config (Dict): A dictionary containing the quantization configuration
settings.
Returns:
--------
An instance of `QuantizationConfig` if `quantization_config` is provided; otherwise,
returns None.
"""
if quantization_config is not None:
return models.QuantizationConfig()
return None
class QDrantAdapter(VectorDBInterface):
"""
Adapt to the Qdrant vector database interface.
Public methods:
- get_qdrant_client
- embed_data
- has_collection
- create_collection
- create_data_points
- create_vector_index
- index_data_points
- retrieve
- search
- batch_search
- delete_data_points
- prune
"""
name = "Qdrant"
url: str = None
api_key: str = None
qdrant_path: str = None
def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path=None):
self.embedding_engine = embedding_engine
if qdrant_path is not None:
self.qdrant_path = qdrant_path
else:
self.url = url
self.api_key = api_key
def get_qdrant_client(self) -> AsyncQdrantClient:
"""
Retrieve an instance of AsyncQdrantClient configured with the appropriate
settings based on the instance's attributes.
Returns an instance of AsyncQdrantClient configured to connect to the database.
Returns:
--------
- AsyncQdrantClient: An instance of AsyncQdrantClient configured for database
operations.
"""
is_prod = os.getenv("ENV").lower() == "prod"
if self.qdrant_path is not None:
return AsyncQdrantClient(path=self.qdrant_path, port=6333, https=is_prod)
elif self.url is not None:
return AsyncQdrantClient(url=self.url, api_key=self.api_key, port=6333, https=is_prod)
return AsyncQdrantClient(location=":memory:")
async def embed_data(self, data: List[str]) -> List[float]:
"""
Embed a list of text data into vector representations asynchronously.
Parameters:
-----------
- data (List[str]): A list of strings containing the text data to be embedded.
Returns:
--------
- List[float]: A list of floating-point vectors representing the embedded text data.
"""
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
"""
Check if a specified collection exists in the Qdrant database asynchronously.
Parameters:
-----------
- collection_name (str): The name of the collection to check for existence.
Returns:
--------
- bool: True if the specified collection exists, False otherwise.
"""
client = self.get_qdrant_client()
result = await client.collection_exists(collection_name)
await client.close()
return result
async def create_collection(
self,
collection_name: str,
payload_schema=None,
):
"""
Create a new collection in the Qdrant database if it does not already exist.
If the collection already exists, this operation has no effect.
Parameters:
-----------
- collection_name (str): The name of the collection to create.
- payload_schema: Optional schema for the payload. Defaults to None. (default None)
"""
client = self.get_qdrant_client()
if not await client.collection_exists(collection_name):
await client.create_collection(
collection_name=collection_name,
vectors_config={
"text": models.VectorParams(
size=self.embedding_engine.get_vector_size(), distance="Cosine"
)
},
)
await client.close()
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
"""
Create and upload data points to a specified collection in the database.
Raises CollectionNotFoundError if the collection does not exist.
Parameters:
-----------
- collection_name (str): The name of the collection to which data points will be
uploaded.
- data_points (List[DataPoint]): A list of DataPoint objects to be uploaded.
Returns:
--------
None if the operation is successful; raises exceptions on error.
"""
from qdrant_client.http.exceptions import UnexpectedResponse
client = self.get_qdrant_client()
data_vectors = await self.embed_data(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
def convert_to_qdrant_point(data_point: DataPoint):
"""
Convert a DataPoint object into the format expected by Qdrant for upload.
Parameters:
-----------
- data_point (DataPoint): The DataPoint object to convert.
Returns:
--------
None; performs an operation without returning a value.
"""
return models.PointStruct(
id=str(data_point.id),
payload=data_point.model_dump(),
vector={"text": data_vectors[data_points.index(data_point)]},
)
points = [convert_to_qdrant_point(point) for point in data_points]
try:
client.upload_points(collection_name=collection_name, points=points)
except UnexpectedResponse as error:
if "Collection not found" in str(error):
raise CollectionNotFoundError(
message=f"Collection {collection_name} not found!"
) from error
else:
raise error
except Exception as error:
logger.error("Error uploading data points to Qdrant: %s", str(error))
raise error
finally:
await client.close()
async def create_vector_index(self, index_name: str, index_property_name: str):
"""
Create a vector index for a specified property name.
This is essentially a wrapper around create_collection, which allows for more
flexibility
in index naming.
Parameters:
-----------
- index_name (str): The base name for the index to be created.
- index_property_name (str): The property name that will be part of the index name.
"""
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
"""
Index data points into a specific collection based on provided metadata.
Transforms DataPoint objects into an appropriate format and uploads them.
Parameters:
-----------
- index_name (str): The base name for the index used for naming the collection.
- index_property_name (str): The property name used for naming the collection.
- data_points (list[DataPoint]): A list of DataPoint objects to index.
"""
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
],
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
"""
Retrieve data points from a specified collection based on their IDs.
Returns the data corresponding to the provided IDs from the collection.
Parameters:
-----------
- collection_name (str): The name of the collection to retrieve from.
- data_point_ids (list[str]): A list of IDs of the data points to retrieve.
Returns:
--------
The retrieved data points, including payloads for each ID.
"""
client = self.get_qdrant_client()
results = await client.retrieve(collection_name, data_point_ids, with_payload=True)
await client.close()
return results
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
) -> List[ScoredResult]:
"""
Search for data points in a collection based on either a textual query or a vector
query.
Raises MissingQueryParameterError if both query_text and query_vector are None.
Returns a list of scored results that match the search criteria.
Parameters:
-----------
- collection_name (str): The name of the collection to search within.
- query_text (Optional[str]): The text to be used in the search query; optional if
query_vector is provided. (default None)
- query_vector (Optional[List[float]]): The vector to be used in the search query;
optional if query_text is provided. (default None)
- limit (int): The maximum number of results to return; defaults to 15. (default 15)
- with_vector (bool): Indicates whether to return vector data along with results;
defaults to False. (default False)
Returns:
--------
- List[ScoredResult]: A list of ScoredResult objects representing the results of the
search.
"""
from qdrant_client.http.exceptions import UnexpectedResponse
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
if not await self.has_collection(collection_name):
return []
if query_vector is None:
query_vector = (await self.embed_data([query_text]))[0]
try:
client = self.get_qdrant_client()
if limit == 0:
collection_size = await client.count(collection_name=collection_name)
results = await client.search(
collection_name=collection_name,
query_vector=models.NamedVector(
name="text",
vector=query_vector
if query_vector is not None
else (await self.embed_data([query_text]))[0],
),
limit=limit if limit > 0 else collection_size.count,
with_vectors=with_vector,
)
await client.close()
return [
ScoredResult(
id=parse_id(result.id),
payload={
**result.payload,
"id": parse_id(result.id),
},
score=1 - result.score,
)
for result in results
]
finally:
await client.close()
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
):
"""
Perform a batch search in a specified collection using multiple query texts.
Returns the results of the search for each query, filtering for results with a score
higher than 0.9.
Parameters:
-----------
- collection_name (str): The name of the collection to search in.
- query_texts (List[str]): A list of query texts to search for in the collection.
- limit (int): The maximum number of results to return for each search request; can
be None. (default None)
- with_vectors (bool): Indicates whether to include vector data in the results;
defaults to False. (default False)
Returns:
--------
A list containing the filtered search results for each query text.
"""
vectors = await self.embed_data(query_texts)
# Generate dynamic search requests based on the provided embeddings
requests = [
models.SearchRequest(
vector=models.NamedVector(name="text", vector=vector),
limit=limit,
with_vector=with_vectors,
)
for vector in vectors
]
client = self.get_qdrant_client()
# Perform batch search with the dynamically generated requests
results = await client.search_batch(collection_name=collection_name, requests=requests)
await client.close()
return [filter(lambda result: result.score > 0.9, result_group) for result_group in results]
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""
Delete specific data points from a specified collection based on their IDs.
Parameters:
-----------
- collection_name (str): The name of the collection from which to delete the data
points.
- data_point_ids (list[str]): The list of IDs of data points to be deleted.
Returns:
--------
The result of the delete operation from the database.
"""
client = self.get_qdrant_client()
results = await client.delete(collection_name, data_point_ids)
return results
async def prune(self):
"""
Remove all collections from the Qdrant database asynchronously.
"""
client = self.get_qdrant_client()
response = await client.get_collections()
for collection in response.collections:
await client.delete_collection(collection.name)
await client.close()