solved more nitpick comments
This commit is contained in:
parent
df8b80d4a9
commit
e7bcf9043f
1 changed files with 12 additions and 13 deletions
|
|
@ -1,9 +1,6 @@
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict
|
||||||
from uuid import uuid5, NAMESPACE_OID
|
from uuid import uuid5, NAMESPACE_OID
|
||||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
|
|
||||||
get_migration_relational_engine,
|
|
||||||
)
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
|
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
|
||||||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||||
|
|
@ -18,12 +15,12 @@ async def ingest_database_schema(
|
||||||
max_sample_rows: int = 0,
|
max_sample_rows: int = 0,
|
||||||
) -> Dict[str, List[DataPoint] | DataPoint]:
|
) -> Dict[str, List[DataPoint] | DataPoint]:
|
||||||
"""
|
"""
|
||||||
Ingest database schema with sample data into dedicated nodeset
|
Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
database_config: Database connection configuration
|
database_config: Database connection configuration
|
||||||
schema_name: Name identifier for this schema
|
schema_name: Name identifier for this schema
|
||||||
max_sample_rows: Maximum sample rows per table
|
max_sample_rows: Maximum sample rows per table (0 means no sampling)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with keys:
|
Dict with keys:
|
||||||
|
|
@ -49,36 +46,37 @@ async def ingest_database_schema(
|
||||||
|
|
||||||
def qname(name: str):
|
def qname(name: str):
|
||||||
split_name = name.split(".")
|
split_name = name.split(".")
|
||||||
".".join(qi(p) for p in split_name)
|
return ".".join(qi(p) for p in split_name)
|
||||||
|
|
||||||
async with engine.engine.begin() as cursor:
|
async with engine.engine.begin() as cursor:
|
||||||
for table_name, details in schema.items():
|
for table_name, details in schema.items():
|
||||||
tn = qname(table_name)
|
tn = qname(table_name)
|
||||||
if max_sample_rows > 0:
|
if max_sample_rows > 0:
|
||||||
rows_result = await cursor.execute(
|
rows_result = await cursor.execute(
|
||||||
text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows}
|
text(f"SELECT * FROM {tn} LIMIT :limit;"),
|
||||||
|
{"limit": max_sample_rows}, # noqa: S608 - tn is fully quoted
|
||||||
)
|
)
|
||||||
rows = [dict(r) for r in rows_result.mappings().all()]
|
rows = [dict(r) for r in rows_result.mappings().all()]
|
||||||
else:
|
else:
|
||||||
rows = []
|
rows = []
|
||||||
row_count_estimate = 0
|
row_count_estimate = 0
|
||||||
if engine.engine.dialect.name == "postegresql":
|
if engine.engine.dialect.name == "postgresql":
|
||||||
if "." in table_name:
|
if "." in table_name:
|
||||||
schema_part, table_part = table_name.split(".", 1)
|
schema_part, table_part = table_name.split(".", 1)
|
||||||
else:
|
else:
|
||||||
schema_part, table_part = "public", table_name
|
schema_part, table_part = "public", table_name
|
||||||
estimate = await cursor.execute(
|
estimate = await cursor.execute(
|
||||||
text(
|
text(
|
||||||
"SELECT reltuples:bigint "
|
"SELECT reltuples:bigint AS estimate "
|
||||||
"FROM pg_class c "
|
"FROM pg_class c "
|
||||||
"JOIN pg_namespace n ON n.oid = c.relnamespace "
|
"JOIN pg_namespace n ON n.oid = c.relnamespace "
|
||||||
"WHERE n.nspname = :schema AND c.relname = :table"
|
"WHERE n.nspname = :schema AND c.relname = :table"
|
||||||
),
|
),
|
||||||
{"schema": schema_part, "table": table_part},
|
{"schema": schema_part, "table": table_part},
|
||||||
)
|
)
|
||||||
row_count_estimate = estimate.scalar()
|
row_count_estimate = estimate.scalar() or 0
|
||||||
else:
|
else:
|
||||||
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};"))
|
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) # noqa: S608 - tn is fully quoted
|
||||||
row_count_estimate = count_result.scalar()
|
row_count_estimate = count_result.scalar()
|
||||||
|
|
||||||
schema_table = SchemaTable(
|
schema_table = SchemaTable(
|
||||||
|
|
@ -115,8 +113,9 @@ async def ingest_database_schema(
|
||||||
)
|
)
|
||||||
schema_relationships.append(relationship)
|
schema_relationships.append(relationship)
|
||||||
|
|
||||||
|
id_str = f"{database_config.get('migration_db_provider', 'sqlite')}:{database_config.get('migration_db_name', 'cognee_db')}:{schema_name}"
|
||||||
database_schema = DatabaseSchema(
|
database_schema = DatabaseSchema(
|
||||||
id=uuid5(NAMESPACE_OID, name=schema_name),
|
id=uuid5(NAMESPACE_OID, name=id_str),
|
||||||
schema_name=schema_name,
|
schema_name=schema_name,
|
||||||
database_type=database_config.get("migration_db_provider", "sqlite"),
|
database_type=database_config.get("migration_db_provider", "sqlite"),
|
||||||
tables=tables,
|
tables=tables,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue