feat: implements DB projection to memory

This commit is contained in:
hajdul88 2024-11-13 16:38:57 +01:00
parent 68bfb87f3a
commit 8e3a991dd0

View file

@ -13,13 +13,22 @@ class CogneeGraph(CogneeAbstractGraph):
This class provides the functionality to manage nodes and edges,
and project a graph from a database using adapters.
"""
nodes: Dict[str, Node]
edges: List[Edge]
directed: bool
def __init__(self, directed: bool = True):
self.nodes = {}
self.edges = []
self.directed = directed
def add_node(self, node: Node) -> None:
if node.id not in self.nodes:
self.nodes[node.id] = node
else:
raise ValueError(f"Node with id {node.id} already exists.")
# :TODO ADD dimension
def add_edge(self, edge: Edge) -> None:
if edge not in self.edges:
self.edges.append(edge)
@ -38,34 +47,42 @@ class CogneeGraph(CogneeAbstractGraph):
else:
raise ValueError(f"Node with id {node_id} does not exist.")
# :TODO This should take also the list of entity types and connection types to keep. (Maybe we dont need all and can keep just an abstraction of the db network)
async def project_graph_from_db(self, adapter: Union[Neo4jAdapter, NetworkXAdapter]) -> None:
async def project_graph_from_db(self,
adapter: Union[NetworkXAdapter, Neo4jAdapter],
node_properties_to_project: List[str],
edge_properties_to_project: List[str],
directed = True,
node_dimension = 1,
edge_dimension = 1) -> None:
try:
nodes_data, edges_data = await adapter.get_graph_data()
# :TODO: Handle networkx and Neo4j separately
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data:
raise ValueError("No node data retrieved from the database.")
if not edges_data:
raise ValueError("No edge data retrieved from the database.")
raise NotImplementedError("To be implemented...tomorrow")
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
for source_id, target_id, relationship_type, properties in edges_data:
source_node = self.get_node(str(source_id))
target_node = self.get_node(str(target_id))
if source_node and target_node:
edge_attributes = {key: properties.get(key) for key in edge_properties_to_project}
edge_attributes['relationship_type'] = relationship_type
"""
The following code only used for test purposes and will be deleted later
"""
import asyncio
edge = Edge(source_node, target_node, attributes=edge_attributes, directed=directed, dimension=edge_dimension)
self.add_edge(edge)
async def main():
# Choose the adapter (Neo4j or NetworkX)
adapter = await get_graph_engine()
source_node.add_skeleton_edge(edge)
target_node.add_skeleton_edge(edge)
# Create an instance of CogneeGraph
graph = CogneeGraph()
else:
raise ValueError(f"Edge references nonexistent nodes: {source_id} -> {target_id}")
# Project the graph from the database
await graph.project_graph_from_db(adapter)
# Access nodes and edges
print(f"Graph has {len(graph.nodes)} nodes and {len(graph.edges)} edges.")
print("Sample node:", graph.get_node("node1"))
print("Edges for node1:", graph.get_edges("node1"))
# Run the main function
asyncio.run(main())
except (ValueError, TypeError) as e:
print(f"Error projecting graph: {e}")
except Exception as ex:
print(f"Unexpected error: {ex}")