<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: vasilije <vas.markovic@gmail.com> Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Co-authored-by: Igor Ilic <igorilic03@gmail.com> Co-authored-by: Hande <159312713+hande-k@users.noreply.github.com> Co-authored-by: Matea Pesic <80577904+matea16@users.noreply.github.com> Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com> Co-authored-by: Daniel Molnar <soobrosa@gmail.com> Co-authored-by: Diego Baptista Theuerkauf <34717973+diegoabt@users.noreply.github.com>
261 lines
9 KiB
Python
261 lines
9 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
from uuid import UUID
|
|
from typing import List, Optional
|
|
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.infrastructure.engine.utils import parse_id
|
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
|
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
|
from ..models.ScoredResult import ScoredResult
|
|
from ..vector_db_interface import VectorDBInterface
|
|
|
|
logger = get_logger("MilvusAdapter")
|
|
|
|
|
|
class IndexSchema(DataPoint):
|
|
text: str
|
|
|
|
metadata: dict = {"index_fields": ["text"]}
|
|
|
|
|
|
class MilvusAdapter(VectorDBInterface):
|
|
name = "Milvus"
|
|
url: str
|
|
api_key: Optional[str]
|
|
embedding_engine: EmbeddingEngine = None
|
|
|
|
def __init__(self, url: str, api_key: Optional[str], embedding_engine: EmbeddingEngine):
|
|
self.url = url
|
|
self.api_key = api_key
|
|
|
|
self.embedding_engine = embedding_engine
|
|
|
|
def get_milvus_client(self):
|
|
from pymilvus import MilvusClient
|
|
|
|
if self.api_key:
|
|
client = MilvusClient(uri=self.url, token=self.api_key)
|
|
else:
|
|
client = MilvusClient(uri=self.url)
|
|
return client
|
|
|
|
async def embed_data(self, data: List[str]) -> list[list[float]]:
|
|
return await self.embedding_engine.embed_text(data)
|
|
|
|
async def has_collection(self, collection_name: str) -> bool:
|
|
future = asyncio.Future()
|
|
client = self.get_milvus_client()
|
|
future.set_result(client.has_collection(collection_name=collection_name))
|
|
|
|
return await future
|
|
|
|
async def create_collection(
|
|
self,
|
|
collection_name: str,
|
|
payload_schema=None,
|
|
):
|
|
from pymilvus import DataType, MilvusException
|
|
|
|
client = self.get_milvus_client()
|
|
if client.has_collection(collection_name=collection_name):
|
|
logger.info(f"Collection '{collection_name}' already exists.")
|
|
return True
|
|
|
|
try:
|
|
dimension = self.embedding_engine.get_vector_size()
|
|
assert dimension > 0, "Embedding dimension must be greater than 0."
|
|
|
|
schema = client.create_schema(
|
|
auto_id=False,
|
|
enable_dynamic_field=False,
|
|
)
|
|
|
|
schema.add_field(
|
|
field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36
|
|
)
|
|
|
|
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
|
|
|
|
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=60535)
|
|
|
|
index_params = client.prepare_index_params()
|
|
index_params.add_index(field_name="vector", metric_type="COSINE")
|
|
|
|
client.create_collection(
|
|
collection_name=collection_name, schema=schema, index_params=index_params
|
|
)
|
|
|
|
client.load_collection(collection_name)
|
|
|
|
logger.info(f"Collection '{collection_name}' created successfully.")
|
|
return True
|
|
except MilvusException as e:
|
|
logger.error(f"Error creating collection '{collection_name}': {str(e)}")
|
|
raise e
|
|
|
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
|
from pymilvus import MilvusException, exceptions
|
|
|
|
client = self.get_milvus_client()
|
|
data_vectors = await self.embed_data(
|
|
[data_point.get_embeddable_data(data_point) for data_point in data_points]
|
|
)
|
|
|
|
insert_data = [
|
|
{
|
|
"id": str(data_point.id),
|
|
"vector": data_vectors[index],
|
|
"text": data_point.text,
|
|
}
|
|
for index, data_point in enumerate(data_points)
|
|
]
|
|
|
|
try:
|
|
result = client.insert(collection_name=collection_name, data=insert_data)
|
|
logger.info(
|
|
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
|
|
)
|
|
return result
|
|
except exceptions.CollectionNotExistException as error:
|
|
raise CollectionNotFoundError(
|
|
f"Collection '{collection_name}' does not exist!"
|
|
) from error
|
|
except MilvusException as e:
|
|
logger.error(
|
|
f"Error inserting data points into collection '{collection_name}': {str(e)}"
|
|
)
|
|
raise e
|
|
|
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
|
|
|
async def index_data_points(
|
|
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
|
):
|
|
formatted_data_points = [
|
|
IndexSchema(
|
|
id=data_point.id,
|
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
|
)
|
|
for data_point in data_points
|
|
]
|
|
collection_name = f"{index_name}_{index_property_name}"
|
|
await self.create_data_points(collection_name, formatted_data_points)
|
|
|
|
async def retrieve(self, collection_name: str, data_point_ids: list[UUID]):
|
|
from pymilvus import MilvusException, exceptions
|
|
|
|
client = self.get_milvus_client()
|
|
try:
|
|
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
|
|
|
|
results = client.query(
|
|
collection_name=collection_name,
|
|
expr=filter_expression,
|
|
output_fields=["*"],
|
|
)
|
|
return results
|
|
except exceptions.CollectionNotExistException as error:
|
|
raise CollectionNotFoundError(
|
|
f"Collection '{collection_name}' does not exist!"
|
|
) from error
|
|
except MilvusException as e:
|
|
logger.error(
|
|
f"Error retrieving data points from collection '{collection_name}': {str(e)}"
|
|
)
|
|
raise e
|
|
|
|
async def search(
|
|
self,
|
|
collection_name: str,
|
|
query_text: Optional[str] = None,
|
|
query_vector: Optional[List[float]] = None,
|
|
limit: int = 15,
|
|
with_vector: bool = False,
|
|
):
|
|
from pymilvus import MilvusException, exceptions
|
|
|
|
client = self.get_milvus_client()
|
|
if query_text is None and query_vector is None:
|
|
raise ValueError("One of query_text or query_vector must be provided!")
|
|
|
|
try:
|
|
query_vector = query_vector or (await self.embed_data([query_text]))[0]
|
|
|
|
output_fields = ["id", "text"]
|
|
if with_vector:
|
|
output_fields.append("vector")
|
|
|
|
results = client.search(
|
|
collection_name=collection_name,
|
|
data=[query_vector],
|
|
anns_field="vector",
|
|
limit=limit if limit > 0 else None,
|
|
output_fields=output_fields,
|
|
search_params={
|
|
"metric_type": "COSINE",
|
|
},
|
|
)
|
|
|
|
return [
|
|
ScoredResult(
|
|
id=parse_id(result["id"]),
|
|
score=result["distance"],
|
|
payload=result.get("entity", {}),
|
|
)
|
|
for result in results[0]
|
|
]
|
|
except exceptions.CollectionNotExistException as error:
|
|
raise CollectionNotFoundError(
|
|
f"Collection '{collection_name}' does not exist!"
|
|
) from error
|
|
except MilvusException as e:
|
|
logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
|
|
raise e
|
|
|
|
async def batch_search(
|
|
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
|
):
|
|
query_vectors = await self.embed_data(query_texts)
|
|
|
|
return await asyncio.gather(
|
|
*[
|
|
self.search(
|
|
collection_name=collection_name,
|
|
query_vector=query_vector,
|
|
limit=limit,
|
|
with_vector=with_vectors,
|
|
)
|
|
for query_vector in query_vectors
|
|
]
|
|
)
|
|
|
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
|
from pymilvus import MilvusException
|
|
|
|
client = self.get_milvus_client()
|
|
try:
|
|
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
|
|
|
|
delete_result = client.delete(collection_name=collection_name, filter=filter_expression)
|
|
|
|
logger.info(
|
|
f"Deleted data points with IDs {data_point_ids} from collection '{collection_name}'."
|
|
)
|
|
return delete_result
|
|
except MilvusException as e:
|
|
logger.error(
|
|
f"Error deleting data points from collection '{collection_name}': {str(e)}"
|
|
)
|
|
raise e
|
|
|
|
async def prune(self):
|
|
client = self.get_milvus_client()
|
|
if client:
|
|
collections = client.list_collections()
|
|
for collection_name in collections:
|
|
client.drop_collection(collection_name=collection_name)
|
|
client.close()
|