From f9f18d1b0cb4b2e4815c7cb8e2e9cf2f80e2b019 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Thu, 15 May 2025 07:31:31 -0400 Subject: [PATCH] feat: Add columns as nodes in relational db migration (#826) ## Description Add ability to map column values from relational databases to graph ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- cognee/modules/engine/models/ColumnValue.py | 9 +++++ cognee/modules/engine/models/__init__.py | 1 + .../cognee_network_visualization.py | 1 + .../ingestion/migrate_relational_database.py | 40 ++++++++++++++++++- cognee/tests/test_relational_db_migration.py | 14 +++---- 5 files changed, 56 insertions(+), 9 deletions(-) create mode 100644 cognee/modules/engine/models/ColumnValue.py diff --git a/cognee/modules/engine/models/ColumnValue.py b/cognee/modules/engine/models/ColumnValue.py new file mode 100644 index 000000000..6ad8d992c --- /dev/null +++ b/cognee/modules/engine/models/ColumnValue.py @@ -0,0 +1,9 @@ +from cognee.infrastructure.engine import DataPoint + + +class ColumnValue(DataPoint): + name: str + description: str + properties: str + + metadata: dict = {"index_fields": ["properties"]} diff --git a/cognee/modules/engine/models/__init__.py b/cognee/modules/engine/models/__init__.py index 4ab2de0de..2535f00f3 100644 --- a/cognee/modules/engine/models/__init__.py +++ b/cognee/modules/engine/models/__init__.py @@ -3,3 +3,4 @@ from .EntityType import EntityType from .TableRow import TableRow from .TableType import TableType from .node_set import NodeSet +from .ColumnValue import ColumnValue diff --git a/cognee/modules/visualization/cognee_network_visualization.py b/cognee/modules/visualization/cognee_network_visualization.py index b0f2f0a1a..83a0f237e 100644 --- a/cognee/modules/visualization/cognee_network_visualization.py +++ b/cognee/modules/visualization/cognee_network_visualization.py @@ -21,6 +21,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = "TextSummary": "#1077f4", "TableRow": "#f47710", "TableType": "#6510f4", + "ColumnValue": "#13613a", "default": "#D3D3D3", } diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 946b5d721..936ea59e0 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -8,12 +8,12 @@ from cognee.infrastructure.databases.relational.get_migration_relational_engine from cognee.tasks.storage.index_data_points import index_data_points from cognee.tasks.storage.index_graph_edges import index_graph_edges -from cognee.modules.engine.models import TableRow, TableType +from cognee.modules.engine.models import TableRow, TableType, ColumnValue logger = logging.getLogger(__name__) -async def migrate_relational_database(graph_db, schema): +async def migrate_relational_database(graph_db, schema, migrate_column_data=True): """ Migrates data from a relational database into a graph database. @@ -95,6 +95,42 @@ async def migrate_relational_database(graph_db, schema): ) ) + # Migrate data stored in columns of table rows + if migrate_column_data: + # Get foreign key columns to filter them out from column migration + foreign_keys = [] + for fk in details.get("foreign_keys", []): + foreign_keys.append(fk["ref_column"]) + + for key, value in row_properties.items(): + # Skip mapping primary key information to itself and mapping of foreign key information (as it will be mapped bellow) + if key is primary_key_col or key in foreign_keys: + continue + + # Create column value node + column_node_id = f"{table_name}:{key}:{value}" + column_node = ColumnValue( + id=uuid5(NAMESPACE_OID, name=column_node_id), + name=column_node_id, + properties=f"{key} {value} {table_name}", + description=f"Column name={key} and value={value} from column from table={table_name}", + ) + node_mapping[column_node_id] = column_node + + # Create relationship between column value of table row and table row + edge_mapping.append( + ( + row_node.id, + column_node.id, + key, + dict( + relationship_name=key, + source_node_id=row_node.id, + target_node_id=column_node.id, + ), + ) + ) + # Process foreign key relationships after all nodes are created for table_name, details in schema.items(): # Process foreign key relationships for the current table diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index b9c77716f..8a9670a7c 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -112,10 +112,10 @@ async def relational_db_migration(): else: raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") - assert len(distinct_node_names) == 8, ( - f"Expected 8 distinct node references, found {len(distinct_node_names)}" + assert len(distinct_node_names) == 12, ( + f"Expected 12 distinct node references, found {len(distinct_node_names)}" ) - assert len(found_edges) == 7, f"Expected 7 {relationship_label} edges, got {len(found_edges)}" + assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}" expected_edges = { ("Employee:5", "Employee:2"), @@ -158,8 +158,8 @@ async def relational_db_migration(): # NOTE: Because of the different size of the postgres and sqlite databases, # different number of nodes and edges are expected - assert node_count == 227, f"Expected 227 nodes, got {node_count}" - assert edge_count == 580, f"Expected 580 edges, got {edge_count}" + assert node_count == 543, f"Expected 543 nodes, got {node_count}" + assert edge_count == 1317, f"Expected 1317 edges, got {edge_count}" elif migration_db_provider == "postgresql": if graph_db_provider == "neo4j": @@ -189,8 +189,8 @@ async def relational_db_migration(): # NOTE: Because of the different size of the postgres and sqlite databases, # different number of nodes and edges are expected - assert node_count == 115, f"Expected 115 nodes, got {node_count}" - assert edge_count == 356, f"Expected 356 edges, got {edge_count}" + assert node_count == 522, f"Expected 522 nodes, got {node_count}" + assert edge_count == 961, f"Expected 961 edges, got {edge_count}" print(f"Node & edge count validated: node_count={node_count}, edge_count={edge_count}.")