cognee/cognee/infrastructure/databases/vector/vector_db_interface.py
Daniel Molnar b5ebed1f7d
Docstring infrastructure. (#880)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-05-28 17:47:31 +02:00

144 lines
4.6 KiB
Python

from typing import List, Protocol, Optional
from abc import abstractmethod
from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema
class VectorDBInterface(Protocol):
"""
Defines an interface for interacting with a vector database, including operations for
managing collections and data points.
"""
@abstractmethod
async def has_collection(self, collection_name: str) -> bool:
"""
Check if a specified collection exists.
Parameters:
-----------
- collection_name (str): The name of the collection to check for existence.
Returns:
--------
- bool: True if the collection exists, otherwise False.
"""
raise NotImplementedError
@abstractmethod
async def create_collection(
self,
collection_name: str,
payload_schema: Optional[PayloadSchema] = None,
):
"""
Create a new collection with an optional payload schema.
Parameters:
-----------
- collection_name (str): The name of the new collection to create.
- payload_schema (Optional[PayloadSchema]): An optional schema for the payloads
within this collection. (default None)
"""
raise NotImplementedError
""" Data points """
@abstractmethod
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
"""
Insert new data points into the specified collection.
Parameters:
-----------
- collection_name (str): The name of the collection where data points will be added.
- data_points (List[DataPoint]): A list of data points to be added to the
collection.
"""
raise NotImplementedError
@abstractmethod
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
"""
Retrieve data points from a collection using their IDs.
Parameters:
-----------
- collection_name (str): The name of the collection from which to retrieve data
points.
- data_point_ids (list[str]): A list of IDs of the data points to retrieve.
"""
raise NotImplementedError
""" Search """
@abstractmethod
async def search(
self,
collection_name: str,
query_text: Optional[str],
query_vector: Optional[List[float]],
limit: int,
with_vector: bool = False,
):
"""
Perform a search in the specified collection using either a text query or a vector
query.
Parameters:
-----------
- collection_name (str): The name of the collection in which to perform the search.
- query_text (Optional[str]): An optional text query to search for in the
collection.
- query_vector (Optional[List[float]]): An optional vector representation for
searching the collection.
- limit (int): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
"""
raise NotImplementedError
@abstractmethod
async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
):
"""
Perform a batch search using multiple text queries against a collection.
Parameters:
-----------
- collection_name (str): The name of the collection to conduct the batch search in.
- query_texts (List[str]): A list of text queries to use for the search.
- limit (int): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search
results. (default False)
"""
raise NotImplementedError
@abstractmethod
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""
Delete specified data points from a collection.
Parameters:
-----------
- collection_name (str): The name of the collection from which to delete data
points.
- data_point_ids (list[str]): A list of IDs of the data points to delete.
"""
raise NotImplementedError
@abstractmethod
async def prune(self):
"""
Remove obsolete or unnecessary data from the database.
"""
raise NotImplementedError