diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 2a343ea0d..e362734fc 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -9,13 +9,13 @@ from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelati from cognee.infrastructure.databases.relational.create_relational_engine import ( create_relational_engine, ) -from datetime import datetime +from datetime import datetime, timezone async def ingest_database_schema( database_config: Dict, schema_name: str = "default", - max_sample_rows: int = 5, + max_sample_rows: int = 0, ) -> Dict[str, List[DataPoint] | DataPoint]: """ Ingest database schema with sample data into dedicated nodeset @@ -45,22 +45,41 @@ async def ingest_database_schema( sample_data = {} schema_tables = [] schema_relationships = [] + qi = engine.engine.dialect.identifier_preparer.quote + + def qname(name: str): + split_name = name.split(".") + ".".join(qi(p) for p in split_name) async with engine.engine.begin() as cursor: for table_name, details in schema.items(): - qi = engine.engine.dialect.identifier_preparer.quote - qname = lambda name : ".".join(qi(p) for p in name.split(".")) tn = qname(table_name) - rows_result = await cursor.execute( - text(f"SELECT * FROM {tn} LIMIT :limit;"), - {"limit": max_sample_rows} - ) - rows = [ - dict(zip([col["name"] for col in details["columns"]], row)) - for row in rows_result.fetchall() - ] - count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) - row_count_estimate = count_result.scalar() + if max_sample_rows > 0: + rows_result = await cursor.execute( + text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows} + ) + rows = [dict(r) for r in rows_result.mappings().all()] + else: + rows = [] + row_count_estimate = 0 + if engine.engine.dialect.name == "postegresql": + if "." in table_name: + schema_part, table_part = table_name.split(".", 1) + else: + schema_part, table_part = "public", table_name + estimate = await cursor.execute( + text( + "SELECT reltuples:bigint " + "FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relname = :table" + ), + {"schema": schema_part, "table": table_part}, + ) + row_count_estimate = estimate.scalar() + else: + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) + row_count_estimate = count_result.scalar() schema_table = SchemaTable( id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"), @@ -79,9 +98,9 @@ async def ingest_database_schema( for fk in details.get("foreign_keys", []): ref_table_fq = fk["ref_table"] - if '.' not in ref_table_fq and '.' in table_name: + if "." not in ref_table_fq and "." in table_name: ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}" - + relationship = SchemaRelationship( id=uuid5( NAMESPACE_OID, @@ -102,7 +121,7 @@ async def ingest_database_schema( database_type=database_config.get("migration_db_provider", "sqlite"), tables=tables, sample_data=sample_data, - extraction_timestamp=datetime.utcnow(), + extraction_timestamp=datetime.now(timezone.utc), description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.", )