diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 16600b386..65afdf275 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -2,7 +2,7 @@ import inspect from functools import wraps from abc import abstractmethod, ABC from datetime import datetime, timezone -from typing import Optional, Dict, Any, List, Tuple, Type +from typing import Optional, Dict, Any, List, Tuple, Type, Union from uuid import NAMESPACE_OID, UUID, uuid5 from cognee.shared.logging_utils import get_logger from cognee.infrastructure.engine import DataPoint @@ -173,28 +173,31 @@ class GraphDBInterface(ABC): raise NotImplementedError @abstractmethod - async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None: + async def add_node( + self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None + ) -> None: """ Add a single node with specified properties to the graph. Parameters: ----------- - - node_id (str): Unique identifier for the node being added. - - properties (Dict[str, Any]): A dictionary of properties associated with the node. + - node (Union[DataPoint, str]): Either a DataPoint object or a string identifier for the node being added. + - properties (Optional[Dict[str, Any]]): A dictionary of properties associated with the node. + Required when node is a string, ignored when node is a DataPoint. """ raise NotImplementedError @abstractmethod @record_graph_changes - async def add_nodes(self, nodes: List[Node]) -> None: + async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None: """ Add multiple nodes to the graph in a single operation. Parameters: ----------- - - nodes (List[Node]): A list of Node objects to be added to the graph. + - nodes (Union[List[Node], List[DataPoint]]): A list of Node objects or DataPoint objects to be added to the graph. """ raise NotImplementedError @@ -271,14 +274,16 @@ class GraphDBInterface(ABC): @abstractmethod @record_graph_changes - async def add_edges(self, edges: List[EdgeData]) -> None: + async def add_edges( + self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] + ) -> None: """ Add multiple edges to the graph in a single operation. Parameters: ----------- - - edges (List[EdgeData]): A list of EdgeData objects representing edges to be added. + - edges (Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]): A list of EdgeData objects or tuples representing edges to be added. """ raise NotImplementedError @@ -377,7 +382,7 @@ class GraphDBInterface(ABC): @abstractmethod async def get_connections( - self, node_id: str + self, node_id: Union[str, UUID] ) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]: """ Get all nodes connected to a specified node and their relationship details. @@ -385,6 +390,6 @@ class GraphDBInterface(ABC): Parameters: ----------- - - node_id (str): Unique identifier of the node for which to retrieve connections. + - node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections. """ raise NotImplementedError diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index 93c19adec..96b6bbd6f 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -1,4 +1,4 @@ -from typing import List, Protocol, Optional +from typing import List, Protocol, Optional, Union, Any from abc import abstractmethod from cognee.infrastructure.engine import DataPoint from .models.PayloadSchema import PayloadSchema @@ -31,7 +31,7 @@ class VectorDBInterface(Protocol): async def create_collection( self, collection_name: str, - payload_schema: Optional[PayloadSchema] = None, + payload_schema: Optional[Any] = None, ): """ Create a new collection with an optional payload schema. @@ -40,8 +40,8 @@ class VectorDBInterface(Protocol): ----------- - collection_name (str): The name of the new collection to create. - - payload_schema (Optional[PayloadSchema]): An optional schema for the payloads - within this collection. (default None) + - payload_schema (Optional[Any]): An optional schema for the payloads + within this collection. Can be PayloadSchema, BaseModel, or other schema types. (default None) """ raise NotImplementedError @@ -71,7 +71,7 @@ class VectorDBInterface(Protocol): - collection_name (str): The name of the collection from which to retrieve data points. - - data_point_ids (list[str]): A list of IDs of the data points to retrieve. + - data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to retrieve. """ raise NotImplementedError @@ -123,7 +123,9 @@ class VectorDBInterface(Protocol): raise NotImplementedError @abstractmethod - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async def delete_data_points( + self, collection_name: str, data_point_ids: Union[List[str], list[str]] + ): """ Delete specified data points from a collection. @@ -132,7 +134,7 @@ class VectorDBInterface(Protocol): - collection_name (str): The name of the collection from which to delete data points. - - data_point_ids (list[str]): A list of IDs of the data points to delete. + - data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to delete. """ raise NotImplementedError @@ -142,3 +144,72 @@ class VectorDBInterface(Protocol): Remove obsolete or unnecessary data from the database. """ raise NotImplementedError + + @abstractmethod + async def embed_data(self, data: List[str]) -> List[List[float]]: + """ + Embed textual data into vector representations. + + Parameters: + ----------- + + - data (List[str]): A list of strings to be embedded. + + Returns: + -------- + + - List[List[float]]: A list of embedded vectors corresponding to the input data. + """ + raise NotImplementedError + + # Optional methods that may be implemented by adapters + async def get_connection(self): + """ + Get a connection to the vector database. + This method is optional and may return None for adapters that don't use connections. + """ + return None + + async def get_collection(self, collection_name: str): + """ + Get a collection object from the vector database. + This method is optional and may return None for adapters that don't expose collection objects. + """ + return None + + async def create_vector_index(self, index_name: str, index_property_name: str): + """ + Create a vector index for improved search performance. + This method is optional and may be a no-op for adapters that don't support indexing. + """ + pass + + async def index_data_points( + self, index_name: str, index_property_name: str, data_points: List[DataPoint] + ): + """ + Index data points for improved search performance. + This method is optional and may be a no-op for adapters that don't support separate indexing. + + Parameters: + ----------- + - index_name (str): Name of the index to create/update + - index_property_name (str): Property name to index on + - data_points (List[DataPoint]): Data points to index + """ + pass + + def get_data_point_schema(self, model_type: Any) -> Any: + """ + Get or transform a data point schema for the specific vector database. + This method is optional and may return the input unchanged for simple adapters. + + Parameters: + ----------- + - model_type (Any): The model type to get schema for + + Returns: + -------- + - Any: The schema object suitable for this vector database + """ + return model_type