feat: Relational db to graph db [COG-1468] (#644)

<!-- .github/pull_request_template.md -->

## Description
Add ability to migrate relational database to graph database

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin
This commit is contained in:
Igor Ilic 2025-03-26 11:40:06 +01:00 committed by GitHub
parent 897a1f3081
commit 9f587a01a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 441 additions and 24 deletions

View file

@ -47,3 +47,14 @@ DB_NAME=cognee_db
# DB_PORT=5432
# DB_USERNAME=cognee
# DB_PASSWORD=cognee
# Params for migrating relational database data to graph / Cognee ( PostgreSQL and SQLite supported )
# MIGRATION_DB_PATH="/path/to/migration/directory"
# MIGRATION_DB_NAME="migration_database.sqlite"
# MIGRATION_DB_PROVIDER="sqlite"
# Postgres specific parameters for migration
# MIGRATION_DB_USERNAME=cognee
# MIGRATION_DB_PASSWORD=cognee
# MIGRATION_DB_HOST="127.0.0.1"
# MIGRATION_DB_PORT=5432

View file

@ -17,7 +17,6 @@ from kuzu import Connection
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
import aiofiles
logger = get_logger()

View file

@ -199,13 +199,14 @@ class Neo4jAdapter(GraphDBInterface):
serialized_properties = self.serialize_properties(edge_properties)
query = dedent(
"""MATCH (from_node {id: $from_node}),
(to_node {id: $to_node})
MERGE (from_node)-[r]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
f"""\
MATCH (from_node {{id: $from_node}}),
(to_node {{id: $to_node}})
MERGE (from_node)-[r:{relationship_name}]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp()
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
"""
"""
)
params = {

View file

@ -5,6 +5,7 @@ import os
import json
import asyncio
from cognee.shared.logging_utils import get_logger
from sqlalchemy import text
from typing import Dict, Any, List, Union
from uuid import UUID
import aiofiles
@ -88,6 +89,7 @@ class NetworkXAdapter(GraphDBInterface):
key=relationship_name,
**(edge_properties if edge_properties else {}),
)
await self.save_graph_to_file(self.filename)
async def add_edges(
@ -315,11 +317,13 @@ class NetworkXAdapter(GraphDBInterface):
logger.error(e)
raise e
if isinstance(edge["updated_at"], int): # Handle timestamp in milliseconds
if isinstance(
edge.get("updated_at"), int
): # Handle timestamp in milliseconds
edge["updated_at"] = datetime.fromtimestamp(
edge["updated_at"] / 1000, tz=timezone.utc
)
elif isinstance(edge["updated_at"], str):
elif isinstance(edge.get("updated_at"), str):
edge["updated_at"] = datetime.strptime(
edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
)

View file

@ -1,4 +1,6 @@
from .ModelBase import Base
from .config import get_relational_config
from .config import get_migration_config
from .create_db_and_tables import create_db_and_tables
from .get_relational_engine import get_relational_engine
from .get_migration_relational_engine import get_migration_relational_engine

View file

@ -31,3 +31,31 @@ class RelationalConfig(BaseSettings):
@lru_cache
def get_relational_config():
return RelationalConfig()
class MigrationConfig(BaseSettings):
migration_db_path: Union[str, None] = None
migration_db_name: str = None
migration_db_host: Union[str, None] = None
migration_db_port: Union[str, None] = None
migration_db_username: Union[str, None] = None
migration_db_password: Union[str, None] = None
migration_db_provider: str = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {
"migration_db_path": self.migration_db_path,
"migration_db_name": self.migration_db_name,
"migration_db_host": self.migration_db_host,
"migration_db_port": self.migration_db_port,
"migration_db_username": self.migration_db_username,
"migration_db_password": self.migration_db_password,
"migration_db_provider": self.migration_db_provider,
}
@lru_cache
def get_migration_config():
return MigrationConfig()

View file

@ -0,0 +1,16 @@
from .config import get_migration_config
from .create_relational_engine import create_relational_engine
def get_migration_relational_engine():
migration_config = get_migration_config()
return create_relational_engine(
db_path=migration_config.migration_db_path,
db_name=migration_config.migration_db_name,
db_host=migration_config.migration_db_host,
db_port=migration_config.migration_db_port,
db_username=migration_config.migration_db_username,
db_password=migration_config.migration_db_password,
db_provider=migration_config.migration_db_provider,
)

View file

@ -5,7 +5,7 @@ from uuid import UUID
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
from sqlalchemy import text, select, MetaData, Table, delete
from sqlalchemy import text, select, MetaData, Table, delete, inspect
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
@ -229,25 +229,23 @@ class SQLAlchemyAdapter:
async def get_table_names(self) -> List[str]:
"""
Return a list of all tables names in database
Return a list of all table names in the database, even if they don't have defined SQLAlchemy models.
"""
table_names = []
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
await connection.run_sync(Base.metadata.reflect)
for table in Base.metadata.tables:
table_names.append(str(table))
# Use a new MetaData instance to reflect all tables
metadata = MetaData()
await connection.run_sync(metadata.reflect) # Reflect the entire database
table_names = list(metadata.tables.keys()) # Get table names
else:
schema_list = await self.get_schema_list()
# Create a MetaData instance to load table information
metadata = MetaData()
# Drop all tables from all schemas
for schema_name in schema_list:
# Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
table_names.append(str(table))
metadata.clear()
table_names.extend(metadata.tables.keys()) # Append table names from schema
metadata.clear() # Clear metadata for the next schema
return table_names
async def get_data(self, table_name: str, filters: dict = None):
@ -345,3 +343,94 @@ class SQLAlchemyAdapter:
raise e
logger.info("Database deleted successfully.")
async def extract_schema(self):
async with self.engine.begin() as connection:
tables = await self.get_table_names()
schema = {}
if self.engine.dialect.name == "sqlite":
for table_name in tables:
schema[table_name] = {"columns": [], "primary_key": None, "foreign_keys": []}
# Get column details
columns_result = await connection.execute(
text(f"PRAGMA table_info('{table_name}');")
)
columns = columns_result.fetchall()
for column in columns:
column_name = column[1]
column_type = column[2]
is_pk = column[5] == 1
schema[table_name]["columns"].append(
{"name": column_name, "type": column_type}
)
if is_pk:
schema[table_name]["primary_key"] = column_name
# Get foreign key details
foreign_keys_results = await connection.execute(
text(f"PRAGMA foreign_key_list('{table_name}');")
)
foreign_keys = foreign_keys_results.fetchall()
for fk in foreign_keys:
schema[table_name]["foreign_keys"].append(
{
"column": fk[3], # Column in the current table
"ref_table": fk[2], # Referenced table
"ref_column": fk[4], # Referenced column
}
)
else:
schema_list = await self.get_schema_list()
for schema_name in schema_list:
# Get tables for the current schema via the inspector.
tables = await connection.run_sync(
lambda sync_conn: inspect(sync_conn).get_table_names(schema=schema_name)
)
for table_name in tables:
# Optionally, qualify the table name with the schema if not in the default schema.
key = (
table_name if schema_name == "public" else f"{schema_name}.{table_name}"
)
schema[key] = {"columns": [], "primary_key": None, "foreign_keys": []}
# Helper function to get table details using the inspector.
def get_details(sync_conn, table, schema_name):
insp = inspect(sync_conn)
cols = insp.get_columns(table, schema=schema_name)
pk = insp.get_pk_constraint(table, schema=schema_name)
fks = insp.get_foreign_keys(table, schema=schema_name)
return cols, pk, fks
cols, pk, fks = await connection.run_sync(
get_details, table_name, schema_name
)
for column in cols:
# Convert the type to string
schema[key]["columns"].append(
{"name": column["name"], "type": str(column["type"])}
)
pk_columns = pk.get("constrained_columns", [])
if pk_columns:
schema[key]["primary_key"] = pk_columns[0]
for fk in fks:
for col, ref_col in zip(
fk.get("constrained_columns", []), fk.get("referred_columns", [])
):
if col and ref_col:
schema[key]["foreign_keys"].append(
{
"column": col,
"ref_table": fk.get("referred_table"),
"ref_column": ref_col,
}
)
else:
logger.warning(
f"Missing value in foreign key information. \nColumn value: {col}\nReference column value: {ref_col}\n"
)
return schema

View file

@ -22,6 +22,7 @@ class DataPoint(BaseModel):
updated_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
)
ontology_valid: bool = False
version: int = 1 # Default version
topological_rank: Optional[int] = 0
metadata: Optional[MetaData] = {"index_fields": []}

View file

@ -7,6 +7,5 @@ class Entity(DataPoint):
name: str
is_a: Optional[EntityType] = None
description: str
ontology_valid: bool = False
metadata: dict = {"index_fields": ["name"]}

View file

@ -4,6 +4,5 @@ from cognee.infrastructure.engine import DataPoint
class EntityType(DataPoint):
name: str
description: str
ontology_valid: bool = False
metadata: dict = {"index_fields": ["name"]}

View file

@ -0,0 +1,12 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models.TableType import TableType
from typing import Optional
class TableRow(DataPoint):
name: str
is_a: Optional[TableType] = None
description: str
properties: str
metadata: dict = {"index_fields": ["properties"]}

View file

@ -0,0 +1,8 @@
from cognee.infrastructure.engine import DataPoint
class TableType(DataPoint):
name: str
description: str
metadata: dict = {"index_fields": ["name"]}

View file

@ -1,2 +1,4 @@
from .Entity import Entity
from .EntityType import EntityType
from .TableRow import TableRow
from .TableType import TableType

View file

@ -1,5 +1,6 @@
import json
from uuid import UUID
from decimal import Decimal
from datetime import datetime
from pydantic_core import PydanticUndefined
from pydantic import create_model
@ -14,6 +15,8 @@ class JSONEncoder(json.JSONEncoder):
elif isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return str(obj)
elif isinstance(obj, Decimal):
return float(obj)
return json.JSONEncoder.default(self, obj)

View file

@ -20,6 +20,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
"EntityType": "#6510f4",
"DocumentChunk": "#801212",
"TextSummary": "#1077f4",
"TableRow": "#f47710",
"TableType": "#6510f4",
"default": "#D3D3D3",
}
@ -182,8 +184,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
</html>
"""
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list))
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str))
html_content = html_content.replace("{links}", json.dumps(links_list, default=str))
if not destination_file_path:
home_dir = os.path.expanduser("~")

View file

@ -1,3 +1,4 @@
from .save_data_item_to_storage import save_data_item_to_storage
from .ingest_data import ingest_data
from .resolve_data_directories import resolve_data_directories
from .migrate_relational_database import migrate_relational_database

View file

@ -0,0 +1,164 @@
import logging
from decimal import Decimal
from uuid import uuid5, NAMESPACE_OID
from sqlalchemy import text
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine,
)
from cognee.tasks.storage.index_data_points import index_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from uuid import uuid4
from cognee.modules.engine.models import TableRow, TableType
logger = logging.getLogger(__name__)
async def migrate_relational_database(graph_db, schema):
"""
Migrates data from a relational database into a graph database.
For each table in the schema:
- Creates a TableType node representing the table
- Fetches all rows and creates a TableRow node for each row
- Links each TableRow node to its TableType node with an "is_part_of" relationship
Then, for every foreign key defined in the schema:
- Establishes relationships between TableRow nodes based on foreign key relationships
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
"""
engine = get_migration_relational_engine()
# Create a mapping of node_id to node objects for referencing in edge creation
node_mapping = {}
edge_mapping = []
async with engine.engine.begin() as cursor:
# First, create table type nodes for all tables
for table_name, details in schema.items():
# Create a TableType node for each table
table_node = TableType(
id=uuid5(NAMESPACE_OID, name=table_name),
name=table_name,
description=f"Table: {table_name}",
)
# Add TableType node to mapping ( node will be added to the graph later based on this mapping )
node_mapping[table_name] = table_node
# Fetch all rows for the current table
rows_result = await cursor.execute(text(f"SELECT * FROM {table_name};"))
rows = rows_result.fetchall()
for row in rows:
# Build a dictionary of properties from the row
row_properties = {
col["name"]: row[idx] for idx, col in enumerate(details["columns"])
}
# Determine the primary key value
if not details["primary_key"]:
# Use the first column as primary key if not specified
primary_key_col = details["columns"][0]["name"]
primary_key_value = row_properties[primary_key_col]
else:
# Use value of the specified primary key column
primary_key_col = details["primary_key"]
primary_key_value = row_properties[primary_key_col]
# Create a node ID in the format "table_name:primary_key_value"
node_id = f"{table_name}:{primary_key_value}"
# Create a TableRow node
# Node id must uniquely map to the id used in the relational database
# To catch the foreign key relationships properly
row_node = TableRow(
id=uuid5(NAMESPACE_OID, name=node_id),
name=node_id,
is_a=table_node,
properties=str(row_properties),
description=f"Row in {table_name} with {primary_key_col}={primary_key_value}",
)
# Store the node object in our mapping
node_mapping[node_id] = row_node
# Add edge between row node and table node ( it will be added to the graph later )
edge_mapping.append(
(
row_node.id,
table_node.id,
"is_part_of",
dict(
relationship_name="is_part_of",
source_node_id=row_node.id,
target_node_id=table_node.id,
),
)
)
# Process foreign key relationships after all nodes are created
for table_name, details in schema.items():
# Process foreign key relationships for the current table
for fk in details.get("foreign_keys", []):
# Aliases needed for self-referencing tables
alias_1 = f"{table_name}_e1"
alias_2 = f"{fk['ref_table']}_e2"
# Determine primary key column
if not details["primary_key"]:
primary_key_col = details["columns"][0]["name"]
else:
primary_key_col = details["primary_key"]
# Query to find relationships based on foreign keys
fk_query = text(
f"SELECT {alias_1}.{primary_key_col} AS source_id, "
f"{alias_2}.{fk['ref_column']} AS ref_value "
f"FROM {table_name} AS {alias_1} "
f"JOIN {fk['ref_table']} AS {alias_2} "
f"ON {alias_1}.{fk['column']} = {alias_2}.{fk['ref_column']};"
)
fk_result = await cursor.execute(fk_query)
relations = fk_result.fetchall()
for source_id, ref_value in relations:
# Construct node ids
source_node_id = f"{table_name}:{source_id}"
target_node_id = f"{fk['ref_table']}:{ref_value}"
# Get the source and target node objects from our mapping
source_node = node_mapping[source_node_id]
target_node = node_mapping[target_node_id]
# Add edge representing the foreign key relationship using the node objects
# Create edge to add to graph later
edge_mapping.append(
(
source_node.id,
target_node.id,
fk["column"],
dict(
source_node_id=source_node.id,
target_node_id=target_node.id,
relationship_name=fk["column"],
),
)
)
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(edge_mapping)
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()

View file

@ -38,7 +38,11 @@ async def index_data_points(data_points: list[DataPoint]):
index_name = index_name_and_field[:first_occurence]
field_name = index_name_and_field[first_occurence + 1 :]
try:
await vector_engine.index_data_points(index_name, field_name, indexable_points)
# In case the ammount if indexable points is too large we need to send them in batches
batch_size = 1000
for i in range(0, len(indexable_points), batch_size):
batch = indexable_points[i : i + batch_size]
await vector_engine.index_data_points(index_name, field_name, batch)
except EmbeddingException as e:
logger.warning(f"Failed to index data points for {index_name}.{field_name}: {e}")

View file

@ -0,0 +1,72 @@
import asyncio
import cognee
import os
import logging
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
)
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
# Prerequisites:
# 1. Copy `.env.template` and rename it to `.env`.
# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field:
# LLM_API_KEY = "your_key_here"
# 3. Fill all relevant MIGRATION_DB information for the database you want to migrate to graph / Cognee
async def main():
engine = get_migration_relational_engine()
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Needed to create principals table
# Create tables for databases
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
print("\nExtracting schema of database to migrate.")
schema = await engine.extract_schema()
print(f"Migrated database schema:\n{schema}")
graph = await get_graph_engine()
print("Migrating relational database to graph database based on schema.")
from cognee.tasks.ingestion import migrate_relational_database
await migrate_relational_database(graph, schema=schema)
print("Relational database migration complete.")
# Define location where to store html visualization of graph of the migrated database
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
# test.html is a file with visualized data migration
print("Adding html visualization of graph database after migration.")
await visualize_graph(destination_file_path)
print(f"Visualization can be found at: {destination_file_path}")
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="What kind of data do you contain?"
)
print(f"Search results: {search_results}")
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())