solved more nitpick comments

This commit is contained in:
Geoff-Robin 2025-09-15 19:05:00 +05:30 committed by Igor Ilic
parent 1e59f1594c
commit df8b80d4a9

View file

@ -9,13 +9,13 @@ from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelati
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
from datetime import datetime
from datetime import datetime, timezone
async def ingest_database_schema(
database_config: Dict,
schema_name: str = "default",
max_sample_rows: int = 5,
max_sample_rows: int = 0,
) -> Dict[str, List[DataPoint] | DataPoint]:
"""
Ingest database schema with sample data into dedicated nodeset
@ -45,22 +45,41 @@ async def ingest_database_schema(
sample_data = {}
schema_tables = []
schema_relationships = []
qi = engine.engine.dialect.identifier_preparer.quote
def qname(name: str):
split_name = name.split(".")
".".join(qi(p) for p in split_name)
async with engine.engine.begin() as cursor:
for table_name, details in schema.items():
qi = engine.engine.dialect.identifier_preparer.quote
qname = lambda name : ".".join(qi(p) for p in name.split("."))
tn = qname(table_name)
rows_result = await cursor.execute(
text(f"SELECT * FROM {tn} LIMIT :limit;"),
{"limit": max_sample_rows}
)
rows = [
dict(zip([col["name"] for col in details["columns"]], row))
for row in rows_result.fetchall()
]
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};"))
row_count_estimate = count_result.scalar()
if max_sample_rows > 0:
rows_result = await cursor.execute(
text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows}
)
rows = [dict(r) for r in rows_result.mappings().all()]
else:
rows = []
row_count_estimate = 0
if engine.engine.dialect.name == "postegresql":
if "." in table_name:
schema_part, table_part = table_name.split(".", 1)
else:
schema_part, table_part = "public", table_name
estimate = await cursor.execute(
text(
"SELECT reltuples:bigint "
"FROM pg_class c "
"JOIN pg_namespace n ON n.oid = c.relnamespace "
"WHERE n.nspname = :schema AND c.relname = :table"
),
{"schema": schema_part, "table": table_part},
)
row_count_estimate = estimate.scalar()
else:
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};"))
row_count_estimate = count_result.scalar()
schema_table = SchemaTable(
id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"),
@ -79,9 +98,9 @@ async def ingest_database_schema(
for fk in details.get("foreign_keys", []):
ref_table_fq = fk["ref_table"]
if '.' not in ref_table_fq and '.' in table_name:
if "." not in ref_table_fq and "." in table_name:
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
relationship = SchemaRelationship(
id=uuid5(
NAMESPACE_OID,
@ -102,7 +121,7 @@ async def ingest_database_schema(
database_type=database_config.get("migration_db_provider", "sqlite"),
tables=tables,
sample_data=sample_data,
extraction_timestamp=datetime.utcnow(),
extraction_timestamp=datetime.now(timezone.utc),
description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.",
)