From f93d30ae77f7232e03110a817226539d5eb4d483 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Sat, 27 Sep 2025 00:41:58 +0200 Subject: [PATCH] refactor: refactor schema migration --- .../ingestion/migrate_relational_database.py | 8 +++--- cognee/tasks/schema/ingest_database_schema.py | 27 ++++++++----------- cognee/tests/test_relational_db_migration.py | 1 + 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 5ee9f5973..83ad452c3 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -32,7 +32,7 @@ async def migrate_relational_database( """ # Create a mapping of node_id to node objects for referencing in edge creation if schema_only: - node_mapping, edge_mapping = await schema_only_ingestion() + node_mapping, edge_mapping = await schema_only_ingestion(schema) else: node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data) @@ -74,13 +74,13 @@ async def migrate_relational_database( return await graph_db.get_graph_data() -async def schema_only_ingestion(): +async def schema_only_ingestion(schema): node_mapping = {} edge_mapping = [] - database_config = get_migration_config().to_dict() + # Calling the ingest_database_schema function to return DataPoint subclasses result = await ingest_database_schema( - database_config=database_config, + schema=schema, schema_name="migrated_schema", max_sample_rows=5, ) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index be9bf6ff1..e89b679d2 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -3,14 +3,15 @@ from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint from sqlalchemy import text from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship -from cognee.infrastructure.databases.relational.create_relational_engine import ( - create_relational_engine, +from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( + get_migration_relational_engine, ) +from cognee.infrastructure.databases.relational.config import get_migration_config from datetime import datetime, timezone async def ingest_database_schema( - database_config: Dict, + schema, schema_name: str = "default", max_sample_rows: int = 0, ) -> Dict[str, List[DataPoint] | DataPoint]: @@ -28,20 +29,13 @@ async def ingest_database_schema( "schema_tables": List[SchemaTable] "relationships": List[SchemaRelationship] """ - engine = create_relational_engine( - db_path=database_config.get("migration_db_path", ""), - db_name=database_config.get("migration_db_name", "cognee_db"), - db_host=database_config.get("migration_db_host"), - db_port=database_config.get("migration_db_port"), - db_username=database_config.get("migration_db_username"), - db_password=database_config.get("migration_db_password"), - db_provider=database_config.get("migration_db_provider", "sqlite"), - ) - schema = await engine.extract_schema() + tables = {} sample_data = {} schema_tables = [] schema_relationships = [] + + engine = get_migration_relational_engine() qi = engine.engine.dialect.identifier_preparer.quote try: max_sample_rows = max(0, int(max_sample_rows)) @@ -63,7 +57,7 @@ async def ingest_database_schema( rows = [dict(r) for r in rows_result.mappings().all()] else: rows = [] - row_count_estimate = 0 + if engine.engine.dialect.name == "postgresql": if "." in table_name: schema_part, table_part = table_name.split(".", 1) @@ -117,11 +111,12 @@ async def ingest_database_schema( ) schema_relationships.append(relationship) - id_str = f"{database_config.get('migration_db_provider', 'sqlite')}:{database_config.get('migration_db_name', 'cognee_db')}:{schema_name}" + migration_config = get_migration_config() + id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}:{schema_name}" database_schema = DatabaseSchema( id=uuid5(NAMESPACE_OID, name=id_str), schema_name=schema_name, - database_type=database_config.get("migration_db_provider", "sqlite"), + database_type=migration_config.migration_db_provider, tables=tables, sample_data=sample_data, extraction_timestamp=datetime.now(timezone.utc), diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index cb360f1c2..2b69ce854 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -212,6 +212,7 @@ async def test_schema_only_migration(): search_results = await cognee.search( query_text="How many tables are there in this database", query_type=cognee.SearchType.GRAPH_COMPLETION, + top_k=30, ) assert any("11" in r for r in search_results), ( "Number of tables in the database reported in search_results is either None or not equal to 11"