<!-- .github/pull_request_template.md --> ## Description Delete by document. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin
195 lines
7 KiB
Python
195 lines
7 KiB
Python
from typing import Protocol, Optional, Dict, Any, List, Tuple
|
|
from abc import abstractmethod, ABC
|
|
from uuid import UUID, uuid5, NAMESPACE_DNS
|
|
from cognee.modules.graph.relationship_manager import create_relationship
|
|
from functools import wraps
|
|
import inspect
|
|
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
|
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
|
from cognee.shared.logging_utils import get_logger
|
|
from datetime import datetime, timezone
|
|
|
|
logger = get_logger()
|
|
|
|
# Type aliases for better readability
|
|
NodeData = Dict[str, Any]
|
|
EdgeData = Tuple[
|
|
str, str, str, Dict[str, Any]
|
|
] # (source_id, target_id, relationship_name, properties)
|
|
Node = Tuple[str, NodeData] # (node_id, properties)
|
|
|
|
|
|
def record_graph_changes(func):
|
|
"""Decorator to record graph changes in the relationship database."""
|
|
db_engine = get_relational_engine()
|
|
|
|
@wraps(func)
|
|
async def wrapper(self, *args, **kwargs):
|
|
frame = inspect.currentframe()
|
|
while frame:
|
|
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
|
|
caller_frame = frame.f_back
|
|
break
|
|
frame = frame.f_back
|
|
|
|
caller_name = caller_frame.f_code.co_name
|
|
caller_class = (
|
|
caller_frame.f_locals.get("self", None).__class__.__name__
|
|
if caller_frame.f_locals.get("self", None)
|
|
else None
|
|
)
|
|
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
|
|
|
result = await func(self, *args, **kwargs)
|
|
|
|
async with db_engine.get_async_session() as session:
|
|
if func.__name__ == "add_nodes":
|
|
nodes = args[0]
|
|
for node in nodes:
|
|
try:
|
|
node_id = (
|
|
UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id))
|
|
)
|
|
relationship = GraphRelationshipLedger(
|
|
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
source_node_id=node_id,
|
|
destination_node_id=node_id,
|
|
creator_function=f"{creator}.node",
|
|
node_label=node[1].get("type")
|
|
if isinstance(node, tuple)
|
|
else type(node).__name__,
|
|
)
|
|
session.add(relationship)
|
|
await session.flush()
|
|
except Exception as e:
|
|
logger.error(f"Error adding relationship: {e}")
|
|
await session.rollback()
|
|
continue
|
|
|
|
elif func.__name__ == "add_edges":
|
|
edges = args[0]
|
|
for edge in edges:
|
|
try:
|
|
source_id = UUID(str(edge[0]))
|
|
target_id = UUID(str(edge[1]))
|
|
rel_type = str(edge[2])
|
|
relationship = GraphRelationshipLedger(
|
|
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
source_node_id=source_id,
|
|
destination_node_id=target_id,
|
|
creator_function=f"{creator}.{rel_type}",
|
|
)
|
|
session.add(relationship)
|
|
await session.flush()
|
|
except Exception as e:
|
|
logger.error(f"Error adding relationship: {e}")
|
|
await session.rollback()
|
|
continue
|
|
|
|
try:
|
|
await session.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error committing session: {e}")
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
class GraphDBInterface(ABC):
|
|
"""Interface for graph database operations."""
|
|
|
|
@abstractmethod
|
|
async def query(self, query: str, params: dict) -> List[Any]:
|
|
"""Execute a raw query against the database."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None:
|
|
"""Add a single node to the graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
@record_graph_changes
|
|
async def add_nodes(self, nodes: List[Node]) -> None:
|
|
"""Add multiple nodes to the graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_node(self, node_id: str) -> None:
|
|
"""Delete a node from the graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_nodes(self, node_ids: List[str]) -> None:
|
|
"""Delete multiple nodes from the graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
|
"""Get a single node by ID."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
|
|
"""Get multiple nodes by their IDs."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def add_edge(
|
|
self,
|
|
source_id: str,
|
|
target_id: str,
|
|
relationship_name: str,
|
|
properties: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Add a single edge to the graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
@record_graph_changes
|
|
async def add_edges(self, edges: List[EdgeData]) -> None:
|
|
"""Add multiple edges to the graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_graph(self) -> None:
|
|
"""Delete the entire graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
|
"""Get all nodes and edges in the graph."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
|
"""Get graph metrics and statistics."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
|
"""Check if an edge exists."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]:
|
|
"""Check if multiple edges exist."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_edges(self, node_id: str) -> List[EdgeData]:
|
|
"""Get all edges connected to a node."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
|
"""Get all neighboring nodes."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_connections(
|
|
self, node_id: str
|
|
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
|
"""Get all nodes connected to a given node with their relationships."""
|
|
raise NotImplementedError
|