From a43f19cc5914f46d9a83e1b3e7a25189f89cf252 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sat, 13 Sep 2025 14:33:06 +0530 Subject: [PATCH] ingest_database_schema with a slight alteration with return value as Dict[str,List[DataPoint] | DataPoint]] --- cognee/tasks/schema/ingest_database_schema.py | 63 ++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 6f9c538cd..2b9cd38c5 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,12 +1,17 @@ from typing import List, Dict from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.infrastructure.databases.relational.get_migration_relational_engine import get_migration_relational_engine +from sqlalchemy import text +from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship +from cognee.infrastructure.databases.relational.create_relational_engine import create_relational_engine +from datetime import datetime async def ingest_database_schema( database_config: Dict, schema_name: str = "default", max_sample_rows: int = 5, node_set: List[str] = ["database_schema"] -) -> List[DataPoint]: +) -> Dict[str, List[DataPoint]|DataPoint]: """ Ingest database schema with sample data into dedicated nodeset @@ -19,4 +24,58 @@ async def ingest_database_schema( Returns: List of created DataPoint objects """ - pass \ No newline at end of file + engine = create_relational_engine( + db_path=database_config.get("db_path", ""), + db_name=database_config.get("db_name", "cognee_db"), + db_host=database_config.get("db_host"), + db_port=database_config.get("db_port"), + db_username=database_config.get("db_username"), + db_password=database_config.get("db_password"), + db_provider=database_config.get("db_provider", "sqlite"), + ) + schema = await engine.extract_schema() + tables={} + sample_data={} + schema_tables = [] + schema_relationships = [] + async with engine.engine.begin() as cursor: + for table_name, details in schema.items(): + rows_result = await cursor.execute(text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")) + rows = [dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall()] + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) + row_count_estimate = count_result.scalar() + schema_table = SchemaTable( + table_name=table_name, + schema_name=schema_name, + columns=details["columns"], + primary_key=details.get("primary_key"), + foreign_keys=details.get("foreign_keys", []), + sample_rows=rows, + row_count_estimate=row_count_estimate + ) + schema_tables.append(schema_table) + tables[table_name] = details + sample_data[table_name] = rows + + for fk in details.get("foreign_keys",[]): + relationship = SchemaRelationship( + source_table=table_name, + target_table=fk["ref_table"], + relationship_type=fk["type"], + source_column=fk["source_column"], + target_column=fk["target_column"] + ) + schema_relationships.append(relationship) + database_schema = DatabaseSchema( + schema_name=schema_name, + database_type=database_config.get("db_provider", "sqlite"), + tables=tables, + sample_data=sample_data, + extraction_timestamp=datetime.utcnow() + ) + + return{ + "database_schema": database_schema, + "schema_tables": schema_tables, + "relationships": schema_relationships + } \ No newline at end of file