cognee/cognee/modules/graph/cognee_graph/CogneeGraph.py
2025-12-17 15:48:24 +01:00

254 lines
9.8 KiB
Python

import time
from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union, Optional, Type
from cognee.modules.graph.exceptions import (
EntityNotFoundError,
EntityAlreadyExistsError,
InvalidDimensionsError,
)
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
import heapq
logger = get_logger("CogneeGraph")
class CogneeGraph(CogneeAbstractGraph):
"""
Concrete implementation of the AbstractGraph class for Cognee.
This class provides the functionality to manage nodes and edges,
and project a graph from a database using adapters.
"""
nodes: Dict[str, Node]
edges: List[Edge]
directed: bool
triplet_distance_penalty: float
def __init__(self, directed: bool = True):
self.nodes = {}
self.edges = []
self.directed = directed
self.triplet_distance_penalty = 3.5
def add_node(self, node: Node) -> None:
if node.id not in self.nodes:
self.nodes[node.id] = node
else:
raise EntityAlreadyExistsError(message=f"Node with id {node.id} already exists.")
def add_edge(self, edge: Edge) -> None:
self.edges.append(edge)
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)
def get_edges_from_node(self, node_id: str) -> List[Edge]:
node = self.get_node(node_id)
if node:
return node.skeleton_edges
else:
raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
def get_edges(self) -> List[Edge]:
return self.edges
async def _get_nodeset_subgraph(
self,
adapter,
node_type,
node_name,
):
"""Retrieve subgraph based on node type and name."""
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodeset projected from the database."
)
return nodes_data, edges_data
async def _get_full_or_id_filtered_graph(
self,
adapter,
relevant_ids_to_filter,
):
"""Retrieve full or ID-filtered graph with fallback."""
if relevant_ids_to_filter is None:
logger.info("Retrieving full graph.")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
return nodes_data, edges_data
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
logger.info("Retrieving ID-filtered graph from database.")
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
else:
logger.info("Retrieving full graph from database.")
nodes_data, edges_data = await get_graph_data_fn()
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
logger.warning(
"Id filtered graph returned empty, falling back to full graph retrieval."
)
logger.info("Retrieving full graph")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError("Empty graph projected from the database.")
return nodes_data, edges_data
async def _get_filtered_graph(
self,
adapter,
memory_fragment_filter,
):
"""Retrieve graph filtered by attributes."""
logger.info("Retrieving graph filtered by memory fragment")
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
return nodes_data, edges_data
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
node_properties_to_project: List[str],
edge_properties_to_project: List[str],
directed=True,
node_dimension=1,
edge_dimension=1,
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
relevant_ids_to_filter: Optional[List[str]] = None,
triplet_distance_penalty: float = 3.5,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidDimensionsError()
try:
if node_type is not None and node_name not in [None, [], ""]:
nodes_data, edges_data = await self._get_nodeset_subgraph(
adapter, node_type, node_name
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
adapter, relevant_ids_to_filter
)
else:
nodes_data, edges_data = await self._get_filtered_graph(
adapter, memory_fragment_filter
)
self.triplet_distance_penalty = triplet_distance_penalty
import time
start_time = time.time()
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(
Node(
str(node_id),
node_attributes,
dimension=node_dimension,
node_penalty=triplet_distance_penalty,
)
)
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
source_node = self.get_node(str(source_id))
target_node = self.get_node(str(target_id))
if source_node and target_node:
edge_attributes = {
key: properties.get(key) for key in edge_properties_to_project
}
edge_attributes["relationship_type"] = relationship_type
edge = Edge(
source_node,
target_node,
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)
source_node.add_skeleton_edge(edge)
target_node.add_skeleton_edge(edge)
else:
raise EntityNotFoundError(
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
)
# Final statistics
projection_time = time.time() - start_time
logger.info(
f"Graph projection completed: {len(self.nodes)} nodes, {len(self.edges)} edges in {projection_time:.2f}s"
)
except Exception as e:
logger.error(f"Error during graph projection: {str(e)}")
raise
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
mapped_nodes = 0
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
score = scored_result.score
node = self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
mapped_nodes += 1
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
try:
if edge_distances is None:
return
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
for edge in self.edges:
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance
except Exception as ex:
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
raise ex
def _as_distance(self, value: Union[float, List[float], None]) -> float:
"""Normalize distance value to float, handling None, lists, and scalars."""
if value is None:
return self.triplet_distance_penalty
if isinstance(value, list) and value:
return float(value[0])
if isinstance(value, (int, float)):
return float(value)
return self.triplet_distance_penalty
async def calculate_top_triplet_importances(self, k: int) -> List[Edge]:
def score(edge):
n1 = self._as_distance(edge.node1.attributes.get("vector_distance"))
n2 = self._as_distance(edge.node2.attributes.get("vector_distance"))
e = self._as_distance(edge.attributes.get("vector_distance"))
return n1 + n2 + e
return heapq.nsmallest(k, self.edges, key=score)