400 lines
13 KiB
Python
400 lines
13 KiB
Python
import inspect
|
|
from functools import wraps
|
|
from abc import abstractmethod, ABC
|
|
from datetime import datetime, timezone
|
|
from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
|
from uuid import NAMESPACE_OID, UUID, uuid5
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
|
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
|
|
|
logger = get_logger()
|
|
|
|
# Type aliases for better readability
|
|
NodeData = Dict[str, Any]
|
|
EdgeData = Tuple[
|
|
str, str, str, Dict[str, Any]
|
|
] # (source_id, target_id, relationship_name, properties)
|
|
Node = Tuple[str, NodeData] # (node_id, properties)
|
|
|
|
|
|
def record_graph_changes(func):
|
|
"""
|
|
Decorator to record graph changes in the relationship database.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- func: The asynchronous function to wrap, which likely modifies graph data.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the wrapped function that manages database relationships.
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(self, *args, **kwargs):
|
|
"""
|
|
Wraps the given asynchronous function to handle database relationships.
|
|
|
|
Tracks the caller's function and class name for context. When the wrapped function is
|
|
called, it manages database relationships for nodes or edges by adding entries to a
|
|
ledger and committing the changes to the database session. Errors during relationship
|
|
addition or session commit are logged and will not disrupt the execution of the wrapped
|
|
function.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- *args: Positional arguments passed to the wrapped function.
|
|
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the result of the wrapped function call.
|
|
"""
|
|
db_engine = get_relational_engine()
|
|
frame = inspect.currentframe()
|
|
while frame:
|
|
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
|
|
caller_frame = frame.f_back
|
|
break
|
|
frame = frame.f_back
|
|
|
|
caller_name = caller_frame.f_code.co_name
|
|
caller_class = (
|
|
caller_frame.f_locals.get("self", None).__class__.__name__
|
|
if caller_frame.f_locals.get("self", None)
|
|
else None
|
|
)
|
|
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
|
|
|
result = await func(self, *args, **kwargs)
|
|
|
|
async with db_engine.get_async_session() as session:
|
|
if func.__name__ == "add_nodes":
|
|
nodes: List[DataPoint] = args[0]
|
|
|
|
relationship_ledgers = []
|
|
|
|
for node in nodes:
|
|
node_id = UUID(str(node.id))
|
|
relationship_ledgers.append(
|
|
GraphRelationshipLedger(
|
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
source_node_id=node_id,
|
|
destination_node_id=node_id,
|
|
creator_function=f"{creator}.node",
|
|
node_label=getattr(node, "name", None) or str(node.id),
|
|
)
|
|
)
|
|
|
|
try:
|
|
session.add_all(relationship_ledgers)
|
|
await session.flush()
|
|
except Exception as e:
|
|
logger.debug(f"Error adding relationship: {e}")
|
|
await session.rollback()
|
|
|
|
elif func.__name__ == "add_edges":
|
|
edges = args[0]
|
|
|
|
relationship_ledgers = []
|
|
|
|
for edge in edges:
|
|
source_id = UUID(str(edge[0]))
|
|
target_id = UUID(str(edge[1]))
|
|
rel_type = str(edge[2])
|
|
relationship_ledgers.append(
|
|
GraphRelationshipLedger(
|
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
source_node_id=source_id,
|
|
destination_node_id=target_id,
|
|
creator_function=f"{creator}.{rel_type}",
|
|
)
|
|
)
|
|
|
|
try:
|
|
session.add_all(relationship_ledgers)
|
|
await session.flush()
|
|
except Exception as e:
|
|
logger.debug(f"Error adding relationship: {e}")
|
|
await session.rollback()
|
|
|
|
try:
|
|
await session.commit()
|
|
except Exception as e:
|
|
logger.debug(f"Error committing session: {e}")
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
class GraphDBInterface(ABC):
|
|
"""
|
|
Define an interface for graph database operations to be implemented by concrete classes.
|
|
|
|
Public methods include:
|
|
- query
|
|
- add_node
|
|
- add_nodes
|
|
- delete_node
|
|
- delete_nodes
|
|
- get_node
|
|
- get_nodes
|
|
- add_edge
|
|
- add_edges
|
|
- delete_graph
|
|
- get_graph_data
|
|
- get_graph_metrics
|
|
- has_edge
|
|
- has_edges
|
|
- get_edges
|
|
- get_neighbors
|
|
- get_nodeset_subgraph
|
|
- get_connections
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def is_empty(self) -> bool:
|
|
logger.warning("is_empty() is not implemented")
|
|
return True
|
|
|
|
@abstractmethod
|
|
async def query(self, query: str, params: dict) -> List[Any]:
|
|
"""
|
|
Execute a raw database query and return the results.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- query (str): The query string to execute against the database.
|
|
- params (dict): A dictionary of parameters to be used in the query.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def add_node(
|
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
"""
|
|
Add a single node with specified properties to the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node (Union[DataPoint, str]): Either a DataPoint object or a string identifier for the node being added.
|
|
- properties (Optional[Dict[str, Any]]): A dictionary of properties associated with the node.
|
|
Required when node is a string, ignored when node is a DataPoint.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
@record_graph_changes
|
|
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
|
"""
|
|
Add multiple nodes to the graph in a single operation.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- nodes (Union[List[Node], List[DataPoint]]): A list of Node objects or DataPoint objects to be added to the graph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_node(self, node_id: str) -> None:
|
|
"""
|
|
Delete a specified node from the graph by its ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier for the node to delete.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_nodes(self, node_ids: List[str]) -> None:
|
|
"""
|
|
Delete multiple nodes from the graph by their identifiers.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of unique identifiers for the nodes to delete.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
|
"""
|
|
Retrieve a single node from the graph using its ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node to retrieve.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
|
|
"""
|
|
Retrieve multiple nodes from the graph using their IDs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of unique identifiers for the nodes to retrieve.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def add_edge(
|
|
self,
|
|
source_id: str,
|
|
target_id: str,
|
|
relationship_name: str,
|
|
properties: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""
|
|
Create a new edge between two nodes in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- source_id (str): The unique identifier of the source node.
|
|
- target_id (str): The unique identifier of the target node.
|
|
- relationship_name (str): The name of the relationship to be established by the
|
|
edge.
|
|
- properties (Optional[Dict[str, Any]]): Optional dictionary of properties
|
|
associated with the edge. (default None)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
@record_graph_changes
|
|
async def add_edges(
|
|
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
|
|
) -> None:
|
|
"""
|
|
Add multiple edges to the graph in a single operation.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges (Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]): A list of EdgeData objects or tuples representing edges to be added.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_graph(self) -> None:
|
|
"""
|
|
Remove the entire graph, including all nodes and edges.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
|
"""
|
|
Retrieve all nodes and edges within the graph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Fetch metrics and statistics of the graph, possibly including optional details.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- include_optional (bool): Flag indicating whether to include optional metrics or
|
|
not. (default False)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
|
"""
|
|
Verify if an edge exists between two specified nodes.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- source_id (str): Unique identifier of the source node.
|
|
- target_id (str): Unique identifier of the target node.
|
|
- relationship_name (str): Name of the relationship to verify.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]:
|
|
"""
|
|
Determine the existence of multiple edges in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges (List[EdgeData]): A list of EdgeData objects to check for existence in the
|
|
graph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_edges(self, node_id: str) -> List[EdgeData]:
|
|
"""
|
|
Retrieve all edges that are connected to the specified node.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node whose edges are to be retrieved.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
|
"""
|
|
Get all neighboring nodes connected to the specified node.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node for which to retrieve neighbors.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_nodeset_subgraph(
|
|
self, node_type: Type[Any], node_name: List[str]
|
|
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
|
"""
|
|
Fetch a subgraph consisting of a specific set of nodes and their relationships.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_type (Type[Any]): The type of nodes to include in the subgraph.
|
|
- node_name (List[str]): A list of names of the nodes to include in the subgraph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_connections(
|
|
self, node_id: Union[str, UUID]
|
|
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
|
"""
|
|
Get all nodes connected to a specified node and their relationship details.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
|
"""
|
|
raise NotImplementedError
|