fix: Interface fixes (#1206)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Daulet Amirkhanov <damirkhanov01@gmail.com>
This commit is contained in:
parent
f65605b575
commit
e3b41e0ed4
2 changed files with 93 additions and 17 deletions
|
|
@ -2,7 +2,7 @@ import inspect
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, Dict, Any, List, Tuple, Type
|
from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
||||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
@ -173,28 +173,31 @@ class GraphDBInterface(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None:
|
async def add_node(
|
||||||
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a single node with specified properties to the graph.
|
Add a single node with specified properties to the graph.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- node_id (str): Unique identifier for the node being added.
|
- node (Union[DataPoint, str]): Either a DataPoint object or a string identifier for the node being added.
|
||||||
- properties (Dict[str, Any]): A dictionary of properties associated with the node.
|
- properties (Optional[Dict[str, Any]]): A dictionary of properties associated with the node.
|
||||||
|
Required when node is a string, ignored when node is a DataPoint.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@record_graph_changes
|
@record_graph_changes
|
||||||
async def add_nodes(self, nodes: List[Node]) -> None:
|
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the graph in a single operation.
|
Add multiple nodes to the graph in a single operation.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- nodes (List[Node]): A list of Node objects to be added to the graph.
|
- nodes (Union[List[Node], List[DataPoint]]): A list of Node objects or DataPoint objects to be added to the graph.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -271,14 +274,16 @@ class GraphDBInterface(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@record_graph_changes
|
@record_graph_changes
|
||||||
async def add_edges(self, edges: List[EdgeData]) -> None:
|
async def add_edges(
|
||||||
|
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple edges to the graph in a single operation.
|
Add multiple edges to the graph in a single operation.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- edges (List[EdgeData]): A list of EdgeData objects representing edges to be added.
|
- edges (Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]): A list of EdgeData objects or tuples representing edges to be added.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -377,7 +382,7 @@ class GraphDBInterface(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_connections(
|
async def get_connections(
|
||||||
self, node_id: str
|
self, node_id: Union[str, UUID]
|
||||||
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
||||||
"""
|
"""
|
||||||
Get all nodes connected to a specified node and their relationship details.
|
Get all nodes connected to a specified node and their relationship details.
|
||||||
|
|
@ -385,6 +390,6 @@ class GraphDBInterface(ABC):
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- node_id (str): Unique identifier of the node for which to retrieve connections.
|
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Protocol, Optional
|
from typing import List, Protocol, Optional, Union, Any
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from .models.PayloadSchema import PayloadSchema
|
from .models.PayloadSchema import PayloadSchema
|
||||||
|
|
@ -31,7 +31,7 @@ class VectorDBInterface(Protocol):
|
||||||
async def create_collection(
|
async def create_collection(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
payload_schema: Optional[PayloadSchema] = None,
|
payload_schema: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new collection with an optional payload schema.
|
Create a new collection with an optional payload schema.
|
||||||
|
|
@ -40,8 +40,8 @@ class VectorDBInterface(Protocol):
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- collection_name (str): The name of the new collection to create.
|
- collection_name (str): The name of the new collection to create.
|
||||||
- payload_schema (Optional[PayloadSchema]): An optional schema for the payloads
|
- payload_schema (Optional[Any]): An optional schema for the payloads
|
||||||
within this collection. (default None)
|
within this collection. Can be PayloadSchema, BaseModel, or other schema types. (default None)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -71,7 +71,7 @@ class VectorDBInterface(Protocol):
|
||||||
|
|
||||||
- collection_name (str): The name of the collection from which to retrieve data
|
- collection_name (str): The name of the collection from which to retrieve data
|
||||||
points.
|
points.
|
||||||
- data_point_ids (list[str]): A list of IDs of the data points to retrieve.
|
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to retrieve.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -123,7 +123,9 @@ class VectorDBInterface(Protocol):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(
|
||||||
|
self, collection_name: str, data_point_ids: Union[List[str], list[str]]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Delete specified data points from a collection.
|
Delete specified data points from a collection.
|
||||||
|
|
||||||
|
|
@ -132,7 +134,7 @@ class VectorDBInterface(Protocol):
|
||||||
|
|
||||||
- collection_name (str): The name of the collection from which to delete data
|
- collection_name (str): The name of the collection from which to delete data
|
||||||
points.
|
points.
|
||||||
- data_point_ids (list[str]): A list of IDs of the data points to delete.
|
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to delete.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -142,3 +144,72 @@ class VectorDBInterface(Protocol):
|
||||||
Remove obsolete or unnecessary data from the database.
|
Remove obsolete or unnecessary data from the database.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def embed_data(self, data: List[str]) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Embed textual data into vector representations.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- data (List[str]): A list of strings to be embedded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- List[List[float]]: A list of embedded vectors corresponding to the input data.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Optional methods that may be implemented by adapters
|
||||||
|
async def get_connection(self):
|
||||||
|
"""
|
||||||
|
Get a connection to the vector database.
|
||||||
|
This method is optional and may return None for adapters that don't use connections.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_collection(self, collection_name: str):
|
||||||
|
"""
|
||||||
|
Get a collection object from the vector database.
|
||||||
|
This method is optional and may return None for adapters that don't expose collection objects.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||||
|
"""
|
||||||
|
Create a vector index for improved search performance.
|
||||||
|
This method is optional and may be a no-op for adapters that don't support indexing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def index_data_points(
|
||||||
|
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Index data points for improved search performance.
|
||||||
|
This method is optional and may be a no-op for adapters that don't support separate indexing.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- index_name (str): Name of the index to create/update
|
||||||
|
- index_property_name (str): Property name to index on
|
||||||
|
- data_points (List[DataPoint]): Data points to index
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_data_point_schema(self, model_type: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Get or transform a data point schema for the specific vector database.
|
||||||
|
This method is optional and may return the input unchanged for simple adapters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- model_type (Any): The model type to get schema for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- Any: The schema object suitable for this vector database
|
||||||
|
"""
|
||||||
|
return model_type
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue