fix: Interface fixes (#1206)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Daulet Amirkhanov <damirkhanov01@gmail.com>
This commit is contained in:
Vasilije 2025-08-08 20:41:33 +02:00 committed by GitHub
parent f65605b575
commit e3b41e0ed4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 93 additions and 17 deletions

View file

@ -2,7 +2,7 @@ import inspect
from functools import wraps
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple, Type
from typing import Optional, Dict, Any, List, Tuple, Type, Union
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
@ -173,28 +173,31 @@ class GraphDBInterface(ABC):
raise NotImplementedError
@abstractmethod
async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None:
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a single node with specified properties to the graph.
Parameters:
-----------
- node_id (str): Unique identifier for the node being added.
- properties (Dict[str, Any]): A dictionary of properties associated with the node.
- node (Union[DataPoint, str]): Either a DataPoint object or a string identifier for the node being added.
- properties (Optional[Dict[str, Any]]): A dictionary of properties associated with the node.
Required when node is a string, ignored when node is a DataPoint.
"""
raise NotImplementedError
@abstractmethod
@record_graph_changes
async def add_nodes(self, nodes: List[Node]) -> None:
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
"""
Add multiple nodes to the graph in a single operation.
Parameters:
-----------
- nodes (List[Node]): A list of Node objects to be added to the graph.
- nodes (Union[List[Node], List[DataPoint]]): A list of Node objects or DataPoint objects to be added to the graph.
"""
raise NotImplementedError
@ -271,14 +274,16 @@ class GraphDBInterface(ABC):
@abstractmethod
@record_graph_changes
async def add_edges(self, edges: List[EdgeData]) -> None:
async def add_edges(
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
) -> None:
"""
Add multiple edges to the graph in a single operation.
Parameters:
-----------
- edges (List[EdgeData]): A list of EdgeData objects representing edges to be added.
- edges (Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]): A list of EdgeData objects or tuples representing edges to be added.
"""
raise NotImplementedError
@ -377,7 +382,7 @@ class GraphDBInterface(ABC):
@abstractmethod
async def get_connections(
self, node_id: str
self, node_id: Union[str, UUID]
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
"""
Get all nodes connected to a specified node and their relationship details.
@ -385,6 +390,6 @@ class GraphDBInterface(ABC):
Parameters:
-----------
- node_id (str): Unique identifier of the node for which to retrieve connections.
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError

View file

@ -1,4 +1,4 @@
from typing import List, Protocol, Optional
from typing import List, Protocol, Optional, Union, Any
from abc import abstractmethod
from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema
@ -31,7 +31,7 @@ class VectorDBInterface(Protocol):
async def create_collection(
self,
collection_name: str,
payload_schema: Optional[PayloadSchema] = None,
payload_schema: Optional[Any] = None,
):
"""
Create a new collection with an optional payload schema.
@ -40,8 +40,8 @@ class VectorDBInterface(Protocol):
-----------
- collection_name (str): The name of the new collection to create.
- payload_schema (Optional[PayloadSchema]): An optional schema for the payloads
within this collection. (default None)
- payload_schema (Optional[Any]): An optional schema for the payloads
within this collection. Can be PayloadSchema, BaseModel, or other schema types. (default None)
"""
raise NotImplementedError
@ -71,7 +71,7 @@ class VectorDBInterface(Protocol):
- collection_name (str): The name of the collection from which to retrieve data
points.
- data_point_ids (list[str]): A list of IDs of the data points to retrieve.
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to retrieve.
"""
raise NotImplementedError
@ -123,7 +123,9 @@ class VectorDBInterface(Protocol):
raise NotImplementedError
@abstractmethod
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(
self, collection_name: str, data_point_ids: Union[List[str], list[str]]
):
"""
Delete specified data points from a collection.
@ -132,7 +134,7 @@ class VectorDBInterface(Protocol):
- collection_name (str): The name of the collection from which to delete data
points.
- data_point_ids (list[str]): A list of IDs of the data points to delete.
- data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to delete.
"""
raise NotImplementedError
@ -142,3 +144,72 @@ class VectorDBInterface(Protocol):
Remove obsolete or unnecessary data from the database.
"""
raise NotImplementedError
@abstractmethod
async def embed_data(self, data: List[str]) -> List[List[float]]:
"""
Embed textual data into vector representations.
Parameters:
-----------
- data (List[str]): A list of strings to be embedded.
Returns:
--------
- List[List[float]]: A list of embedded vectors corresponding to the input data.
"""
raise NotImplementedError
# Optional methods that may be implemented by adapters
async def get_connection(self):
"""
Get a connection to the vector database.
This method is optional and may return None for adapters that don't use connections.
"""
return None
async def get_collection(self, collection_name: str):
"""
Get a collection object from the vector database.
This method is optional and may return None for adapters that don't expose collection objects.
"""
return None
async def create_vector_index(self, index_name: str, index_property_name: str):
"""
Create a vector index for improved search performance.
This method is optional and may be a no-op for adapters that don't support indexing.
"""
pass
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
):
"""
Index data points for improved search performance.
This method is optional and may be a no-op for adapters that don't support separate indexing.
Parameters:
-----------
- index_name (str): Name of the index to create/update
- index_property_name (str): Property name to index on
- data_points (List[DataPoint]): Data points to index
"""
pass
def get_data_point_schema(self, model_type: Any) -> Any:
"""
Get or transform a data point schema for the specific vector database.
This method is optional and may return the input unchanged for simple adapters.
Parameters:
-----------
- model_type (Any): The model type to get schema for
Returns:
--------
- Any: The schema object suitable for this vector database
"""
return model_type