Add utils for graph visualization + classification nodes

This commit is contained in:
Vasilije 2024-03-10 12:15:17 +01:00
parent 4ea9a2c134
commit 2cc4ec7a78
4 changed files with 153 additions and 74 deletions

File diff suppressed because one or more lines are too long

View file

@ -48,6 +48,9 @@ class Config:
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
graphistry_username = os.getenv("GRAPHISTRY_USERNAME")
graphistry_password = os.getenv("GRAPHISTRY_PASSWORD")
# Embedding parameters
embedding_model: str = "openai"
embedding_dim: int = 1536

View file

@ -7,7 +7,9 @@ from cognitive_architecture.infrastructure.databases.graph.get_graph_client impo
from cognitive_architecture.shared.data_models import GraphDBType, DefaultGraphModel, Document, DocumentType, Category, Relationship, UserProperties, UserLocation
def add_classification_nodes(G, id, classification_data):
async def add_classification_nodes(G, id, classification_data):
await G.load_graph_from_file()
context = classification_data['context_name']
layer = classification_data['layer_name']
@ -15,19 +17,19 @@ def add_classification_nodes(G, id, classification_data):
layer_classification_node_id = f'LLM_LAYER_CLASSIFICATION:{context}:{id}'
# Add the node to the graph, unpacking the node data from the dictionary
G.add_node(layer_classification_node_id, **classification_data)
await G.add_node(layer_classification_node_id, **classification_data)
# Link this node to the corresponding document node
G.add_edge(id, layer_classification_node_id, relationship='classified_as')
await G.add_edge(id, layer_classification_node_id, relationship='classified_as')
# Create the detailed classification node ID using the context_name
detailed_classification_node_id = f'LLM_CLASSIFICATION:LAYER:{layer}:{id}'
# Add the detailed classification node, reusing the same node data
G.add_node(detailed_classification_node_id, **classification_data)
await G.add_node(detailed_classification_node_id, **classification_data)
# Link the detailed classification node to the layer classification node
G.add_edge(layer_classification_node_id, detailed_classification_node_id, relationship='contains_analysis')
await G.add_edge(layer_classification_node_id, detailed_classification_node_id, relationship='contains_analysis')
return G
@ -43,6 +45,10 @@ if __name__ == "__main__":
graph_client = get_graph_client(GraphDBType.NETWORKX)
G = asyncio.run(add_classification_nodes(graph_client, 'document_id', {'data_type': 'text',
G = asyncio.run(add_classification_nodes(graph_client, 'Document:doc1', {'data_type': 'text',
'context_name': 'TEXT',
'layer_name': 'Articles, essays, and reports'}))
'layer_name': 'Articles, essays, and reports'}))
from cognitive_architecture.utils import render_graph
ff = asyncio.run( render_graph(G.graph, graph_type='networkx'))
print(ff)

View file

@ -4,6 +4,7 @@ import os
import random
import string
import uuid
import graphistry
from pathlib import Path
from jinja2 import Environment, FileSystemLoader, select_autoescape
from graphviz import Digraph
@ -24,6 +25,12 @@ from cognitive_architecture.database.relationaldb.database_crud import (
fetch_job_id,
)
from cognitive_architecture.config import Config
config = Config()
config.load()
class Node:
def __init__(self, id, description, color):
self.id = id
@ -346,4 +353,41 @@ async def async_render_template(filename: str, context: dict) -> str:
# Render the template with the provided context
rendered_template = template.render(context)
return rendered_template
return rendered_template
async def render_graph(graph, graph_type):
# Authenticate with your Graphistry API key
import networkx as nx
import pandas as pd
graphistry.register(api=3, username=config.graphistry_username, password=config.graphistry_password)
# Convert the NetworkX graph to a Pandas DataFrame representing the edge list
edges = nx.to_pandas_edgelist(graph)
# Visualize the graph using Graphistry
plotter = graphistry.edges(edges, 'source', 'target')
# Visualize the graph (this will open a URL in your default web browser)
url = plotter.plot(render=False, as_files=True)
print(f"Graph is visualized at: {url}")
# import networkx as nx
# # Create a simple NetworkX graph
# G = nx.Graph()
#
# # Add nodes
# G.add_node(1)
# G.add_node(2)
#
# # Add an edge between nodes
# G.add_edge(1, 2)
#
# import asyncio
#
# # Define the graph type (for this example, it's just a placeholder as the function doesn't use it yet)
# graph_type = "simple"
#
# # Call the render_graph function
# asyncio.run(render_graph(G, graph_type))