diff --git a/cognee/modules/visualization/cognee_network_visualization.py b/cognee/modules/visualization/cognee_network_visualization.py index bbdbc0019..c735e70f1 100644 --- a/cognee/modules/visualization/cognee_network_visualization.py +++ b/cognee/modules/visualization/cognee_network_visualization.py @@ -23,6 +23,9 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = "TableRow": "#f47710", "TableType": "#6510f4", "ColumnValue": "#13613a", + "SchemaTable": "#f47710", + "DatabaseSchema": "#6510f4", + "SchemaRelationship": "#13613a", "default": "#D3D3D3", } diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 936ea59e0..53ce176e8 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -4,16 +4,20 @@ from sqlalchemy import text from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( get_migration_relational_engine, ) +from cognee.infrastructure.databases.relational.config import get_migration_config from cognee.tasks.storage.index_data_points import index_data_points from cognee.tasks.storage.index_graph_edges import index_graph_edges +from cognee.tasks.schema.ingest_database_schema import ingest_database_schema from cognee.modules.engine.models import TableRow, TableType, ColumnValue logger = logging.getLogger(__name__) -async def migrate_relational_database(graph_db, schema, migrate_column_data=True): +async def migrate_relational_database( + graph_db, schema, migrate_column_data=True, schema_only=False +): """ Migrates data from a relational database into a graph database. @@ -26,11 +30,133 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model. """ + # Create a mapping of node_id to node objects for referencing in edge creation + if schema_only: + node_mapping, edge_mapping = await schema_only_ingestion(schema) + + else: + node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data) + + def _remove_duplicate_edges(edge_mapping): + seen = set() + unique_original_shape = [] + + for tup in edge_mapping: + # We go through all the tuples in the edge_mapping and we only add unique tuples to the list + # To eliminate duplicate edges. + source_id, target_id, rel_name, rel_dict = tup + # We need to convert the dictionary to a frozenset to be able to compare values for it + rel_dict_hashable = frozenset(sorted(rel_dict.items())) + hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable) + + # We use the seen set to keep track of unique edges + if hashable_tup not in seen: + # A list that has frozensets elements instead of dictionaries is needed to be able to compare values + seen.add(hashable_tup) + # append the original tuple shape (with the dictionary) if it's the first time we see it + unique_original_shape.append(tup) + + return unique_original_shape + + # Add all nodes and edges to the graph + # NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX. + # If we'd create nodes and add them to graph in real time the process would take too long. + # Every node and edge added to NetworkX is saved to file which is very slow when not done in batches. + await graph_db.add_nodes(list(node_mapping.values())) + await graph_db.add_edges(_remove_duplicate_edges(edge_mapping)) + + # In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database + # Cognee uses this information to perform searches on the knowledge graph. + await index_data_points(list(node_mapping.values())) + await index_graph_edges() + + logger.info("Data successfully migrated from relational database to desired graph database.") + return await graph_db.get_graph_data() + + +async def schema_only_ingestion(schema): + node_mapping = {} + edge_mapping = [] + + # Calling the ingest_database_schema function to return DataPoint subclasses + result = await ingest_database_schema( + schema=schema, + max_sample_rows=5, + ) + database_schema = result["database_schema"] + schema_tables = result["schema_tables"] + schema_relationships = result["relationships"] + database_node_id = database_schema.id + node_mapping[database_node_id] = database_schema + for table in schema_tables: + table_node_id = table.id + # Add TableSchema Datapoint as a node. + node_mapping[table_node_id] = table + edge_mapping.append( + ( + table_node_id, + database_node_id, + "is_part_of", + dict( + source_node_id=table_node_id, + target_node_id=database_node_id, + relationship_name="is_part_of", + ), + ) + ) + table_name_to_id = {t.name: t.id for t in schema_tables} + for rel in schema_relationships: + source_table_id = table_name_to_id.get(rel.source_table) + target_table_id = table_name_to_id.get(rel.target_table) + + relationship_id = rel.id + + # Add RelationshipTable DataPoint as a node. + node_mapping[relationship_id] = rel + edge_mapping.append( + ( + source_table_id, + relationship_id, + "has_relationship", + dict( + source_node_id=source_table_id, + target_node_id=relationship_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + relationship_id, + target_table_id, + "has_relationship", + dict( + source_node_id=relationship_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + source_table_id, + target_table_id, + rel.relationship_type, + dict( + source_node_id=source_table_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + return node_mapping, edge_mapping + + +async def complete_database_ingestion(schema, migrate_column_data): engine = get_migration_relational_engine() # Create a mapping of node_id to node objects for referencing in edge creation node_mapping = {} edge_mapping = [] - async with engine.engine.begin() as cursor: # First, create table type nodes for all tables for table_name, details in schema.items(): @@ -38,7 +164,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True table_node = TableType( id=uuid5(NAMESPACE_OID, name=table_name), name=table_name, - description=f"Table: {table_name}", + description=f'Relational database table with the following name: "{table_name}".', ) # Add TableType node to mapping ( node will be added to the graph later based on this mapping ) @@ -75,7 +201,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True name=node_id, is_a=table_node, properties=str(row_properties), - description=f"Row in {table_name} with {primary_key_col}={primary_key_value}", + description=f'Row in relational database table from the table with the name: "{table_name}" with the following row data {str(row_properties)} where the dictionary key value is the column name and the value is the column value. This row has the id of: {node_id}', ) # Store the node object in our mapping @@ -113,7 +239,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True id=uuid5(NAMESPACE_OID, name=column_node_id), name=column_node_id, properties=f"{key} {value} {table_name}", - description=f"Column name={key} and value={value} from column from table={table_name}", + description=f"column from relational database table={table_name}. Column name={key} and value={value}. The value of the column is related to the following row with this id: {row_node.id}. This column has the following ID: {column_node_id}", ) node_mapping[column_node_id] = column_node @@ -180,39 +306,4 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True ), ) ) - - def _remove_duplicate_edges(edge_mapping): - seen = set() - unique_original_shape = [] - - for tup in edge_mapping: - # We go through all the tuples in the edge_mapping and we only add unique tuples to the list - # To eliminate duplicate edges. - source_id, target_id, rel_name, rel_dict = tup - # We need to convert the dictionary to a frozenset to be able to compare values for it - rel_dict_hashable = frozenset(sorted(rel_dict.items())) - hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable) - - # We use the seen set to keep track of unique edges - if hashable_tup not in seen: - # A list that has frozensets elements instead of dictionaries is needed to be able to compare values - seen.add(hashable_tup) - # append the original tuple shape (with the dictionary) if it's the first time we see it - unique_original_shape.append(tup) - - return unique_original_shape - - # Add all nodes and edges to the graph - # NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX. - # If we'd create nodes and add them to graph in real time the process would take too long. - # Every node and edge added to NetworkX is saved to file which is very slow when not done in batches. - await graph_db.add_nodes(list(node_mapping.values())) - await graph_db.add_edges(_remove_duplicate_edges(edge_mapping)) - - # In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database - # Cognee uses this information to perform searches on the knowledge graph. - await index_data_points(list(node_mapping.values())) - await index_graph_edges() - - logger.info("Data successfully migrated from relational database to desired graph database.") - return await graph_db.get_graph_data() + return node_mapping, edge_mapping diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py new file mode 100644 index 000000000..e3823701c --- /dev/null +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -0,0 +1,134 @@ +from typing import List, Dict +from uuid import uuid5, NAMESPACE_OID +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from sqlalchemy import text +from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship +from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( + get_migration_relational_engine, +) +from cognee.infrastructure.databases.relational.config import get_migration_config +from datetime import datetime, timezone + + +async def ingest_database_schema( + schema, + max_sample_rows: int = 0, +) -> Dict[str, List[DataPoint] | DataPoint]: + """ + Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction. + + Args: + schema: Database schema + max_sample_rows: Maximum sample rows per table (0 means no sampling) + + Returns: + Dict with keys: + "database_schema": DatabaseSchema + "schema_tables": List[SchemaTable] + "relationships": List[SchemaRelationship] + """ + + tables = {} + sample_data = {} + schema_tables = [] + schema_relationships = [] + + migration_config = get_migration_config() + engine = get_migration_relational_engine() + qi = engine.engine.dialect.identifier_preparer.quote + try: + max_sample_rows = max(0, int(max_sample_rows)) + except (TypeError, ValueError): + max_sample_rows = 0 + + def qname(name: str): + split_name = name.split(".") + return ".".join(qi(p) for p in split_name) + + async with engine.engine.begin() as cursor: + for table_name, details in schema.items(): + tn = qname(table_name) + if max_sample_rows > 0: + rows_result = await cursor.execute( + text(f"SELECT * FROM {tn} LIMIT :limit;"), # noqa: S608 - tn is fully quoted + {"limit": max_sample_rows}, + ) + rows = [dict(r) for r in rows_result.mappings().all()] + else: + rows = [] + + if engine.engine.dialect.name == "postgresql": + if "." in table_name: + schema_part, table_part = table_name.split(".", 1) + else: + schema_part, table_part = "public", table_name + estimate = await cursor.execute( + text( + "SELECT reltuples::bigint AS estimate " + "FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relname = :table" + ), + {"schema": schema_part, "table": table_part}, + ) + row_count_estimate = estimate.scalar() or 0 + else: + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) # noqa: S608 - tn is fully quoted + row_count_estimate = count_result.scalar() + + schema_table = SchemaTable( + id=uuid5(NAMESPACE_OID, name=f"{table_name}"), + name=table_name, + columns=details["columns"], + primary_key=details.get("primary_key"), + foreign_keys=details.get("foreign_keys", []), + sample_rows=rows, + row_count_estimate=row_count_estimate, + description=f"Relational database table with '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows." + f"Here are the columns this table contains: {details['columns']}" + f"Here are a few sample_rows to show the contents of the table: {rows}" + f"Table is part of the database: {migration_config.migration_db_name}", + ) + schema_tables.append(schema_table) + tables[table_name] = details + sample_data[table_name] = rows + + for fk in details.get("foreign_keys", []): + ref_table_fq = fk["ref_table"] + if "." not in ref_table_fq and "." in table_name: + ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}" + + relationship_name = ( + f"{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}" + ) + relationship = SchemaRelationship( + id=uuid5(NAMESPACE_OID, name=relationship_name), + name=relationship_name, + source_table=table_name, + target_table=ref_table_fq, + relationship_type="foreign_key", + source_column=fk["column"], + target_column=fk["ref_column"], + description=f"Relational database table foreign key relationship between: {table_name}.{fk['column']} → {ref_table_fq}.{fk['ref_column']}" + f"This foreing key relationship between table columns is a part of the following database: {migration_config.migration_db_name}", + ) + schema_relationships.append(relationship) + + id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}" + database_schema = DatabaseSchema( + id=uuid5(NAMESPACE_OID, name=id_str), + name=migration_config.migration_db_name, + database_type=migration_config.migration_db_provider, + tables=tables, + sample_data=sample_data, + extraction_timestamp=datetime.now(timezone.utc), + description=f"Database schema containing {len(schema_tables)} tables and {len(schema_relationships)} relationships. " + f"The database type is {migration_config.migration_db_provider}." + f"The database contains the following tables: {tables}", + ) + + return { + "database_schema": database_schema, + "schema_tables": schema_tables, + "relationships": schema_relationships, + } diff --git a/cognee/tasks/schema/models.py b/cognee/tasks/schema/models.py new file mode 100644 index 000000000..4b13f420b --- /dev/null +++ b/cognee/tasks/schema/models.py @@ -0,0 +1,41 @@ +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from typing import List, Dict, Optional +from datetime import datetime + + +class DatabaseSchema(DataPoint): + """Represents a complete database schema with sample data""" + + name: str + database_type: str # sqlite, postgres, etc. + tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter + sample_data: Dict[str, List[Dict]] # Limited examples per table + extraction_timestamp: datetime + description: str + metadata: dict = {"index_fields": ["description", "name"]} + + +class SchemaTable(DataPoint): + """Represents an individual table schema with relationships""" + + name: str + columns: List[Dict] # Column definitions with types + primary_key: Optional[str] + foreign_keys: List[Dict] # Foreign key relationships + sample_rows: List[Dict] # Max 3-5 example rows + row_count_estimate: Optional[int] # Actual table size + description: str + metadata: dict = {"index_fields": ["description", "name"]} + + +class SchemaRelationship(DataPoint): + """Represents relationships between tables""" + + name: str + source_table: str + target_table: str + relationship_type: str # "foreign_key", "one_to_many", etc. + source_column: str + target_column: str + description: str + metadata: dict = {"index_fields": ["description", "name"]} diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 68b46dbf5..2b69ce854 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -197,6 +197,80 @@ async def relational_db_migration(): print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!") +async def test_schema_only_migration(): + # 1. Setup test DB and extract schema + migration_engine = await setup_test_db() + schema = await migration_engine.extract_schema() + + # 2. Setup graph engine + graph_engine = await get_graph_engine() + + # 4. Migrate schema only + await migrate_relational_database(graph_engine, schema=schema, schema_only=True) + + # 5. Verify number of tables through search + search_results = await cognee.search( + query_text="How many tables are there in this database", + query_type=cognee.SearchType.GRAPH_COMPLETION, + top_k=30, + ) + assert any("11" in r for r in search_results), ( + "Number of tables in the database reported in search_results is either None or not equal to 11" + ) + + graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower() + + edge_counts = { + "is_part_of": 0, + "has_relationship": 0, + "foreign_key": 0, + } + + if graph_db_provider == "neo4j": + for rel_type in edge_counts.keys(): + query_str = f""" + MATCH ()-[r:{rel_type}]->() + RETURN count(r) as c + """ + rows = await graph_engine.query(query_str) + edge_counts[rel_type] = rows[0]["c"] + + elif graph_db_provider == "kuzu": + for rel_type in edge_counts.keys(): + query_str = f""" + MATCH ()-[r:EDGE]->() + WHERE r.relationship_name = '{rel_type}' + RETURN count(r) as c + """ + rows = await graph_engine.query(query_str) + edge_counts[rel_type] = rows[0][0] + + elif graph_db_provider == "networkx": + nodes, edges = await graph_engine.get_graph_data() + for _, _, key, _ in edges: + if key in edge_counts: + edge_counts[key] += 1 + + else: + raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") + + # 7. Assert counts match expected values + expected_counts = { + "is_part_of": 11, + "has_relationship": 22, + "foreign_key": 11, + } + + for rel_type, expected in expected_counts.items(): + actual = edge_counts[rel_type] + assert actual == expected, ( + f"Expected {expected} edges for relationship '{rel_type}', but found {actual}" + ) + + print("Schema-only migration edge counts validated successfully!") + print(f"Edge counts: {edge_counts}") + + async def test_migration_sqlite(): database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/") @@ -209,6 +283,7 @@ async def test_migration_sqlite(): ) await relational_db_migration() + await test_schema_only_migration() async def test_migration_postgres(): @@ -224,6 +299,7 @@ async def test_migration_postgres(): } ) await relational_db_migration() + await test_schema_only_migration() async def main(): diff --git a/examples/python/relational_database_migration_example.py b/examples/python/relational_database_migration_example.py index fae8cfb3d..6a5c3b78b 100644 --- a/examples/python/relational_database_migration_example.py +++ b/examples/python/relational_database_migration_example.py @@ -1,16 +1,15 @@ +from pathlib import Path import asyncio - -import cognee import os +import cognee +from cognee.infrastructure.databases.relational.config import get_migration_config from cognee.infrastructure.databases.graph import get_graph_engine from cognee.api.v1.visualize.visualize import visualize_graph from cognee.infrastructure.databases.relational import ( get_migration_relational_engine, ) - from cognee.modules.search.types import SearchType - from cognee.infrastructure.databases.relational import ( create_db_and_tables as create_relational_db_and_tables, ) @@ -32,16 +31,29 @@ from cognee.infrastructure.databases.vector.pgvector import ( async def main(): - engine = get_migration_relational_engine() - # Clean all data stored in Cognee await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - # Needed to create appropriate tables only on the Cognee side + # Needed to create appropriate database tables only on the Cognee side await create_relational_db_and_tables() await create_vector_db_and_tables() + # In case environment variables are not set use the example database from the Cognee repo + migration_db_provider = os.environ.get("MIGRATION_DB_PROVIDER", "sqlite") + migration_db_path = os.environ.get( + "MIGRATION_DB_PATH", + os.path.join(Path(__file__).resolve().parent.parent.parent, "cognee/tests/test_data"), + ) + migration_db_name = os.environ.get("MIGRATION_DB_NAME", "migration_database.sqlite") + + migration_config = get_migration_config() + migration_config.migration_db_provider = migration_db_provider + migration_config.migration_db_path = migration_db_path + migration_config.migration_db_name = migration_db_name + + engine = get_migration_relational_engine() + print("\nExtracting schema of database to migrate.") schema = await engine.extract_schema() print(f"Migrated database schema:\n{schema}") @@ -53,10 +65,6 @@ async def main(): await migrate_relational_database(graph, schema=schema) print("Relational database migration complete.") - # Define location where to store html visualization of graph of the migrated database - home_dir = os.path.expanduser("~") - destination_file_path = os.path.join(home_dir, "graph_visualization.html") - # Make sure to set top_k at a high value for a broader search, the default value is only 10! # top_k represent the number of graph tripplets to supply to the LLM to answer your question search_results = await cognee.search( @@ -69,13 +77,25 @@ async def main(): # Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered. # For this kind of question we've set the top_k to 30 search_results = await cognee.search( - query_type=SearchType.GRAPH_COMPLETION_COT, + query_type=SearchType.GRAPH_COMPLETION, query_text="What invoices are related to Leonie Köhler?", top_k=30, ) print(f"Search results: {search_results}") - # test.html is a file with visualized data migration + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What invoices are related to Luís Gonçalves?", + top_k=30, + ) + print(f"Search results: {search_results}") + + # If you check the relational database for this example you can see that the search results successfully found all + # the invoices related to the two customers, without any hallucinations or additional information + + # Define location where to store html visualization of graph of the migrated database + home_dir = os.path.expanduser("~") + destination_file_path = os.path.join(home_dir, "graph_visualization.html") print("Adding html visualization of graph database after migration.") await visualize_graph(destination_file_path) print(f"Visualization can be found at: {destination_file_path}")