refactor: refactor schema migration

This commit is contained in:
Igor Ilic 2025-09-27 00:41:58 +02:00
parent 8ff58f0278
commit f93d30ae77
3 changed files with 16 additions and 20 deletions

View file

@ -32,7 +32,7 @@ async def migrate_relational_database(
""" """
# Create a mapping of node_id to node objects for referencing in edge creation # Create a mapping of node_id to node objects for referencing in edge creation
if schema_only: if schema_only:
node_mapping, edge_mapping = await schema_only_ingestion() node_mapping, edge_mapping = await schema_only_ingestion(schema)
else: else:
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data) node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
@ -74,13 +74,13 @@ async def migrate_relational_database(
return await graph_db.get_graph_data() return await graph_db.get_graph_data()
async def schema_only_ingestion(): async def schema_only_ingestion(schema):
node_mapping = {} node_mapping = {}
edge_mapping = [] edge_mapping = []
database_config = get_migration_config().to_dict()
# Calling the ingest_database_schema function to return DataPoint subclasses # Calling the ingest_database_schema function to return DataPoint subclasses
result = await ingest_database_schema( result = await ingest_database_schema(
database_config=database_config, schema=schema,
schema_name="migrated_schema", schema_name="migrated_schema",
max_sample_rows=5, max_sample_rows=5,
) )

View file

@ -3,14 +3,15 @@ from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.infrastructure.engine.models.DataPoint import DataPoint
from sqlalchemy import text from sqlalchemy import text
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
from cognee.infrastructure.databases.relational.create_relational_engine import ( from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
create_relational_engine, get_migration_relational_engine,
) )
from cognee.infrastructure.databases.relational.config import get_migration_config
from datetime import datetime, timezone from datetime import datetime, timezone
async def ingest_database_schema( async def ingest_database_schema(
database_config: Dict, schema,
schema_name: str = "default", schema_name: str = "default",
max_sample_rows: int = 0, max_sample_rows: int = 0,
) -> Dict[str, List[DataPoint] | DataPoint]: ) -> Dict[str, List[DataPoint] | DataPoint]:
@ -28,20 +29,13 @@ async def ingest_database_schema(
"schema_tables": List[SchemaTable] "schema_tables": List[SchemaTable]
"relationships": List[SchemaRelationship] "relationships": List[SchemaRelationship]
""" """
engine = create_relational_engine(
db_path=database_config.get("migration_db_path", ""),
db_name=database_config.get("migration_db_name", "cognee_db"),
db_host=database_config.get("migration_db_host"),
db_port=database_config.get("migration_db_port"),
db_username=database_config.get("migration_db_username"),
db_password=database_config.get("migration_db_password"),
db_provider=database_config.get("migration_db_provider", "sqlite"),
)
schema = await engine.extract_schema()
tables = {} tables = {}
sample_data = {} sample_data = {}
schema_tables = [] schema_tables = []
schema_relationships = [] schema_relationships = []
engine = get_migration_relational_engine()
qi = engine.engine.dialect.identifier_preparer.quote qi = engine.engine.dialect.identifier_preparer.quote
try: try:
max_sample_rows = max(0, int(max_sample_rows)) max_sample_rows = max(0, int(max_sample_rows))
@ -63,7 +57,7 @@ async def ingest_database_schema(
rows = [dict(r) for r in rows_result.mappings().all()] rows = [dict(r) for r in rows_result.mappings().all()]
else: else:
rows = [] rows = []
row_count_estimate = 0
if engine.engine.dialect.name == "postgresql": if engine.engine.dialect.name == "postgresql":
if "." in table_name: if "." in table_name:
schema_part, table_part = table_name.split(".", 1) schema_part, table_part = table_name.split(".", 1)
@ -117,11 +111,12 @@ async def ingest_database_schema(
) )
schema_relationships.append(relationship) schema_relationships.append(relationship)
id_str = f"{database_config.get('migration_db_provider', 'sqlite')}:{database_config.get('migration_db_name', 'cognee_db')}:{schema_name}" migration_config = get_migration_config()
id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}:{schema_name}"
database_schema = DatabaseSchema( database_schema = DatabaseSchema(
id=uuid5(NAMESPACE_OID, name=id_str), id=uuid5(NAMESPACE_OID, name=id_str),
schema_name=schema_name, schema_name=schema_name,
database_type=database_config.get("migration_db_provider", "sqlite"), database_type=migration_config.migration_db_provider,
tables=tables, tables=tables,
sample_data=sample_data, sample_data=sample_data,
extraction_timestamp=datetime.now(timezone.utc), extraction_timestamp=datetime.now(timezone.utc),

View file

@ -212,6 +212,7 @@ async def test_schema_only_migration():
search_results = await cognee.search( search_results = await cognee.search(
query_text="How many tables are there in this database", query_text="How many tables are there in this database",
query_type=cognee.SearchType.GRAPH_COMPLETION, query_type=cognee.SearchType.GRAPH_COMPLETION,
top_k=30,
) )
assert any("11" in r for r in search_results), ( assert any("11" in r for r in search_results), (
"Number of tables in the database reported in search_results is either None or not equal to 11" "Number of tables in the database reported in search_results is either None or not equal to 11"