cognee/cognee/modules/graph/cognee_graph/CogneeGraph.py
Igor Ilic 1260fc7db0
fix: Add reraising of general exception handling in cognee [COG-1062] (#490)
<!-- .github/pull_request_template.md -->

## Description
Add re-raising of errors in general exception handling 

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Bug Fixes & Stability Improvements**
- Enhanced error handling throughout the system, ensuring issues during
operations like server startup, data processing, and graph management
are properly logged and reported.

- **Refactor**
- Standardized logging practices replace basic output statements,
improving traceability and providing better insights for
troubleshooting.

- **New Features**
- Updated search functionality now returns only unique results,
enhancing data consistency and the overall user experience.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: holchan <61059652+holchan@users.noreply.github.com>
Co-authored-by: Boris <boris@topoteretes.com>
2025-02-04 10:51:05 +01:00

173 lines
6.7 KiB
Python

import numpy as np
from typing import List, Dict, Union
from cognee.exceptions import InvalidValueError
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
import heapq
import asyncio
class CogneeGraph(CogneeAbstractGraph):
"""
Concrete implementation of the AbstractGraph class for Cognee.
This class provides the functionality to manage nodes and edges,
and project a graph from a database using adapters.
"""
nodes: Dict[str, Node]
edges: List[Edge]
directed: bool
def __init__(self, directed: bool = True):
self.nodes = {}
self.edges = []
self.directed = directed
def add_node(self, node: Node) -> None:
if node.id not in self.nodes:
self.nodes[node.id] = node
else:
raise EntityAlreadyExistsError(message=f"Node with id {node.id} already exists.")
def add_edge(self, edge: Edge) -> None:
if edge not in self.edges:
self.edges.append(edge)
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
else:
print(f"Edge {edge} already exists in the graph.")
def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)
def get_edges_from_node(self, node_id: str) -> List[Edge]:
node = self.get_node(node_id)
if node:
return node.skeleton_edges
else:
raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
def get_edges(self) -> List[Edge]:
return self.edges
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
node_properties_to_project: List[str],
edge_properties_to_project: List[str],
directed=True,
node_dimension=1,
edge_dimension=1,
memory_fragment_filter=[],
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers")
try:
if len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data:
raise EntityNotFoundError(message="No node data retrieved from the database.")
if not edges_data:
raise EntityNotFoundError(message="No edge data retrieved from the database.")
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
for source_id, target_id, relationship_type, properties in edges_data:
source_node = self.get_node(str(source_id))
target_node = self.get_node(str(target_id))
if source_node and target_node:
edge_attributes = {
key: properties.get(key) for key in edge_properties_to_project
}
edge_attributes["relationship_type"] = relationship_type
edge = Edge(
source_node,
target_node,
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
)
self.add_edge(edge)
source_node.add_skeleton_edge(edge)
target_node.add_skeleton_edge(edge)
else:
raise EntityNotFoundError(
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
)
except (ValueError, TypeError) as e:
print(f"Error projecting graph: {e}")
raise e
except Exception as ex:
print(f"Unexpected error: {ex}")
raise ex
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
score = scored_result.score
node = self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
else:
print(f"Node with id {node_id} not found in the graph.")
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None:
try:
query_vector = await vector_engine.embed_data([query])
query_vector = query_vector[0]
if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.")
edge_distances = await vector_engine.get_distance_from_collection_elements(
"edge_type_relationship_name", query_text=query
)
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
if not relationship_type or relationship_type not in embedding_map:
print(f"Edge {edge} has an unknown or missing relationship type.")
continue
edge.attributes["vector_distance"] = embedding_map[relationship_type]
except Exception as ex:
print(f"Error mapping vector distances to edges: {ex}")
raise ex
async def calculate_top_triplet_importances(self, k: int) -> List:
min_heap = []
for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id)
source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1
target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1
edge_distance = edge.attributes.get("vector_distance", 1)
total_distance = source_distance + target_distance + edge_distance
heapq.heappush(min_heap, (-total_distance, i, edge))
if len(min_heap) > k:
heapq.heappop(min_heap)
return [edge for _, _, edge in sorted(min_heap)]