cognee/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
2025-12-18 14:52:35 +01:00

202 lines
7.1 KiB
Python

import numpy as np
from typing import List, Dict, Optional, Any, Union
from cognee.modules.graph.exceptions import InvalidDimensionsError, DimensionOutOfRangeError
class Node:
"""
Represents a node in a graph.
Attributes:
id (str): A unique identifier for the node.
attributes (Dict[str, Any]): A dictionary of attributes associated with the node.
neighbors (List[Node]): Represents the original nodes
skeleton_edges (List[Edge]): Represents the original edges
"""
id: str
attributes: Dict[str, Any]
skeleton_neighbours: List["Node"]
skeleton_edges: List["Edge"]
status: np.ndarray
def __init__(
self,
node_id: str,
attributes: Optional[Dict[str, Any]] = None,
dimension: int = 1,
node_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = None
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
self.attributes["vector_distance"] = [default_penalty] * query_count
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
distances = self.attributes.get("vector_distance")
if not isinstance(distances, list) or len(distances) != query_count:
distances = [default_penalty] * query_count
self.attributes["vector_distance"] = distances
return distances
def update_distance_for_query(
self,
query_index: int,
score: float,
query_count: int,
default_penalty: float,
) -> None:
distances = self.ensure_vector_distance_list(query_count, default_penalty)
distances[query_index] = score
def add_skeleton_neighbor(self, neighbor: "Node") -> None:
if neighbor not in self.skeleton_neighbours:
self.skeleton_neighbours.append(neighbor)
def remove_skeleton_neighbor(self, neighbor: "Node") -> None:
if neighbor in self.skeleton_neighbours:
self.skeleton_neighbours.remove(neighbor)
def add_skeleton_edge(self, edge: "Edge") -> None:
self.skeleton_edges.append(edge)
# Add neighbor
if edge.node1 == self:
self.add_skeleton_neighbor(edge.node2)
elif edge.node2 == self:
self.add_skeleton_neighbor(edge.node1)
def remove_skeleton_edge(self, edge: "Edge") -> None:
if edge in self.skeleton_edges:
self.skeleton_edges.remove(edge)
# Remove neighbor if no other edge connects them
neighbor = edge.node2 if edge.node1 == self else edge.node1
if all(e.node1 != neighbor and e.node2 != neighbor for e in self.skeleton_edges):
self.remove_skeleton_neighbor(neighbor)
def is_node_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status):
raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1)
return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str) -> Union[str, int, float]:
return self.attributes[key]
def get_skeleton_edges(self):
return self.skeleton_edges
def get_skeleton_neighbours(self):
return self.skeleton_neighbours
def __repr__(self) -> str:
return f"Node({self.id}, attributes={self.attributes})"
def __hash__(self) -> int:
return hash(self.id)
def __eq__(self, other: "Node") -> bool:
return isinstance(other, Node) and self.id == other.id
class Edge:
"""
Represents an edge in a graph, connecting two nodes.
Attributes:
node1 (Node): The starting node of the edge.
node2 (Node): The ending node of the edge.
attributes (Dict[str, Any]): A dictionary of attributes associated with the edge.
directed (bool): A flag indicating whether the edge is directed or undirected.
"""
node1: "Node"
node2: "Node"
attributes: Dict[str, Any]
directed: bool
status: np.ndarray
def __init__(
self,
node1: "Node",
node2: "Node",
attributes: Optional[Dict[str, Any]] = None,
directed: bool = True,
dimension: int = 1,
edge_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = None
self.directed = directed
self.status = np.ones(dimension, dtype=int)
def get_distance_key(self) -> Optional[str]:
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
if key is None:
return None
return str(key)
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
self.attributes["vector_distance"] = [default_penalty] * query_count
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
distances = self.attributes.get("vector_distance")
if not isinstance(distances, list) or len(distances) != query_count:
distances = [default_penalty] * query_count
self.attributes["vector_distance"] = distances
return distances
def update_distance_for_query(
self,
query_index: int,
score: float,
query_count: int,
default_penalty: float,
) -> None:
distances = self.ensure_vector_distance_list(query_count, default_penalty)
distances[query_index] = score
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status):
raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1)
return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str) -> Optional[Union[str, int, float]]:
return self.attributes.get(key)
def get_source_node(self):
return self.node1
def get_destination_node(self):
return self.node2
def __repr__(self) -> str:
direction = "->" if self.directed else "--"
return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})"
def __hash__(self) -> int:
if self.directed:
return hash((self.node1, self.node2))
else:
return hash(frozenset({self.node1, self.node2}))
def __eq__(self, other: "Edge") -> bool:
if not isinstance(other, Edge):
return False
if self.directed:
return self.node1 == other.node1 and self.node2 == other.node2
else:
return {self.node1, self.node2} == {other.node1, other.node2}