feat: Adds in time edge vector similarity calculation and triplet importances

This commit is contained in:
hajdul88 2024-11-20 18:32:03 +01:00
parent 9f557b0c5b
commit 980ae2b22c
3 changed files with 105 additions and 7 deletions

View file

@ -1,9 +1,12 @@
from typing import List, Dict, Union
import numpy as np
from typing import List, Dict, Union
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
from cognee.infrastructure.databases.graph import get_graph_engine
import heapq
from graphistry import edges
class CogneeGraph(CogneeAbstractGraph):
"""
@ -39,13 +42,16 @@ class CogneeGraph(CogneeAbstractGraph):
def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)
def get_edges(self, node_id: str) -> List[Edge]:
def get_edges_of_node(self, node_id: str) -> List[Edge]:
node = self.get_node(node_id)
if node:
return node.skeleton_edges
else:
raise ValueError(f"Node with id {node_id} does not exist.")
def get_edges(self)-> List[Edge]:
return edges
async def project_graph_from_db(self,
adapter: Union[GraphDBInterface],
node_properties_to_project: List[str],
@ -53,7 +59,7 @@ class CogneeGraph(CogneeAbstractGraph):
directed = True,
node_dimension = 1,
edge_dimension = 1,
memory_fragment_filter = List[Dict[str, List[Union[str, int]]]]) -> None:
memory_fragment_filter = []) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise ValueError("Dimensions must be positive integers")
@ -93,3 +99,81 @@ class CogneeGraph(CogneeAbstractGraph):
print(f"Error projecting graph: {e}")
except Exception as ex:
print(f"Unexpected error: {ex}")
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
score = scored_result.score
node =self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
else:
print(f"Node with id {node_id} not found in the graph.")
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None: # :TODO: When we calculate edge embeddings in vector db change this similarly to node mapping
try:
# Step 1: Generate the query embedding
query_vector = await vector_engine.embed_data([query])
query_vector = query_vector[0]
if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.")
# Step 2: Collect all unique relationship types
unique_relationship_types = set()
for edge in self.edges:
relationship_type = edge.attributes.get('relationship_type')
if relationship_type:
unique_relationship_types.add(relationship_type)
# Step 3: Embed all unique relationship types
unique_relationship_types = list(unique_relationship_types)
relationship_type_embeddings = await vector_engine.embed_data(unique_relationship_types)
# Step 4: Map relationship types to their embeddings and calculate distances
embedding_map = {}
for relationship_type, embedding in zip(unique_relationship_types, relationship_type_embeddings):
edge_vector = np.array(embedding)
# Calculate cosine similarity
similarity = np.dot(query_vector, edge_vector) / (
np.linalg.norm(query_vector) * np.linalg.norm(edge_vector)
)
distance = 1 - similarity
# Round the distance to 4 decimal places and store it
embedding_map[relationship_type] = round(distance, 4)
# Step 4: Assign precomputed distances to edges
for edge in self.edges:
relationship_type = edge.attributes.get('relationship_type')
if not relationship_type or relationship_type not in embedding_map:
print(f"Edge {edge} has an unknown or missing relationship type.")
continue
# Assign the precomputed distance
edge.attributes["vector_distance"] = embedding_map[relationship_type]
except Exception as ex:
print(f"Error mapping vector distances to edges: {ex}")
async def calculate_top_triplet_importances(self, k = int) -> List:
min_heap = []
for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id)
source_distance = source_node.attributes.get("vector_distance", 0) if source_node else 0
target_distance = target_node.attributes.get("vector_distance", 0) if target_node else 0
edge_distance = edge.attributes.get("vector_distance", 0)
total_distance = source_distance + target_distance + edge_distance
heapq.heappush(min_heap, (-total_distance, i, edge))
if len(min_heap) > k:
heapq.heappop(min_heap)
return [edge for _, _, edge in sorted(min_heap)]

View file

@ -1,5 +1,5 @@
import numpy as np
from typing import List, Dict, Optional, Any
from typing import List, Dict, Optional, Any, Union
class Node:
"""
@ -21,6 +21,7 @@ class Node:
raise ValueError("Dimension must be a positive integer")
self.id = node_id
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float('inf')
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
@ -55,6 +56,12 @@ class Node:
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str) -> Union[str, int, float]:
return self.attributes[key]
def __repr__(self) -> str:
return f"Node({self.id}, attributes={self.attributes})"
@ -87,6 +94,7 @@ class Edge:
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float('inf')
self.directed = directed
self.status = np.ones(dimension, dtype=int)
@ -95,6 +103,12 @@ class Edge:
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str, value: Any) -> Union[str, int, float]:
return self.attributes[key]
def __repr__(self) -> str:
direction = "->" if self.directed else "--"
return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})"

View file

@ -77,11 +77,11 @@ def test_get_edges_success(setup_graph):
graph.add_node(node2)
edge = Edge(node1, node2)
graph.add_edge(edge)
assert edge in graph.get_edges("node1")
assert edge in graph.get_edges_of_node("node1")
def test_get_edges_nonexistent_node(setup_graph):
"""Test retrieving edges for a nonexistent node raises an exception."""
graph = setup_graph
with pytest.raises(ValueError, match="Node with id nonexistent does not exist."):
graph.get_edges("nonexistent")
graph.get_edges_of_node("nonexistent")