diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index e857ab34d..824fef2fa 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -43,7 +43,6 @@ async def migrate_relational_database( database_config=database_config, schema_name="migrated_schema", max_sample_rows=5, - node_set=["database_schema", "schema_tables", "relationships"], ) database_schema = result["database_schema"] schema_tables = result["schema_tables"] diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index e80ce2e75..be544408b 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Optional from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( @@ -16,7 +16,6 @@ async def ingest_database_schema( database_config: Dict, schema_name: str = "default", max_sample_rows: int = 5, - node_set: List[str] = ["database_schema"], ) -> Dict[str, List[DataPoint] | DataPoint]: """ Ingest database schema with sample data into dedicated nodeset @@ -25,7 +24,6 @@ async def ingest_database_schema( database_config: Database connection configuration schema_name: Name identifier for this schema max_sample_rows: Maximum sample rows per table - node_set: Target nodeset (default: ["database_schema"]) Returns: List of created DataPoint objects @@ -48,6 +46,8 @@ async def ingest_database_schema( async with engine.engine.begin() as cursor: for table_name, details in schema.items(): qi = engine.engine.dialect.identifier_preparer.quote + qname = lambda name : ".".join(qi(p) for p in name.split(".")) + tn = qname(table_name) tn = qi(table_name) rows_result = await cursor.execute( text(f"SELECT * FROM {tn} LIMIT :limit;"), @@ -57,11 +57,11 @@ async def ingest_database_schema( dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall() ] - count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) row_count_estimate = count_result.scalar() schema_table = SchemaTable( - id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"), + id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{tn}"), table_name=table_name, schema_name=schema_name, columns=details["columns"], @@ -76,17 +76,21 @@ async def ingest_database_schema( sample_data[table_name] = rows for fk in details.get("foreign_keys", []): + ref_table_fq = fk["ref_table"] + if '.' not in ref_table_fq and '.' in table_name: + ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}" + relationship = SchemaRelationship( id=uuid5( NAMESPACE_OID, - name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}", + name=f"{schema_name}:{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}", ), source_table=table_name, - target_table=fk["ref_table"], + target_table=ref_table_fq, relationship_type="foreign_key", source_column=fk["column"], target_column=fk["ref_column"], - description=f"Foreign key relationship: {table_name}.{fk['column']} → {fk['ref_table']}.{fk['ref_column']}", + description=f"Foreign key relationship: {table_name}.{fk['column']} → {ref_table_fq}.{fk['ref_column']}", ) schema_relationships.append(relationship)