refactor: refactor schema migration
This commit is contained in:
parent
8ff58f0278
commit
f93d30ae77
3 changed files with 16 additions and 20 deletions
|
|
@ -32,7 +32,7 @@ async def migrate_relational_database(
|
|||
"""
|
||||
# Create a mapping of node_id to node objects for referencing in edge creation
|
||||
if schema_only:
|
||||
node_mapping, edge_mapping = await schema_only_ingestion()
|
||||
node_mapping, edge_mapping = await schema_only_ingestion(schema)
|
||||
|
||||
else:
|
||||
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
|
||||
|
|
@ -74,13 +74,13 @@ async def migrate_relational_database(
|
|||
return await graph_db.get_graph_data()
|
||||
|
||||
|
||||
async def schema_only_ingestion():
|
||||
async def schema_only_ingestion(schema):
|
||||
node_mapping = {}
|
||||
edge_mapping = []
|
||||
database_config = get_migration_config().to_dict()
|
||||
|
||||
# Calling the ingest_database_schema function to return DataPoint subclasses
|
||||
result = await ingest_database_schema(
|
||||
database_config=database_config,
|
||||
schema=schema,
|
||||
schema_name="migrated_schema",
|
||||
max_sample_rows=5,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,14 +3,15 @@ from uuid import uuid5, NAMESPACE_OID
|
|||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from sqlalchemy import text
|
||||
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
|
||||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||
create_relational_engine,
|
||||
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
|
||||
get_migration_relational_engine,
|
||||
)
|
||||
from cognee.infrastructure.databases.relational.config import get_migration_config
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
async def ingest_database_schema(
|
||||
database_config: Dict,
|
||||
schema,
|
||||
schema_name: str = "default",
|
||||
max_sample_rows: int = 0,
|
||||
) -> Dict[str, List[DataPoint] | DataPoint]:
|
||||
|
|
@ -28,20 +29,13 @@ async def ingest_database_schema(
|
|||
"schema_tables": List[SchemaTable]
|
||||
"relationships": List[SchemaRelationship]
|
||||
"""
|
||||
engine = create_relational_engine(
|
||||
db_path=database_config.get("migration_db_path", ""),
|
||||
db_name=database_config.get("migration_db_name", "cognee_db"),
|
||||
db_host=database_config.get("migration_db_host"),
|
||||
db_port=database_config.get("migration_db_port"),
|
||||
db_username=database_config.get("migration_db_username"),
|
||||
db_password=database_config.get("migration_db_password"),
|
||||
db_provider=database_config.get("migration_db_provider", "sqlite"),
|
||||
)
|
||||
schema = await engine.extract_schema()
|
||||
|
||||
tables = {}
|
||||
sample_data = {}
|
||||
schema_tables = []
|
||||
schema_relationships = []
|
||||
|
||||
engine = get_migration_relational_engine()
|
||||
qi = engine.engine.dialect.identifier_preparer.quote
|
||||
try:
|
||||
max_sample_rows = max(0, int(max_sample_rows))
|
||||
|
|
@ -63,7 +57,7 @@ async def ingest_database_schema(
|
|||
rows = [dict(r) for r in rows_result.mappings().all()]
|
||||
else:
|
||||
rows = []
|
||||
row_count_estimate = 0
|
||||
|
||||
if engine.engine.dialect.name == "postgresql":
|
||||
if "." in table_name:
|
||||
schema_part, table_part = table_name.split(".", 1)
|
||||
|
|
@ -117,11 +111,12 @@ async def ingest_database_schema(
|
|||
)
|
||||
schema_relationships.append(relationship)
|
||||
|
||||
id_str = f"{database_config.get('migration_db_provider', 'sqlite')}:{database_config.get('migration_db_name', 'cognee_db')}:{schema_name}"
|
||||
migration_config = get_migration_config()
|
||||
id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}:{schema_name}"
|
||||
database_schema = DatabaseSchema(
|
||||
id=uuid5(NAMESPACE_OID, name=id_str),
|
||||
schema_name=schema_name,
|
||||
database_type=database_config.get("migration_db_provider", "sqlite"),
|
||||
database_type=migration_config.migration_db_provider,
|
||||
tables=tables,
|
||||
sample_data=sample_data,
|
||||
extraction_timestamp=datetime.now(timezone.utc),
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ async def test_schema_only_migration():
|
|||
search_results = await cognee.search(
|
||||
query_text="How many tables are there in this database",
|
||||
query_type=cognee.SearchType.GRAPH_COMPLETION,
|
||||
top_k=30,
|
||||
)
|
||||
assert any("11" in r for r in search_results), (
|
||||
"Number of tables in the database reported in search_results is either None or not equal to 11"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue