done with ruff checks

This commit is contained in:
Geoff-Robin 2025-09-14 21:56:31 +05:30 committed by Igor Ilic
parent 51dfac359d
commit 1ba9e1df31
3 changed files with 95 additions and 67 deletions

View file

@ -16,7 +16,9 @@ from cognee.modules.engine.models import TableRow, TableType, ColumnValue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def migrate_relational_database(graph_db, schema, migrate_column_data=True,schema_only=False): async def migrate_relational_database(
graph_db, schema, migrate_column_data=True, schema_only=False
):
""" """
Migrates data from a relational database into a graph database. Migrates data from a relational database into a graph database.
@ -33,15 +35,15 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
# Create a mapping of node_id to node objects for referencing in edge creation # Create a mapping of node_id to node objects for referencing in edge creation
node_mapping = {} node_mapping = {}
edge_mapping = [] edge_mapping = []
if schema_only: if schema_only:
database_config = get_migration_config().to_dict() database_config = get_migration_config().to_dict()
# Calling the ingest_database_schema function to return DataPoint subclasses # Calling the ingest_database_schema function to return DataPoint subclasses
result = await ingest_database_schema( result = await ingest_database_schema(
database_config=database_config, database_config=database_config,
schema_name="migrated_schema", schema_name="migrated_schema",
max_sample_rows=5, max_sample_rows=5,
node_set=["database_schema", "schema_tables", "relationships"] node_set=["database_schema", "schema_tables", "relationships"],
) )
database_schema = result["database_schema"] database_schema = result["database_schema"]
schema_tables = result["schema_tables"] schema_tables = result["schema_tables"]
@ -51,57 +53,64 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
for table in schema_tables: for table in schema_tables:
table_node_id = table.id table_node_id = table.id
# Add TableSchema Datapoint as a node. # Add TableSchema Datapoint as a node.
node_mapping[table_node_id]=table node_mapping[table_node_id] = table
edge_mapping.append(( edge_mapping.append(
table_node_id, (
database_node_id, table_node_id,
"is_part_of", database_node_id,
dict( "is_part_of",
source_node_id=table_node_id, dict(
target_node_id=database_node_id, source_node_id=table_node_id,
relationship_name="is_part_of", target_node_id=database_node_id,
), relationship_name="is_part_of",
)) ),
for rel in schema_relationships:
source_table_id = uuid5(NAMESPACE_OID,name=rel.source_table)
target_table_id = uuid5(NAMESPACE_OID,name=rel.target_table)
relationship_id = rel.id
# Add RelationshipTable DataPoint as a node.
node_mapping[relationship_id]=rel
edge_mapping.append((
source_table_id,
relationship_id,
"has_relationship",
dict(
source_node_id=source_table_id,
target_node_id=relationship_id,
relationship_name=rel.relationship_type,
),
))
edge_mapping.append((
relationship_id,
target_table_id,
"has_relationship",
dict(
source_node_id=relationship_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
) )
)) )
edge_mapping.append(( for rel in schema_relationships:
source_table_id, source_table_id = uuid5(NAMESPACE_OID, name=rel.source_table)
target_table_id, target_table_id = uuid5(NAMESPACE_OID, name=rel.target_table)
rel.relationship_type,
dict( relationship_id = rel.id
source_node_id=source_table_id,
target_node_id=target_table_id, # Add RelationshipTable DataPoint as a node.
relationship_name=rel.relationship_type, node_mapping[relationship_id] = rel
), edge_mapping.append(
)) (
source_table_id,
relationship_id,
"has_relationship",
dict(
source_node_id=source_table_id,
target_node_id=relationship_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
relationship_id,
target_table_id,
"has_relationship",
dict(
source_node_id=relationship_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
source_table_id,
target_table_id,
rel.relationship_type,
dict(
source_node_id=source_table_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
else: else:
async with engine.engine.begin() as cursor: async with engine.engine.begin() as cursor:
# First, create table type nodes for all tables # First, create table type nodes for all tables

View file

@ -1,27 +1,32 @@
from typing import List, Dict from typing import List, Dict
from uuid import uuid5, NAMESPACE_OID from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational.get_migration_relational_engine import get_migration_relational_engine from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine,
)
from sqlalchemy import text from sqlalchemy import text
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
from cognee.infrastructure.databases.relational.create_relational_engine import create_relational_engine from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
from datetime import datetime from datetime import datetime
async def ingest_database_schema( async def ingest_database_schema(
database_config: Dict, database_config: Dict,
schema_name: str = "default", schema_name: str = "default",
max_sample_rows: int = 5, max_sample_rows: int = 5,
node_set: List[str] = ["database_schema"] node_set: List[str] = ["database_schema"],
) -> Dict[str, List[DataPoint] | DataPoint]: ) -> Dict[str, List[DataPoint] | DataPoint]:
""" """
Ingest database schema with sample data into dedicated nodeset Ingest database schema with sample data into dedicated nodeset
Args: Args:
database_config: Database connection configuration database_config: Database connection configuration
schema_name: Name identifier for this schema schema_name: Name identifier for this schema
max_sample_rows: Maximum sample rows per table max_sample_rows: Maximum sample rows per table
node_set: Target nodeset (default: ["database_schema"]) node_set: Target nodeset (default: ["database_schema"])
Returns: Returns:
List of created DataPoint objects List of created DataPoint objects
""" """
@ -42,8 +47,13 @@ async def ingest_database_schema(
async with engine.engine.begin() as cursor: async with engine.engine.begin() as cursor:
for table_name, details in schema.items(): for table_name, details in schema.items():
rows_result = await cursor.execute(text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")) rows_result = await cursor.execute(
rows = [dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall()] text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")
)
rows = [
dict(zip([col["name"] for col in details["columns"]], row))
for row in rows_result.fetchall()
]
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};"))
row_count_estimate = count_result.scalar() row_count_estimate = count_result.scalar()
@ -56,7 +66,7 @@ async def ingest_database_schema(
foreign_keys=details.get("foreign_keys", []), foreign_keys=details.get("foreign_keys", []),
sample_rows=rows, sample_rows=rows,
row_count_estimate=row_count_estimate, row_count_estimate=row_count_estimate,
description=f"" description=f"Schema table for '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows.",
) )
schema_tables.append(schema_table) schema_tables.append(schema_table)
tables[table_name] = details tables[table_name] = details
@ -64,13 +74,16 @@ async def ingest_database_schema(
for fk in details.get("foreign_keys", []): for fk in details.get("foreign_keys", []):
relationship = SchemaRelationship( relationship = SchemaRelationship(
id=uuid5(NAMESPACE_OID, name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}"), id=uuid5(
NAMESPACE_OID,
name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}",
),
source_table=table_name, source_table=table_name,
target_table=fk["ref_table"], target_table=fk["ref_table"],
relationship_type="foreign_key", relationship_type="foreign_key",
source_column=fk["column"], source_column=fk["column"],
target_column=fk["ref_column"], target_column=fk["ref_column"],
description=f"" description=f"Foreign key relationship: {table_name}.{fk['column']}{fk['ref_table']}.{fk['ref_column']}",
) )
schema_relationships.append(relationship) schema_relationships.append(relationship)
@ -81,11 +94,11 @@ async def ingest_database_schema(
tables=tables, tables=tables,
sample_data=sample_data, sample_data=sample_data,
extraction_timestamp=datetime.utcnow(), extraction_timestamp=datetime.utcnow(),
description=f"" description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.",
) )
return { return {
"database_schema": database_schema, "database_schema": database_schema,
"schema_tables": schema_tables, "schema_tables": schema_tables,
"relationships": schema_relationships "relationships": schema_relationships,
} }

View file

@ -2,8 +2,10 @@ from cognee.infrastructure.engine.models.DataPoint import DataPoint
from typing import List, Dict, Optional from typing import List, Dict, Optional
from datetime import datetime from datetime import datetime
class DatabaseSchema(DataPoint): class DatabaseSchema(DataPoint):
"""Represents a complete database schema with sample data""" """Represents a complete database schema with sample data"""
schema_name: str schema_name: str
database_type: str # sqlite, postgres, etc. database_type: str # sqlite, postgres, etc.
tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter
@ -12,8 +14,10 @@ class DatabaseSchema(DataPoint):
description: str description: str
metadata: dict = {"index_fields": ["schema_name", "database_type"]} metadata: dict = {"index_fields": ["schema_name", "database_type"]}
class SchemaTable(DataPoint): class SchemaTable(DataPoint):
"""Represents an individual table schema with relationships""" """Represents an individual table schema with relationships"""
table_name: str table_name: str
schema_name: str schema_name: str
columns: List[Dict] # Column definitions with types columns: List[Dict] # Column definitions with types
@ -24,12 +28,14 @@ class SchemaTable(DataPoint):
description: str description: str
metadata: dict = {"index_fields": ["table_name", "schema_name"]} metadata: dict = {"index_fields": ["table_name", "schema_name"]}
class SchemaRelationship(DataPoint): class SchemaRelationship(DataPoint):
"""Represents relationships between tables""" """Represents relationships between tables"""
source_table: str source_table: str
target_table: str target_table: str
relationship_type: str # "foreign_key", "one_to_many", etc. relationship_type: str # "foreign_key", "one_to_many", etc.
source_column: str source_column: str
target_column: str target_column: str
description: str description: str
metadata: dict = {"index_fields": ["source_table", "target_table"]} metadata: dict = {"index_fields": ["source_table", "target_table"]}