solved more nitpick comments
This commit is contained in:
parent
1e59f1594c
commit
df8b80d4a9
1 changed files with 36 additions and 17 deletions
|
|
@ -9,13 +9,13 @@ from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelati
|
|||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||
create_relational_engine,
|
||||
)
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
async def ingest_database_schema(
|
||||
database_config: Dict,
|
||||
schema_name: str = "default",
|
||||
max_sample_rows: int = 5,
|
||||
max_sample_rows: int = 0,
|
||||
) -> Dict[str, List[DataPoint] | DataPoint]:
|
||||
"""
|
||||
Ingest database schema with sample data into dedicated nodeset
|
||||
|
|
@ -45,22 +45,41 @@ async def ingest_database_schema(
|
|||
sample_data = {}
|
||||
schema_tables = []
|
||||
schema_relationships = []
|
||||
qi = engine.engine.dialect.identifier_preparer.quote
|
||||
|
||||
def qname(name: str):
|
||||
split_name = name.split(".")
|
||||
".".join(qi(p) for p in split_name)
|
||||
|
||||
async with engine.engine.begin() as cursor:
|
||||
for table_name, details in schema.items():
|
||||
qi = engine.engine.dialect.identifier_preparer.quote
|
||||
qname = lambda name : ".".join(qi(p) for p in name.split("."))
|
||||
tn = qname(table_name)
|
||||
rows_result = await cursor.execute(
|
||||
text(f"SELECT * FROM {tn} LIMIT :limit;"),
|
||||
{"limit": max_sample_rows}
|
||||
)
|
||||
rows = [
|
||||
dict(zip([col["name"] for col in details["columns"]], row))
|
||||
for row in rows_result.fetchall()
|
||||
]
|
||||
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};"))
|
||||
row_count_estimate = count_result.scalar()
|
||||
if max_sample_rows > 0:
|
||||
rows_result = await cursor.execute(
|
||||
text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows}
|
||||
)
|
||||
rows = [dict(r) for r in rows_result.mappings().all()]
|
||||
else:
|
||||
rows = []
|
||||
row_count_estimate = 0
|
||||
if engine.engine.dialect.name == "postegresql":
|
||||
if "." in table_name:
|
||||
schema_part, table_part = table_name.split(".", 1)
|
||||
else:
|
||||
schema_part, table_part = "public", table_name
|
||||
estimate = await cursor.execute(
|
||||
text(
|
||||
"SELECT reltuples:bigint "
|
||||
"FROM pg_class c "
|
||||
"JOIN pg_namespace n ON n.oid = c.relnamespace "
|
||||
"WHERE n.nspname = :schema AND c.relname = :table"
|
||||
),
|
||||
{"schema": schema_part, "table": table_part},
|
||||
)
|
||||
row_count_estimate = estimate.scalar()
|
||||
else:
|
||||
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};"))
|
||||
row_count_estimate = count_result.scalar()
|
||||
|
||||
schema_table = SchemaTable(
|
||||
id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"),
|
||||
|
|
@ -79,9 +98,9 @@ async def ingest_database_schema(
|
|||
|
||||
for fk in details.get("foreign_keys", []):
|
||||
ref_table_fq = fk["ref_table"]
|
||||
if '.' not in ref_table_fq and '.' in table_name:
|
||||
if "." not in ref_table_fq and "." in table_name:
|
||||
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
|
||||
|
||||
|
||||
relationship = SchemaRelationship(
|
||||
id=uuid5(
|
||||
NAMESPACE_OID,
|
||||
|
|
@ -102,7 +121,7 @@ async def ingest_database_schema(
|
|||
database_type=database_config.get("migration_db_provider", "sqlite"),
|
||||
tables=tables,
|
||||
sample_data=sample_data,
|
||||
extraction_timestamp=datetime.utcnow(),
|
||||
extraction_timestamp=datetime.now(timezone.utc),
|
||||
description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.",
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue