Added networkx to the graph
This commit is contained in:
parent
4a8db1fe51
commit
0a07b1e96b
1 changed files with 53 additions and 6 deletions
|
|
@ -17,7 +17,7 @@ import openai
|
|||
import instructor
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import pickle
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
|
@ -582,10 +582,23 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
return None
|
||||
|
||||
|
||||
class NetworkXGraphDB(AbstractGraphDB):
|
||||
def __init__(self):
|
||||
self.graph = nx.Graph()
|
||||
# Initialize other necessary properties or configurations
|
||||
class NetworkXGraphDB:
|
||||
def __init__(self, filename='networkx_graph.pkl'):
|
||||
self.filename = filename
|
||||
try:
|
||||
self.graph = self.load_graph() # Attempt to load an existing graph
|
||||
except (FileNotFoundError, EOFError, pickle.UnpicklingError):
|
||||
self.graph = nx.Graph() # Create a new graph if loading failed
|
||||
|
||||
def save_graph(self):
|
||||
""" Save the graph to a file using pickle """
|
||||
with open(self.filename, 'wb') as f:
|
||||
pickle.dump(self.graph, f)
|
||||
|
||||
def load_graph(self):
|
||||
""" Load the graph from a file using pickle """
|
||||
with open(self.filename, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def create_base_cognitive_architecture(self, user_id: str):
|
||||
# Add nodes for user and memory types if they don't exist
|
||||
|
|
@ -599,19 +612,24 @@ class NetworkXGraphDB(AbstractGraphDB):
|
|||
self.graph.add_edge(user_id, f"{user_id}_episodic", relation='HAS_EPISODIC_MEMORY')
|
||||
self.graph.add_edge(user_id, f"{user_id}_buffer", relation='HAS_BUFFER')
|
||||
|
||||
self.save_graph() # Save the graph after modifying it
|
||||
|
||||
def delete_all_user_memories(self, user_id: str):
|
||||
# Remove nodes and edges related to the user's memories
|
||||
for memory_type in ['semantic', 'episodic', 'buffer']:
|
||||
memory_node = f"{user_id}_{memory_type}"
|
||||
self.graph.remove_node(memory_node)
|
||||
|
||||
self.save_graph() # Save the graph after modifying it
|
||||
|
||||
def delete_specific_memory_type(self, user_id: str, memory_type: str):
|
||||
# Remove a specific type of memory node and its related edges
|
||||
memory_node = f"{user_id}_{memory_type.lower()}"
|
||||
if memory_node in self.graph:
|
||||
self.graph.remove_node(memory_node)
|
||||
|
||||
# Methods for retrieving semantic, episodic, and buffer memories
|
||||
self.save_graph() # Save the graph after modifying it
|
||||
|
||||
def retrieve_semantic_memory(self, user_id: str):
|
||||
return [n for n in self.graph.neighbors(f"{user_id}_semantic")]
|
||||
|
||||
|
|
@ -621,6 +639,35 @@ class NetworkXGraphDB(AbstractGraphDB):
|
|||
def retrieve_buffer_memory(self, user_id: str):
|
||||
return [n for n in self.graph.neighbors(f"{user_id}_buffer")]
|
||||
|
||||
def generate_graph_semantic_memory_document_summary(self, document_summary, unique_graphdb_mapping_values, document_namespace, user_id):
|
||||
for node, attributes in unique_graphdb_mapping_values.items():
|
||||
self.graph.add_node(node, **attributes)
|
||||
self.graph.add_edge(f"{user_id}_semantic", node, relation='HAS_KNOWLEDGE')
|
||||
self.save_graph()
|
||||
|
||||
def generate_document_summary(self, document_summary, unique_graphdb_mapping_values, document_namespace, user_id):
|
||||
self.generate_graph_semantic_memory_document_summary(document_summary, unique_graphdb_mapping_values, document_namespace, user_id)
|
||||
|
||||
async def get_document_categories(self, user_id):
|
||||
return [self.graph.nodes[n]['category'] for n in self.graph.neighbors(f"{user_id}_semantic") if 'category' in self.graph.nodes[n]]
|
||||
|
||||
async def get_document_ids(self, user_id, category):
|
||||
return [n for n in self.graph.neighbors(f"{user_id}_semantic") if self.graph.nodes[n].get('category') == category]
|
||||
|
||||
def create_document_node(self, document_summary, user_id):
|
||||
d_id = document_summary['d_id']
|
||||
self.graph.add_node(d_id, **document_summary)
|
||||
self.graph.add_edge(f"{user_id}_semantic", d_id, relation='HAS_DOCUMENT')
|
||||
self.save_graph()
|
||||
|
||||
def update_document_node_with_namespace(self, user_id, vectordb_namespace, document_id):
|
||||
if self.graph.has_node(document_id):
|
||||
self.graph.nodes[document_id]['vectordbNamespace'] = vectordb_namespace
|
||||
self.save_graph()
|
||||
|
||||
def get_namespaces_by_document_category(self, user_id, category):
|
||||
return [self.graph.nodes[n].get('vectordbNamespace') for n in self.graph.neighbors(f"{user_id}_semantic") if self.graph.nodes[n].get('category') == category]
|
||||
|
||||
class GraphDBFactory:
|
||||
def create_graph_db(self, db_type, **kwargs):
|
||||
if db_type == 'neo4j':
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue