From 1ba9e1df317810b0ba796dd3fe75b8ca4c61cb89 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sun, 14 Sep 2025 21:56:31 +0530 Subject: [PATCH] done with ruff checks --- .../ingestion/migrate_relational_database.py | 117 ++++++++++-------- cognee/tasks/schema/ingest_database_schema.py | 37 ++++-- cognee/tasks/schema/models.py | 8 +- 3 files changed, 95 insertions(+), 67 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index e535a0ed8..62a8a0eac 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -16,7 +16,9 @@ from cognee.modules.engine.models import TableRow, TableType, ColumnValue logger = logging.getLogger(__name__) -async def migrate_relational_database(graph_db, schema, migrate_column_data=True,schema_only=False): +async def migrate_relational_database( + graph_db, schema, migrate_column_data=True, schema_only=False +): """ Migrates data from a relational database into a graph database. @@ -33,15 +35,15 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True # Create a mapping of node_id to node objects for referencing in edge creation node_mapping = {} edge_mapping = [] - + if schema_only: - database_config = get_migration_config().to_dict() + database_config = get_migration_config().to_dict() # Calling the ingest_database_schema function to return DataPoint subclasses result = await ingest_database_schema( database_config=database_config, schema_name="migrated_schema", max_sample_rows=5, - node_set=["database_schema", "schema_tables", "relationships"] + node_set=["database_schema", "schema_tables", "relationships"], ) database_schema = result["database_schema"] schema_tables = result["schema_tables"] @@ -51,57 +53,64 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True for table in schema_tables: table_node_id = table.id # Add TableSchema Datapoint as a node. - node_mapping[table_node_id]=table - edge_mapping.append(( - table_node_id, - database_node_id, - "is_part_of", - dict( - source_node_id=table_node_id, - target_node_id=database_node_id, - relationship_name="is_part_of", - ), - )) - for rel in schema_relationships: - source_table_id = uuid5(NAMESPACE_OID,name=rel.source_table) - target_table_id = uuid5(NAMESPACE_OID,name=rel.target_table) - relationship_id = rel.id - - # Add RelationshipTable DataPoint as a node. - node_mapping[relationship_id]=rel - edge_mapping.append(( - source_table_id, - relationship_id, - "has_relationship", - dict( - source_node_id=source_table_id, - target_node_id=relationship_id, - relationship_name=rel.relationship_type, - ), - )) - edge_mapping.append(( - relationship_id, - target_table_id, - "has_relationship", - dict( - source_node_id=relationship_id, - target_node_id=target_table_id, - relationship_name=rel.relationship_type, + node_mapping[table_node_id] = table + edge_mapping.append( + ( + table_node_id, + database_node_id, + "is_part_of", + dict( + source_node_id=table_node_id, + target_node_id=database_node_id, + relationship_name="is_part_of", + ), ) - )) - edge_mapping.append(( - source_table_id, - target_table_id, - rel.relationship_type, - dict( - source_node_id=source_table_id, - target_node_id=target_table_id, - relationship_name=rel.relationship_type, - ), - )) - - - + ) + for rel in schema_relationships: + source_table_id = uuid5(NAMESPACE_OID, name=rel.source_table) + target_table_id = uuid5(NAMESPACE_OID, name=rel.target_table) + + relationship_id = rel.id + + # Add RelationshipTable DataPoint as a node. + node_mapping[relationship_id] = rel + edge_mapping.append( + ( + source_table_id, + relationship_id, + "has_relationship", + dict( + source_node_id=source_table_id, + target_node_id=relationship_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + relationship_id, + target_table_id, + "has_relationship", + dict( + source_node_id=relationship_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + source_table_id, + target_table_id, + rel.relationship_type, + dict( + source_node_id=source_table_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + else: async with engine.engine.begin() as cursor: # First, create table type nodes for all tables diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 2ac57d0ba..c4c13449d 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,27 +1,32 @@ from typing import List, Dict from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint -from cognee.infrastructure.databases.relational.get_migration_relational_engine import get_migration_relational_engine +from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( + get_migration_relational_engine, +) from sqlalchemy import text from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship -from cognee.infrastructure.databases.relational.create_relational_engine import create_relational_engine +from cognee.infrastructure.databases.relational.create_relational_engine import ( + create_relational_engine, +) from datetime import datetime + async def ingest_database_schema( database_config: Dict, schema_name: str = "default", max_sample_rows: int = 5, - node_set: List[str] = ["database_schema"] + node_set: List[str] = ["database_schema"], ) -> Dict[str, List[DataPoint] | DataPoint]: """ Ingest database schema with sample data into dedicated nodeset - + Args: database_config: Database connection configuration schema_name: Name identifier for this schema max_sample_rows: Maximum sample rows per table node_set: Target nodeset (default: ["database_schema"]) - + Returns: List of created DataPoint objects """ @@ -42,8 +47,13 @@ async def ingest_database_schema( async with engine.engine.begin() as cursor: for table_name, details in schema.items(): - rows_result = await cursor.execute(text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")) - rows = [dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall()] + rows_result = await cursor.execute( + text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}") + ) + rows = [ + dict(zip([col["name"] for col in details["columns"]], row)) + for row in rows_result.fetchall() + ] count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) row_count_estimate = count_result.scalar() @@ -56,7 +66,7 @@ async def ingest_database_schema( foreign_keys=details.get("foreign_keys", []), sample_rows=rows, row_count_estimate=row_count_estimate, - description=f"" + description=f"Schema table for '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows.", ) schema_tables.append(schema_table) tables[table_name] = details @@ -64,13 +74,16 @@ async def ingest_database_schema( for fk in details.get("foreign_keys", []): relationship = SchemaRelationship( - id=uuid5(NAMESPACE_OID, name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}"), + id=uuid5( + NAMESPACE_OID, + name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}", + ), source_table=table_name, target_table=fk["ref_table"], relationship_type="foreign_key", source_column=fk["column"], target_column=fk["ref_column"], - description=f"" + description=f"Foreign key relationship: {table_name}.{fk['column']} → {fk['ref_table']}.{fk['ref_column']}", ) schema_relationships.append(relationship) @@ -81,11 +94,11 @@ async def ingest_database_schema( tables=tables, sample_data=sample_data, extraction_timestamp=datetime.utcnow(), - description=f"" + description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.", ) return { "database_schema": database_schema, "schema_tables": schema_tables, - "relationships": schema_relationships + "relationships": schema_relationships, } diff --git a/cognee/tasks/schema/models.py b/cognee/tasks/schema/models.py index 0fb248758..423c92050 100644 --- a/cognee/tasks/schema/models.py +++ b/cognee/tasks/schema/models.py @@ -2,8 +2,10 @@ from cognee.infrastructure.engine.models.DataPoint import DataPoint from typing import List, Dict, Optional from datetime import datetime + class DatabaseSchema(DataPoint): """Represents a complete database schema with sample data""" + schema_name: str database_type: str # sqlite, postgres, etc. tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter @@ -12,8 +14,10 @@ class DatabaseSchema(DataPoint): description: str metadata: dict = {"index_fields": ["schema_name", "database_type"]} + class SchemaTable(DataPoint): """Represents an individual table schema with relationships""" + table_name: str schema_name: str columns: List[Dict] # Column definitions with types @@ -24,12 +28,14 @@ class SchemaTable(DataPoint): description: str metadata: dict = {"index_fields": ["table_name", "schema_name"]} + class SchemaRelationship(DataPoint): """Represents relationships between tables""" + source_table: str target_table: str relationship_type: str # "foreign_key", "one_to_many", etc. source_column: str target_column: str description: str - metadata: dict = {"index_fields": ["source_table", "target_table"]} \ No newline at end of file + metadata: dict = {"index_fields": ["source_table", "target_table"]}