feat: Add columns as nodes in relational db migration (#826)

<!-- .github/pull_request_template.md -->

## Description
Add ability to map column values from relational databases to graph

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Igor Ilic 2025-05-15 07:31:31 -04:00 committed by GitHub
parent 7ac5761040
commit f9f18d1b0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 56 additions and 9 deletions

View file

@ -0,0 +1,9 @@
from cognee.infrastructure.engine import DataPoint
class ColumnValue(DataPoint):
name: str
description: str
properties: str
metadata: dict = {"index_fields": ["properties"]}

View file

@ -3,3 +3,4 @@ from .EntityType import EntityType
from .TableRow import TableRow
from .TableType import TableType
from .node_set import NodeSet
from .ColumnValue import ColumnValue

View file

@ -21,6 +21,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
"TextSummary": "#1077f4",
"TableRow": "#f47710",
"TableType": "#6510f4",
"ColumnValue": "#13613a",
"default": "#D3D3D3",
}

View file

@ -8,12 +8,12 @@ from cognee.infrastructure.databases.relational.get_migration_relational_engine
from cognee.tasks.storage.index_data_points import index_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.modules.engine.models import TableRow, TableType
from cognee.modules.engine.models import TableRow, TableType, ColumnValue
logger = logging.getLogger(__name__)
async def migrate_relational_database(graph_db, schema):
async def migrate_relational_database(graph_db, schema, migrate_column_data=True):
"""
Migrates data from a relational database into a graph database.
@ -95,6 +95,42 @@ async def migrate_relational_database(graph_db, schema):
)
)
# Migrate data stored in columns of table rows
if migrate_column_data:
# Get foreign key columns to filter them out from column migration
foreign_keys = []
for fk in details.get("foreign_keys", []):
foreign_keys.append(fk["ref_column"])
for key, value in row_properties.items():
# Skip mapping primary key information to itself and mapping of foreign key information (as it will be mapped bellow)
if key is primary_key_col or key in foreign_keys:
continue
# Create column value node
column_node_id = f"{table_name}:{key}:{value}"
column_node = ColumnValue(
id=uuid5(NAMESPACE_OID, name=column_node_id),
name=column_node_id,
properties=f"{key} {value} {table_name}",
description=f"Column name={key} and value={value} from column from table={table_name}",
)
node_mapping[column_node_id] = column_node
# Create relationship between column value of table row and table row
edge_mapping.append(
(
row_node.id,
column_node.id,
key,
dict(
relationship_name=key,
source_node_id=row_node.id,
target_node_id=column_node.id,
),
)
)
# Process foreign key relationships after all nodes are created
for table_name, details in schema.items():
# Process foreign key relationships for the current table

View file

@ -112,10 +112,10 @@ async def relational_db_migration():
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
assert len(distinct_node_names) == 8, (
f"Expected 8 distinct node references, found {len(distinct_node_names)}"
assert len(distinct_node_names) == 12, (
f"Expected 12 distinct node references, found {len(distinct_node_names)}"
)
assert len(found_edges) == 7, f"Expected 7 {relationship_label} edges, got {len(found_edges)}"
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
expected_edges = {
("Employee:5", "Employee:2"),
@ -158,8 +158,8 @@ async def relational_db_migration():
# NOTE: Because of the different size of the postgres and sqlite databases,
# different number of nodes and edges are expected
assert node_count == 227, f"Expected 227 nodes, got {node_count}"
assert edge_count == 580, f"Expected 580 edges, got {edge_count}"
assert node_count == 543, f"Expected 543 nodes, got {node_count}"
assert edge_count == 1317, f"Expected 1317 edges, got {edge_count}"
elif migration_db_provider == "postgresql":
if graph_db_provider == "neo4j":
@ -189,8 +189,8 @@ async def relational_db_migration():
# NOTE: Because of the different size of the postgres and sqlite databases,
# different number of nodes and edges are expected
assert node_count == 115, f"Expected 115 nodes, got {node_count}"
assert edge_count == 356, f"Expected 356 edges, got {edge_count}"
assert node_count == 522, f"Expected 522 nodes, got {node_count}"
assert edge_count == 961, f"Expected 961 edges, got {edge_count}"
print(f"Node & edge count validated: node_count={node_count}, edge_count={edge_count}.")