<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: vasilije <vas.markovic@gmail.com> Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
511 lines
18 KiB
Python
511 lines
18 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
import os
|
|
from uuid import UUID
|
|
from typing import List, Optional
|
|
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.infrastructure.engine.utils import parse_id
|
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
from cognee.infrastructure.files.storage import get_file_storage
|
|
|
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
|
from ..models.ScoredResult import ScoredResult
|
|
from ..vector_db_interface import VectorDBInterface
|
|
|
|
logger = get_logger("MilvusAdapter")
|
|
|
|
|
|
class IndexSchema(DataPoint):
|
|
"""
|
|
Represent a schema for an index that includes text data and associated metadata.
|
|
|
|
This class inherits from DataPoint and includes attributes for text and metadata. It
|
|
defines the structure of the data points used in the index, holding the text as a string
|
|
and metadata as a dictionary with predefined index fields.
|
|
"""
|
|
|
|
text: str
|
|
|
|
metadata: dict = {"index_fields": ["text"]}
|
|
|
|
|
|
class MilvusAdapter(VectorDBInterface):
|
|
"""
|
|
Interface for interacting with a Milvus vector database.
|
|
|
|
Public methods:
|
|
|
|
- __init__
|
|
- get_milvus_client
|
|
- embed_data
|
|
- has_collection
|
|
- create_collection
|
|
- create_data_points
|
|
- create_vector_index
|
|
- index_data_points
|
|
- retrieve
|
|
- search
|
|
- batch_search
|
|
- delete_data_points
|
|
- prune
|
|
"""
|
|
|
|
name = "Milvus"
|
|
url: str
|
|
api_key: Optional[str]
|
|
embedding_engine: EmbeddingEngine = None
|
|
|
|
def __init__(self, url: str, api_key: Optional[str], embedding_engine: EmbeddingEngine):
|
|
self.url = url
|
|
self.api_key = api_key
|
|
|
|
self.embedding_engine = embedding_engine
|
|
|
|
def get_milvus_client(self):
|
|
"""
|
|
Retrieve a Milvus client instance.
|
|
|
|
Returns a MilvusClient object configured with the provided URL and optional API key.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
A MilvusClient instance.
|
|
"""
|
|
from pymilvus import MilvusClient
|
|
|
|
# Ensure the parent directory exists for local file-based Milvus databases
|
|
if self.url and not self.url.startswith(("http://", "https://", "grpc://")):
|
|
# This is likely a local file path, ensure the directory exists
|
|
db_dir = os.path.dirname(self.url)
|
|
if db_dir and not os.path.exists(db_dir):
|
|
try:
|
|
file_storage = get_file_storage(db_dir)
|
|
if hasattr(file_storage, "ensure_directory_exists"):
|
|
if asyncio.iscoroutinefunction(file_storage.ensure_directory_exists):
|
|
# Run async function synchronously in this sync method
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# If we're already in an async context, we can't use run_sync easily
|
|
# Create the directory directly as a fallback
|
|
os.makedirs(db_dir, exist_ok=True)
|
|
else:
|
|
loop.run_until_complete(file_storage.ensure_directory_exists())
|
|
else:
|
|
file_storage.ensure_directory_exists()
|
|
else:
|
|
# Fallback to os.makedirs if file_storage doesn't have ensure_directory_exists
|
|
os.makedirs(db_dir, exist_ok=True)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Could not create directory {db_dir} using file_storage, falling back to os.makedirs: {e}"
|
|
)
|
|
os.makedirs(db_dir, exist_ok=True)
|
|
|
|
if self.api_key:
|
|
client = MilvusClient(uri=self.url, token=self.api_key)
|
|
else:
|
|
client = MilvusClient(uri=self.url)
|
|
return client
|
|
|
|
async def embed_data(self, data: List[str]) -> list[list[float]]:
|
|
"""
|
|
Embed a list of text data into vectors asynchronously.
|
|
|
|
Accepts a list of strings and utilizes the embedding engine to convert them into
|
|
vectors.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data (List[str]): A list of textual data to be embedded into vectors.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- list[list[float]]: A list of lists containing embedded vectors.
|
|
"""
|
|
return await self.embedding_engine.embed_text(data)
|
|
|
|
async def has_collection(self, collection_name: str) -> bool:
|
|
"""
|
|
Check if a collection exists in the database asynchronously.
|
|
|
|
Returns a boolean indicating whether the specified collection is present.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection to check for its existence.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- bool: True if the collection exists, False otherwise.
|
|
"""
|
|
future = asyncio.Future()
|
|
client = self.get_milvus_client()
|
|
future.set_result(client.has_collection(collection_name=collection_name))
|
|
|
|
return await future
|
|
|
|
async def create_collection(
|
|
self,
|
|
collection_name: str,
|
|
payload_schema=None,
|
|
):
|
|
"""
|
|
Create a new collection in the vector database asynchronously.
|
|
|
|
Raises a MilvusException if there are issues creating the collection, such as already
|
|
existing collection.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection to be created.
|
|
- payload_schema: Optional schema for the collection, defaults to None if not
|
|
provided. (default None)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
True if the collection is created successfully, otherwise returns None.
|
|
"""
|
|
from pymilvus import DataType, MilvusException
|
|
|
|
client = self.get_milvus_client()
|
|
if client.has_collection(collection_name=collection_name):
|
|
logger.info(f"Collection '{collection_name}' already exists.")
|
|
return True
|
|
|
|
try:
|
|
dimension = self.embedding_engine.get_vector_size()
|
|
assert dimension > 0, "Embedding dimension must be greater than 0."
|
|
|
|
schema = client.create_schema(
|
|
auto_id=False,
|
|
enable_dynamic_field=False,
|
|
)
|
|
|
|
schema.add_field(
|
|
field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36
|
|
)
|
|
|
|
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
|
|
|
|
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=60535)
|
|
|
|
index_params = client.prepare_index_params()
|
|
index_params.add_index(field_name="vector", metric_type="COSINE")
|
|
|
|
client.create_collection(
|
|
collection_name=collection_name, schema=schema, index_params=index_params
|
|
)
|
|
|
|
client.load_collection(collection_name)
|
|
|
|
logger.info(f"Collection '{collection_name}' created successfully.")
|
|
return True
|
|
except MilvusException as e:
|
|
logger.error(f"Error creating collection '{collection_name}': {str(e)}")
|
|
raise e
|
|
|
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
|
"""
|
|
Insert multiple data points into a specified collection asynchronously.
|
|
|
|
Raises CollectionNotFoundError if the specified collection does not exist.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection where data points will be
|
|
inserted.
|
|
- data_points (List[DataPoint]): A list of DataPoint objects to be inserted into the
|
|
collection.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
The result of the insert operation, includes count of inserted data points.
|
|
"""
|
|
from pymilvus import MilvusException, exceptions
|
|
|
|
client = self.get_milvus_client()
|
|
data_vectors = await self.embed_data(
|
|
[data_point.get_embeddable_data(data_point) for data_point in data_points]
|
|
)
|
|
|
|
insert_data = [
|
|
{
|
|
"id": str(data_point.id),
|
|
"vector": data_vectors[index],
|
|
"text": data_point.text,
|
|
}
|
|
for index, data_point in enumerate(data_points)
|
|
]
|
|
|
|
try:
|
|
result = client.insert(collection_name=collection_name, data=insert_data)
|
|
logger.info(
|
|
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
|
|
)
|
|
return result
|
|
except exceptions.CollectionNotExistException as error:
|
|
raise CollectionNotFoundError(
|
|
f"Collection '{collection_name}' does not exist!"
|
|
) from error
|
|
except MilvusException as e:
|
|
logger.error(
|
|
f"Error inserting data points into collection '{collection_name}': {str(e)}"
|
|
)
|
|
raise e
|
|
|
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
|
"""
|
|
Create a vector index for a given collection asynchronously.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- index_name (str): The name of the vector index being created.
|
|
- index_property_name (str): The property name associated with the index.
|
|
"""
|
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
|
|
|
async def index_data_points(
|
|
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
|
):
|
|
"""
|
|
Index the provided data points into the collection based on index names asynchronously.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- index_name (str): The name of the index where data points will be indexed.
|
|
- index_property_name (str): The property name associated with the index.
|
|
- data_points (List[DataPoint]): A list of DataPoint objects to be indexed.
|
|
"""
|
|
formatted_data_points = [
|
|
IndexSchema(
|
|
id=data_point.id,
|
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
|
)
|
|
for data_point in data_points
|
|
]
|
|
collection_name = f"{index_name}_{index_property_name}"
|
|
await self.create_data_points(collection_name, formatted_data_points)
|
|
|
|
async def retrieve(self, collection_name: str, data_point_ids: list[UUID]):
|
|
"""
|
|
Retrieve data points from a collection based on their IDs asynchronously.
|
|
|
|
Raises CollectionNotFoundError if the specified collection does not exist.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection from which data points will be
|
|
retrieved.
|
|
- data_point_ids (list[UUID]): A list of UUIDs representing the IDs of the data
|
|
points to be retrieved.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
The results of the query, including the requested data points.
|
|
"""
|
|
from pymilvus import MilvusException, exceptions
|
|
|
|
client = self.get_milvus_client()
|
|
try:
|
|
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
|
|
|
|
results = client.query(
|
|
collection_name=collection_name,
|
|
expr=filter_expression,
|
|
output_fields=["*"],
|
|
)
|
|
return results
|
|
except exceptions.CollectionNotExistException as error:
|
|
raise CollectionNotFoundError(
|
|
f"Collection '{collection_name}' does not exist!"
|
|
) from error
|
|
except MilvusException as e:
|
|
logger.error(
|
|
f"Error retrieving data points from collection '{collection_name}': {str(e)}"
|
|
)
|
|
raise e
|
|
|
|
async def search(
|
|
self,
|
|
collection_name: str,
|
|
query_text: Optional[str] = None,
|
|
query_vector: Optional[List[float]] = None,
|
|
limit: int = 15,
|
|
with_vector: bool = False,
|
|
):
|
|
"""
|
|
Search for data points in a collection based on a text query or vector asynchronously.
|
|
|
|
Raises ValueError if neither query_text nor query_vector is provided. Raises
|
|
MilvusException for errors during the search process.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection to search within.
|
|
- query_text (Optional[str]): Optional text query used for searching, defaults to
|
|
None. (default None)
|
|
- query_vector (Optional[List[float]]): Optional vector query used for searching,
|
|
defaults to None. (default None)
|
|
- limit (int): Maximum number of results to return, defaults to 15. (default 15)
|
|
- with_vector (bool): Flag to indicate if the vector should be included in the
|
|
results, defaults to False. (default False)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
A list of scored results that match the query; may include vector data if requested.
|
|
"""
|
|
from pymilvus import MilvusException, exceptions
|
|
|
|
client = self.get_milvus_client()
|
|
if query_text is None and query_vector is None:
|
|
raise ValueError("One of query_text or query_vector must be provided!")
|
|
|
|
if not client.has_collection(collection_name=collection_name):
|
|
logger.warning(
|
|
f"Collection '{collection_name}' not found in MilvusAdapter.search; returning []."
|
|
)
|
|
return []
|
|
|
|
try:
|
|
query_vector = query_vector or (await self.embed_data([query_text]))[0]
|
|
|
|
output_fields = ["id", "text"]
|
|
if with_vector:
|
|
output_fields.append("vector")
|
|
|
|
results = client.search(
|
|
collection_name=collection_name,
|
|
data=[query_vector],
|
|
anns_field="vector",
|
|
limit=limit if limit > 0 else None,
|
|
output_fields=output_fields,
|
|
search_params={
|
|
"metric_type": "COSINE",
|
|
},
|
|
)
|
|
|
|
return [
|
|
ScoredResult(
|
|
id=parse_id(result["id"]),
|
|
score=result["distance"],
|
|
payload=result.get("entity", {}),
|
|
)
|
|
for result in results[0]
|
|
]
|
|
except exceptions.CollectionNotExistException:
|
|
logger.warning(
|
|
f"Collection '{collection_name}' not found (exception) in MilvusAdapter.search; returning []."
|
|
)
|
|
return []
|
|
except MilvusException as e:
|
|
# Catch other Milvus errors that are "collection not found" (paranoid safety)
|
|
if "collection not found" in str(e).lower() or "schema" in str(e).lower():
|
|
logger.warning(
|
|
f"Collection '{collection_name}' not found (MilvusException) in MilvusAdapter.search; returning []."
|
|
)
|
|
return []
|
|
logger.error(f"Error searching Milvus collection '{collection_name}': {e}")
|
|
raise e
|
|
|
|
async def batch_search(
|
|
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
|
):
|
|
"""
|
|
Perform a batch search in a collection for multiple textual queries asynchronously.
|
|
|
|
Utilizes embed_data to convert texts into vectors and returns the search results for
|
|
each query.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection where the search will be
|
|
performed.
|
|
- query_texts (List[str]): A list of texts to search for in the collection.
|
|
- limit (int): Maximum number of results to return per query.
|
|
- with_vectors (bool): Specifies if the vectors should be included in the search
|
|
results, defaults to False. (default False)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
A list of search result sets, one for each query input.
|
|
"""
|
|
query_vectors = await self.embed_data(query_texts)
|
|
|
|
return await asyncio.gather(
|
|
*[
|
|
self.search(
|
|
collection_name=collection_name,
|
|
query_vector=query_vector,
|
|
limit=limit,
|
|
with_vector=with_vectors,
|
|
)
|
|
for query_vector in query_vectors
|
|
]
|
|
)
|
|
|
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
|
"""
|
|
Delete specific data points from a collection based on their IDs asynchronously.
|
|
|
|
Raises MilvusException for errors during the deletion process.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection from which data points will be
|
|
deleted.
|
|
- data_point_ids (list[UUID]): A list of UUIDs representing the IDs of the data
|
|
points to be deleted.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
The result of the delete operation, indicating success or failure.
|
|
"""
|
|
from pymilvus import MilvusException
|
|
|
|
client = self.get_milvus_client()
|
|
try:
|
|
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
|
|
|
|
delete_result = client.delete(collection_name=collection_name, filter=filter_expression)
|
|
|
|
logger.info(
|
|
f"Deleted data points with IDs {data_point_ids} from collection '{collection_name}'."
|
|
)
|
|
return delete_result
|
|
except MilvusException as e:
|
|
logger.error(
|
|
f"Error deleting data points from collection '{collection_name}': {str(e)}"
|
|
)
|
|
raise e
|
|
|
|
async def prune(self):
|
|
"""
|
|
Remove all collections from the connected Milvus client asynchronously.
|
|
"""
|
|
client = self.get_milvus_client()
|
|
if client:
|
|
collections = client.list_collections()
|
|
for collection_name in collections:
|
|
client.drop_collection(collection_name=collection_name)
|
|
client.close()
|