refactor: refactor schema migration
This commit is contained in:
parent
8ff58f0278
commit
f93d30ae77
3 changed files with 16 additions and 20 deletions
|
|
@ -32,7 +32,7 @@ async def migrate_relational_database(
|
||||||
"""
|
"""
|
||||||
# Create a mapping of node_id to node objects for referencing in edge creation
|
# Create a mapping of node_id to node objects for referencing in edge creation
|
||||||
if schema_only:
|
if schema_only:
|
||||||
node_mapping, edge_mapping = await schema_only_ingestion()
|
node_mapping, edge_mapping = await schema_only_ingestion(schema)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
|
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
|
||||||
|
|
@ -74,13 +74,13 @@ async def migrate_relational_database(
|
||||||
return await graph_db.get_graph_data()
|
return await graph_db.get_graph_data()
|
||||||
|
|
||||||
|
|
||||||
async def schema_only_ingestion():
|
async def schema_only_ingestion(schema):
|
||||||
node_mapping = {}
|
node_mapping = {}
|
||||||
edge_mapping = []
|
edge_mapping = []
|
||||||
database_config = get_migration_config().to_dict()
|
|
||||||
# Calling the ingest_database_schema function to return DataPoint subclasses
|
# Calling the ingest_database_schema function to return DataPoint subclasses
|
||||||
result = await ingest_database_schema(
|
result = await ingest_database_schema(
|
||||||
database_config=database_config,
|
schema=schema,
|
||||||
schema_name="migrated_schema",
|
schema_name="migrated_schema",
|
||||||
max_sample_rows=5,
|
max_sample_rows=5,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,15 @@ from uuid import uuid5, NAMESPACE_OID
|
||||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
|
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
|
||||||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
|
||||||
create_relational_engine,
|
get_migration_relational_engine,
|
||||||
)
|
)
|
||||||
|
from cognee.infrastructure.databases.relational.config import get_migration_config
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
async def ingest_database_schema(
|
async def ingest_database_schema(
|
||||||
database_config: Dict,
|
schema,
|
||||||
schema_name: str = "default",
|
schema_name: str = "default",
|
||||||
max_sample_rows: int = 0,
|
max_sample_rows: int = 0,
|
||||||
) -> Dict[str, List[DataPoint] | DataPoint]:
|
) -> Dict[str, List[DataPoint] | DataPoint]:
|
||||||
|
|
@ -28,20 +29,13 @@ async def ingest_database_schema(
|
||||||
"schema_tables": List[SchemaTable]
|
"schema_tables": List[SchemaTable]
|
||||||
"relationships": List[SchemaRelationship]
|
"relationships": List[SchemaRelationship]
|
||||||
"""
|
"""
|
||||||
engine = create_relational_engine(
|
|
||||||
db_path=database_config.get("migration_db_path", ""),
|
|
||||||
db_name=database_config.get("migration_db_name", "cognee_db"),
|
|
||||||
db_host=database_config.get("migration_db_host"),
|
|
||||||
db_port=database_config.get("migration_db_port"),
|
|
||||||
db_username=database_config.get("migration_db_username"),
|
|
||||||
db_password=database_config.get("migration_db_password"),
|
|
||||||
db_provider=database_config.get("migration_db_provider", "sqlite"),
|
|
||||||
)
|
|
||||||
schema = await engine.extract_schema()
|
|
||||||
tables = {}
|
tables = {}
|
||||||
sample_data = {}
|
sample_data = {}
|
||||||
schema_tables = []
|
schema_tables = []
|
||||||
schema_relationships = []
|
schema_relationships = []
|
||||||
|
|
||||||
|
engine = get_migration_relational_engine()
|
||||||
qi = engine.engine.dialect.identifier_preparer.quote
|
qi = engine.engine.dialect.identifier_preparer.quote
|
||||||
try:
|
try:
|
||||||
max_sample_rows = max(0, int(max_sample_rows))
|
max_sample_rows = max(0, int(max_sample_rows))
|
||||||
|
|
@ -63,7 +57,7 @@ async def ingest_database_schema(
|
||||||
rows = [dict(r) for r in rows_result.mappings().all()]
|
rows = [dict(r) for r in rows_result.mappings().all()]
|
||||||
else:
|
else:
|
||||||
rows = []
|
rows = []
|
||||||
row_count_estimate = 0
|
|
||||||
if engine.engine.dialect.name == "postgresql":
|
if engine.engine.dialect.name == "postgresql":
|
||||||
if "." in table_name:
|
if "." in table_name:
|
||||||
schema_part, table_part = table_name.split(".", 1)
|
schema_part, table_part = table_name.split(".", 1)
|
||||||
|
|
@ -117,11 +111,12 @@ async def ingest_database_schema(
|
||||||
)
|
)
|
||||||
schema_relationships.append(relationship)
|
schema_relationships.append(relationship)
|
||||||
|
|
||||||
id_str = f"{database_config.get('migration_db_provider', 'sqlite')}:{database_config.get('migration_db_name', 'cognee_db')}:{schema_name}"
|
migration_config = get_migration_config()
|
||||||
|
id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}:{schema_name}"
|
||||||
database_schema = DatabaseSchema(
|
database_schema = DatabaseSchema(
|
||||||
id=uuid5(NAMESPACE_OID, name=id_str),
|
id=uuid5(NAMESPACE_OID, name=id_str),
|
||||||
schema_name=schema_name,
|
schema_name=schema_name,
|
||||||
database_type=database_config.get("migration_db_provider", "sqlite"),
|
database_type=migration_config.migration_db_provider,
|
||||||
tables=tables,
|
tables=tables,
|
||||||
sample_data=sample_data,
|
sample_data=sample_data,
|
||||||
extraction_timestamp=datetime.now(timezone.utc),
|
extraction_timestamp=datetime.now(timezone.utc),
|
||||||
|
|
|
||||||
|
|
@ -212,6 +212,7 @@ async def test_schema_only_migration():
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
query_text="How many tables are there in this database",
|
query_text="How many tables are there in this database",
|
||||||
query_type=cognee.SearchType.GRAPH_COMPLETION,
|
query_type=cognee.SearchType.GRAPH_COMPLETION,
|
||||||
|
top_k=30,
|
||||||
)
|
)
|
||||||
assert any("11" in r for r in search_results), (
|
assert any("11" in r for r in search_results), (
|
||||||
"Number of tables in the database reported in search_results is either None or not equal to 11"
|
"Number of tables in the database reported in search_results is either None or not equal to 11"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue