solved more nitpick comments
This commit is contained in:
parent
1e59f1594c
commit
df8b80d4a9
1 changed files with 36 additions and 17 deletions
|
|
@ -9,13 +9,13 @@ from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelati
|
||||||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||||
create_relational_engine,
|
create_relational_engine,
|
||||||
)
|
)
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
async def ingest_database_schema(
|
async def ingest_database_schema(
|
||||||
database_config: Dict,
|
database_config: Dict,
|
||||||
schema_name: str = "default",
|
schema_name: str = "default",
|
||||||
max_sample_rows: int = 5,
|
max_sample_rows: int = 0,
|
||||||
) -> Dict[str, List[DataPoint] | DataPoint]:
|
) -> Dict[str, List[DataPoint] | DataPoint]:
|
||||||
"""
|
"""
|
||||||
Ingest database schema with sample data into dedicated nodeset
|
Ingest database schema with sample data into dedicated nodeset
|
||||||
|
|
@ -45,22 +45,41 @@ async def ingest_database_schema(
|
||||||
sample_data = {}
|
sample_data = {}
|
||||||
schema_tables = []
|
schema_tables = []
|
||||||
schema_relationships = []
|
schema_relationships = []
|
||||||
|
qi = engine.engine.dialect.identifier_preparer.quote
|
||||||
|
|
||||||
|
def qname(name: str):
|
||||||
|
split_name = name.split(".")
|
||||||
|
".".join(qi(p) for p in split_name)
|
||||||
|
|
||||||
async with engine.engine.begin() as cursor:
|
async with engine.engine.begin() as cursor:
|
||||||
for table_name, details in schema.items():
|
for table_name, details in schema.items():
|
||||||
qi = engine.engine.dialect.identifier_preparer.quote
|
|
||||||
qname = lambda name : ".".join(qi(p) for p in name.split("."))
|
|
||||||
tn = qname(table_name)
|
tn = qname(table_name)
|
||||||
rows_result = await cursor.execute(
|
if max_sample_rows > 0:
|
||||||
text(f"SELECT * FROM {tn} LIMIT :limit;"),
|
rows_result = await cursor.execute(
|
||||||
{"limit": max_sample_rows}
|
text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows}
|
||||||
)
|
)
|
||||||
rows = [
|
rows = [dict(r) for r in rows_result.mappings().all()]
|
||||||
dict(zip([col["name"] for col in details["columns"]], row))
|
else:
|
||||||
for row in rows_result.fetchall()
|
rows = []
|
||||||
]
|
row_count_estimate = 0
|
||||||
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};"))
|
if engine.engine.dialect.name == "postegresql":
|
||||||
row_count_estimate = count_result.scalar()
|
if "." in table_name:
|
||||||
|
schema_part, table_part = table_name.split(".", 1)
|
||||||
|
else:
|
||||||
|
schema_part, table_part = "public", table_name
|
||||||
|
estimate = await cursor.execute(
|
||||||
|
text(
|
||||||
|
"SELECT reltuples:bigint "
|
||||||
|
"FROM pg_class c "
|
||||||
|
"JOIN pg_namespace n ON n.oid = c.relnamespace "
|
||||||
|
"WHERE n.nspname = :schema AND c.relname = :table"
|
||||||
|
),
|
||||||
|
{"schema": schema_part, "table": table_part},
|
||||||
|
)
|
||||||
|
row_count_estimate = estimate.scalar()
|
||||||
|
else:
|
||||||
|
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};"))
|
||||||
|
row_count_estimate = count_result.scalar()
|
||||||
|
|
||||||
schema_table = SchemaTable(
|
schema_table = SchemaTable(
|
||||||
id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"),
|
id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"),
|
||||||
|
|
@ -79,9 +98,9 @@ async def ingest_database_schema(
|
||||||
|
|
||||||
for fk in details.get("foreign_keys", []):
|
for fk in details.get("foreign_keys", []):
|
||||||
ref_table_fq = fk["ref_table"]
|
ref_table_fq = fk["ref_table"]
|
||||||
if '.' not in ref_table_fq and '.' in table_name:
|
if "." not in ref_table_fq and "." in table_name:
|
||||||
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
|
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
|
||||||
|
|
||||||
relationship = SchemaRelationship(
|
relationship = SchemaRelationship(
|
||||||
id=uuid5(
|
id=uuid5(
|
||||||
NAMESPACE_OID,
|
NAMESPACE_OID,
|
||||||
|
|
@ -102,7 +121,7 @@ async def ingest_database_schema(
|
||||||
database_type=database_config.get("migration_db_provider", "sqlite"),
|
database_type=database_config.get("migration_db_provider", "sqlite"),
|
||||||
tables=tables,
|
tables=tables,
|
||||||
sample_data=sample_data,
|
sample_data=sample_data,
|
||||||
extraction_timestamp=datetime.utcnow(),
|
extraction_timestamp=datetime.now(timezone.utc),
|
||||||
description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.",
|
description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue