fix: Interface fixes (#1206)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Daulet Amirkhanov <damirkhanov01@gmail.com>
This commit is contained in:
parent
f65605b575
commit
e3b41e0ed4
2 changed files with 93 additions and 17 deletions
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
from functools import wraps
|
||||
from abc import abstractmethod, ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List, Tuple, Type
|
||||
from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -173,28 +173,31 @@ class GraphDBInterface(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None:
|
||||
async def add_node(
|
||||
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Add a single node with specified properties to the graph.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- node_id (str): Unique identifier for the node being added.
|
||||
- properties (Dict[str, Any]): A dictionary of properties associated with the node.
|
||||
- node (Union[DataPoint, str]): Either a DataPoint object or a string identifier for the node being added.
|
||||
- properties (Optional[Dict[str, Any]]): A dictionary of properties associated with the node.
|
||||
Required when node is a string, ignored when node is a DataPoint.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: List[Node]) -> None:
|
||||
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
||||
"""
|
||||
Add multiple nodes to the graph in a single operation.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- nodes (List[Node]): A list of Node objects to be added to the graph.
|
||||
- nodes (Union[List[Node], List[DataPoint]]): A list of Node objects or DataPoint objects to be added to the graph.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -271,14 +274,16 @@ class GraphDBInterface(ABC):
|
|||
|
||||
@abstractmethod
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: List[EdgeData]) -> None:
|
||||
async def add_edges(
|
||||
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
|
||||
) -> None:
|
||||
"""
|
||||
Add multiple edges to the graph in a single operation.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- edges (List[EdgeData]): A list of EdgeData objects representing edges to be added.
|
||||
- edges (Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]): A list of EdgeData objects or tuples representing edges to be added.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -377,7 +382,7 @@ class GraphDBInterface(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def get_connections(
|
||||
self, node_id: str
|
||||
self, node_id: Union[str, UUID]
|
||||
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
||||
"""
|
||||
Get all nodes connected to a specified node and their relationship details.
|
||||
|
|
@ -385,6 +390,6 @@ class GraphDBInterface(ABC):
|
|||
Parameters:
|
||||
-----------
|
||||
|
||||
- node_id (str): Unique identifier of the node for which to retrieve connections.
|
||||
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Protocol, Optional
|
||||
from typing import List, Protocol, Optional, Union, Any
|
||||
from abc import abstractmethod
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from .models.PayloadSchema import PayloadSchema
|
||||
|
|
@ -31,7 +31,7 @@ class VectorDBInterface(Protocol):
|
|||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload_schema: Optional[PayloadSchema] = None,
|
||||
payload_schema: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Create a new collection with an optional payload schema.
|
||||
|
|
@ -40,8 +40,8 @@ class VectorDBInterface(Protocol):
|
|||
-----------
|
||||
|
||||
- collection_name (str): The name of the new collection to create.
|
||||
- payload_schema (Optional[PayloadSchema]): An optional schema for the payloads
|
||||
within this collection. (default None)
|
||||
- payload_schema (Optional[Any]): An optional schema for the payloads
|
||||
within this collection. Can be PayloadSchema, BaseModel, or other schema types. (default None)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class VectorDBInterface(Protocol):
|
|||
|
||||
- collection_name (str): The name of the collection from which to retrieve data
|
||||
points.
|
||||
- data_point_ids (list[str]): A list of IDs of the data points to retrieve.
|
||||
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to retrieve.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -123,7 +123,9 @@ class VectorDBInterface(Protocol):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(
|
||||
self, collection_name: str, data_point_ids: Union[List[str], list[str]]
|
||||
):
|
||||
"""
|
||||
Delete specified data points from a collection.
|
||||
|
||||
|
|
@ -132,7 +134,7 @@ class VectorDBInterface(Protocol):
|
|||
|
||||
- collection_name (str): The name of the collection from which to delete data
|
||||
points.
|
||||
- data_point_ids (list[str]): A list of IDs of the data points to delete.
|
||||
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to delete.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -142,3 +144,72 @@ class VectorDBInterface(Protocol):
|
|||
Remove obsolete or unnecessary data from the database.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def embed_data(self, data: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Embed textual data into vector representations.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- data (List[str]): A list of strings to be embedded.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[List[float]]: A list of embedded vectors corresponding to the input data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Optional methods that may be implemented by adapters
|
||||
async def get_connection(self):
|
||||
"""
|
||||
Get a connection to the vector database.
|
||||
This method is optional and may return None for adapters that don't use connections.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def get_collection(self, collection_name: str):
|
||||
"""
|
||||
Get a collection object from the vector database.
|
||||
This method is optional and may return None for adapters that don't expose collection objects.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
"""
|
||||
Create a vector index for improved search performance.
|
||||
This method is optional and may be a no-op for adapters that don't support indexing.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||
):
|
||||
"""
|
||||
Index data points for improved search performance.
|
||||
This method is optional and may be a no-op for adapters that don't support separate indexing.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- index_name (str): Name of the index to create/update
|
||||
- index_property_name (str): Property name to index on
|
||||
- data_points (List[DataPoint]): Data points to index
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data_point_schema(self, model_type: Any) -> Any:
|
||||
"""
|
||||
Get or transform a data point schema for the specific vector database.
|
||||
This method is optional and may return the input unchanged for simple adapters.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- model_type (Any): The model type to get schema for
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Any: The schema object suitable for this vector database
|
||||
"""
|
||||
return model_type
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue