253 lines
8.5 KiB
Python
253 lines
8.5 KiB
Python
from typing import List, Protocol, Optional, Union, Any
|
|
from abc import abstractmethod
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from .models.PayloadSchema import PayloadSchema
|
|
from uuid import UUID
|
|
from cognee.modules.users.models import User
|
|
|
|
|
|
class VectorDBInterface(Protocol):
|
|
"""
|
|
Defines an interface for interacting with a vector database, including operations for
|
|
managing collections and data points.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def has_collection(self, collection_name: str) -> bool:
|
|
"""
|
|
Check if a specified collection exists.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection to check for existence.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- bool: True if the collection exists, otherwise False.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def create_collection(
|
|
self,
|
|
collection_name: str,
|
|
payload_schema: Optional[Any] = None,
|
|
):
|
|
"""
|
|
Create a new collection with an optional payload schema.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the new collection to create.
|
|
- payload_schema (Optional[Any]): An optional schema for the payloads
|
|
within this collection. Can be PayloadSchema, BaseModel, or other schema types. (default None)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
""" Data points """
|
|
|
|
@abstractmethod
|
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
|
"""
|
|
Insert new data points into the specified collection.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection where data points will be added.
|
|
- data_points (List[DataPoint]): A list of data points to be added to the
|
|
collection.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
|
"""
|
|
Retrieve data points from a collection using their IDs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection from which to retrieve data
|
|
points.
|
|
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to retrieve.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
""" Search """
|
|
|
|
@abstractmethod
|
|
async def search(
|
|
self,
|
|
collection_name: str,
|
|
query_text: Optional[str],
|
|
query_vector: Optional[List[float]],
|
|
limit: Optional[int],
|
|
with_vector: bool = False,
|
|
):
|
|
"""
|
|
Perform a search in the specified collection using either a text query or a vector
|
|
query.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection in which to perform the search.
|
|
- query_text (Optional[str]): An optional text query to search for in the
|
|
collection.
|
|
- query_vector (Optional[List[float]]): An optional vector representation for
|
|
searching the collection.
|
|
- limit (Optional[int]): The maximum number of results to return from the search.
|
|
- with_vector (bool): Whether to return the vector representations with search
|
|
results. (default False)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def batch_search(
|
|
self,
|
|
collection_name: str,
|
|
query_texts: List[str],
|
|
limit: Optional[int],
|
|
with_vectors: bool = False,
|
|
):
|
|
"""
|
|
Perform a batch search using multiple text queries against a collection.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection to conduct the batch search in.
|
|
- query_texts (List[str]): A list of text queries to use for the search.
|
|
- limit (Optional[int]): The maximum number of results to return for each query.
|
|
- with_vectors (bool): Whether to include vector representations with search
|
|
results. (default False)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_data_points(
|
|
self, collection_name: str, data_point_ids: Union[List[str], list[str]]
|
|
):
|
|
"""
|
|
Delete specified data points from a collection.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection from which to delete data
|
|
points.
|
|
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to delete.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def prune(self):
|
|
"""
|
|
Remove obsolete or unnecessary data from the database.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def embed_data(self, data: List[str]) -> List[List[float]]:
|
|
"""
|
|
Embed textual data into vector representations.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data (List[str]): A list of strings to be embedded.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[List[float]]: A list of embedded vectors corresponding to the input data.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
# Optional methods that may be implemented by adapters
|
|
async def get_connection(self):
|
|
"""
|
|
Get a connection to the vector database.
|
|
This method is optional and may return None for adapters that don't use connections.
|
|
"""
|
|
return None
|
|
|
|
async def get_collection(self, collection_name: str):
|
|
"""
|
|
Get a collection object from the vector database.
|
|
This method is optional and may return None for adapters that don't expose collection objects.
|
|
"""
|
|
return None
|
|
|
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
|
"""
|
|
Create a vector index for improved search performance.
|
|
This method is optional and may be a no-op for adapters that don't support indexing.
|
|
"""
|
|
pass
|
|
|
|
async def index_data_points(
|
|
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
|
):
|
|
"""
|
|
Index data points for improved search performance.
|
|
This method is optional and may be a no-op for adapters that don't support separate indexing.
|
|
|
|
Parameters:
|
|
-----------
|
|
- index_name (str): Name of the index to create/update
|
|
- index_property_name (str): Property name to index on
|
|
- data_points (List[DataPoint]): Data points to index
|
|
"""
|
|
pass
|
|
|
|
def get_data_point_schema(self, model_type: Any) -> Any:
|
|
"""
|
|
Get or transform a data point schema for the specific vector database.
|
|
This method is optional and may return the input unchanged for simple adapters.
|
|
|
|
Parameters:
|
|
-----------
|
|
- model_type (Any): The model type to get schema for
|
|
|
|
Returns:
|
|
--------
|
|
- Any: The schema object suitable for this vector database
|
|
"""
|
|
return model_type
|
|
|
|
@classmethod
|
|
async def create_database(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
|
"""
|
|
Return a dictionary with connection info for a vector database for the given dataset and user.
|
|
Function should auto handle deploying of the actual database if needed.
|
|
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
|
|
|
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
|
From which internal mapping of dataset -> database connection info will be done.
|
|
|
|
Each dataset needs to map to a unique vector database instance when backend access control is enabled.
|
|
|
|
Args:
|
|
dataset_id: UUID of the dataset if needed by the database creation logic
|
|
user: User object if needed by the database creation logic
|
|
Returns:
|
|
dict: Connection info for the created vector database instance.
|
|
"""
|
|
pass
|
|
|
|
async def delete_database(self, dataset_id: UUID, user: User) -> None:
|
|
"""
|
|
Delete the vector database instance for the given dataset and user.
|
|
Function should auto handle deleting of the actual database.
|
|
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
|
|
|
Args:
|
|
dataset_id: UUID of the dataset
|
|
user: User object
|
|
"""
|
|
pass
|