feat: Relational db to graph db [COG-1468] (#644)
<!-- .github/pull_request_template.md --> ## Description Add ability to migrate relational database to graph database ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin
This commit is contained in:
parent
897a1f3081
commit
9f587a01a4
20 changed files with 441 additions and 24 deletions
|
|
@ -47,3 +47,14 @@ DB_NAME=cognee_db
|
|||
# DB_PORT=5432
|
||||
# DB_USERNAME=cognee
|
||||
# DB_PASSWORD=cognee
|
||||
|
||||
|
||||
# Params for migrating relational database data to graph / Cognee ( PostgreSQL and SQLite supported )
|
||||
# MIGRATION_DB_PATH="/path/to/migration/directory"
|
||||
# MIGRATION_DB_NAME="migration_database.sqlite"
|
||||
# MIGRATION_DB_PROVIDER="sqlite"
|
||||
# Postgres specific parameters for migration
|
||||
# MIGRATION_DB_USERNAME=cognee
|
||||
# MIGRATION_DB_PASSWORD=cognee
|
||||
# MIGRATION_DB_HOST="127.0.0.1"
|
||||
# MIGRATION_DB_PORT=5432
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from kuzu import Connection
|
|||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
import aiofiles
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
|
|||
|
|
@ -199,13 +199,14 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
|
||||
query = dedent(
|
||||
"""MATCH (from_node {id: $from_node}),
|
||||
(to_node {id: $to_node})
|
||||
MERGE (from_node)-[r]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
||||
f"""\
|
||||
MATCH (from_node {{id: $from_node}}),
|
||||
(to_node {{id: $to_node}})
|
||||
MERGE (from_node)-[r:{relationship_name}]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp()
|
||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||
RETURN r
|
||||
"""
|
||||
"""
|
||||
)
|
||||
|
||||
params = {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import os
|
|||
import json
|
||||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from sqlalchemy import text
|
||||
from typing import Dict, Any, List, Union
|
||||
from uuid import UUID
|
||||
import aiofiles
|
||||
|
|
@ -88,6 +89,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
key=relationship_name,
|
||||
**(edge_properties if edge_properties else {}),
|
||||
)
|
||||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def add_edges(
|
||||
|
|
@ -315,11 +317,13 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
logger.error(e)
|
||||
raise e
|
||||
|
||||
if isinstance(edge["updated_at"], int): # Handle timestamp in milliseconds
|
||||
if isinstance(
|
||||
edge.get("updated_at"), int
|
||||
): # Handle timestamp in milliseconds
|
||||
edge["updated_at"] = datetime.fromtimestamp(
|
||||
edge["updated_at"] / 1000, tz=timezone.utc
|
||||
)
|
||||
elif isinstance(edge["updated_at"], str):
|
||||
elif isinstance(edge.get("updated_at"), str):
|
||||
edge["updated_at"] = datetime.strptime(
|
||||
edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from .ModelBase import Base
|
||||
from .config import get_relational_config
|
||||
from .config import get_migration_config
|
||||
from .create_db_and_tables import create_db_and_tables
|
||||
from .get_relational_engine import get_relational_engine
|
||||
from .get_migration_relational_engine import get_migration_relational_engine
|
||||
|
|
|
|||
|
|
@ -31,3 +31,31 @@ class RelationalConfig(BaseSettings):
|
|||
@lru_cache
|
||||
def get_relational_config():
|
||||
return RelationalConfig()
|
||||
|
||||
|
||||
class MigrationConfig(BaseSettings):
|
||||
migration_db_path: Union[str, None] = None
|
||||
migration_db_name: str = None
|
||||
migration_db_host: Union[str, None] = None
|
||||
migration_db_port: Union[str, None] = None
|
||||
migration_db_username: Union[str, None] = None
|
||||
migration_db_password: Union[str, None] = None
|
||||
migration_db_provider: str = None
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"migration_db_path": self.migration_db_path,
|
||||
"migration_db_name": self.migration_db_name,
|
||||
"migration_db_host": self.migration_db_host,
|
||||
"migration_db_port": self.migration_db_port,
|
||||
"migration_db_username": self.migration_db_username,
|
||||
"migration_db_password": self.migration_db_password,
|
||||
"migration_db_provider": self.migration_db_provider,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_migration_config():
|
||||
return MigrationConfig()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
from .config import get_migration_config
|
||||
from .create_relational_engine import create_relational_engine
|
||||
|
||||
|
||||
def get_migration_relational_engine():
|
||||
migration_config = get_migration_config()
|
||||
|
||||
return create_relational_engine(
|
||||
db_path=migration_config.migration_db_path,
|
||||
db_name=migration_config.migration_db_name,
|
||||
db_host=migration_config.migration_db_host,
|
||||
db_port=migration_config.migration_db_port,
|
||||
db_username=migration_config.migration_db_username,
|
||||
db_password=migration_config.migration_db_password,
|
||||
db_provider=migration_config.migration_db_provider,
|
||||
)
|
||||
|
|
@ -5,7 +5,7 @@ from uuid import UUID
|
|||
from typing import Optional
|
||||
from typing import AsyncGenerator, List
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import text, select, MetaData, Table, delete
|
||||
from sqlalchemy import text, select, MetaData, Table, delete, inspect
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
|
@ -229,25 +229,23 @@ class SQLAlchemyAdapter:
|
|||
|
||||
async def get_table_names(self) -> List[str]:
|
||||
"""
|
||||
Return a list of all tables names in database
|
||||
Return a list of all table names in the database, even if they don't have defined SQLAlchemy models.
|
||||
"""
|
||||
table_names = []
|
||||
async with self.engine.begin() as connection:
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
await connection.run_sync(Base.metadata.reflect)
|
||||
for table in Base.metadata.tables:
|
||||
table_names.append(str(table))
|
||||
# Use a new MetaData instance to reflect all tables
|
||||
metadata = MetaData()
|
||||
await connection.run_sync(metadata.reflect) # Reflect the entire database
|
||||
table_names = list(metadata.tables.keys()) # Get table names
|
||||
else:
|
||||
schema_list = await self.get_schema_list()
|
||||
# Create a MetaData instance to load table information
|
||||
metadata = MetaData()
|
||||
# Drop all tables from all schemas
|
||||
for schema_name in schema_list:
|
||||
# Load the schema information into the MetaData object
|
||||
await connection.run_sync(metadata.reflect, schema=schema_name)
|
||||
for table in metadata.sorted_tables:
|
||||
table_names.append(str(table))
|
||||
metadata.clear()
|
||||
table_names.extend(metadata.tables.keys()) # Append table names from schema
|
||||
metadata.clear() # Clear metadata for the next schema
|
||||
|
||||
return table_names
|
||||
|
||||
async def get_data(self, table_name: str, filters: dict = None):
|
||||
|
|
@ -345,3 +343,94 @@ class SQLAlchemyAdapter:
|
|||
raise e
|
||||
|
||||
logger.info("Database deleted successfully.")
|
||||
|
||||
async def extract_schema(self):
|
||||
async with self.engine.begin() as connection:
|
||||
tables = await self.get_table_names()
|
||||
|
||||
schema = {}
|
||||
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
for table_name in tables:
|
||||
schema[table_name] = {"columns": [], "primary_key": None, "foreign_keys": []}
|
||||
|
||||
# Get column details
|
||||
columns_result = await connection.execute(
|
||||
text(f"PRAGMA table_info('{table_name}');")
|
||||
)
|
||||
columns = columns_result.fetchall()
|
||||
for column in columns:
|
||||
column_name = column[1]
|
||||
column_type = column[2]
|
||||
is_pk = column[5] == 1
|
||||
schema[table_name]["columns"].append(
|
||||
{"name": column_name, "type": column_type}
|
||||
)
|
||||
if is_pk:
|
||||
schema[table_name]["primary_key"] = column_name
|
||||
|
||||
# Get foreign key details
|
||||
foreign_keys_results = await connection.execute(
|
||||
text(f"PRAGMA foreign_key_list('{table_name}');")
|
||||
)
|
||||
foreign_keys = foreign_keys_results.fetchall()
|
||||
for fk in foreign_keys:
|
||||
schema[table_name]["foreign_keys"].append(
|
||||
{
|
||||
"column": fk[3], # Column in the current table
|
||||
"ref_table": fk[2], # Referenced table
|
||||
"ref_column": fk[4], # Referenced column
|
||||
}
|
||||
)
|
||||
else:
|
||||
schema_list = await self.get_schema_list()
|
||||
for schema_name in schema_list:
|
||||
# Get tables for the current schema via the inspector.
|
||||
tables = await connection.run_sync(
|
||||
lambda sync_conn: inspect(sync_conn).get_table_names(schema=schema_name)
|
||||
)
|
||||
for table_name in tables:
|
||||
# Optionally, qualify the table name with the schema if not in the default schema.
|
||||
key = (
|
||||
table_name if schema_name == "public" else f"{schema_name}.{table_name}"
|
||||
)
|
||||
schema[key] = {"columns": [], "primary_key": None, "foreign_keys": []}
|
||||
|
||||
# Helper function to get table details using the inspector.
|
||||
def get_details(sync_conn, table, schema_name):
|
||||
insp = inspect(sync_conn)
|
||||
cols = insp.get_columns(table, schema=schema_name)
|
||||
pk = insp.get_pk_constraint(table, schema=schema_name)
|
||||
fks = insp.get_foreign_keys(table, schema=schema_name)
|
||||
return cols, pk, fks
|
||||
|
||||
cols, pk, fks = await connection.run_sync(
|
||||
get_details, table_name, schema_name
|
||||
)
|
||||
|
||||
for column in cols:
|
||||
# Convert the type to string
|
||||
schema[key]["columns"].append(
|
||||
{"name": column["name"], "type": str(column["type"])}
|
||||
)
|
||||
pk_columns = pk.get("constrained_columns", [])
|
||||
if pk_columns:
|
||||
schema[key]["primary_key"] = pk_columns[0]
|
||||
for fk in fks:
|
||||
for col, ref_col in zip(
|
||||
fk.get("constrained_columns", []), fk.get("referred_columns", [])
|
||||
):
|
||||
if col and ref_col:
|
||||
schema[key]["foreign_keys"].append(
|
||||
{
|
||||
"column": col,
|
||||
"ref_table": fk.get("referred_table"),
|
||||
"ref_column": ref_col,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Missing value in foreign key information. \nColumn value: {col}\nReference column value: {ref_col}\n"
|
||||
)
|
||||
|
||||
return schema
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class DataPoint(BaseModel):
|
|||
updated_at: int = Field(
|
||||
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
ontology_valid: bool = False
|
||||
version: int = 1 # Default version
|
||||
topological_rank: Optional[int] = 0
|
||||
metadata: Optional[MetaData] = {"index_fields": []}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,5 @@ class Entity(DataPoint):
|
|||
name: str
|
||||
is_a: Optional[EntityType] = None
|
||||
description: str
|
||||
ontology_valid: bool = False
|
||||
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,5 @@ from cognee.infrastructure.engine import DataPoint
|
|||
class EntityType(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
ontology_valid: bool = False
|
||||
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
|
|||
12
cognee/modules/engine/models/TableRow.py
Normal file
12
cognee/modules/engine/models/TableRow.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.models.TableType import TableType
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TableRow(DataPoint):
|
||||
name: str
|
||||
is_a: Optional[TableType] = None
|
||||
description: str
|
||||
properties: str
|
||||
|
||||
metadata: dict = {"index_fields": ["properties"]}
|
||||
8
cognee/modules/engine/models/TableType.py
Normal file
8
cognee/modules/engine/models/TableType.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class TableType(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
@ -1,2 +1,4 @@
|
|||
from .Entity import Entity
|
||||
from .EntityType import EntityType
|
||||
from .TableRow import TableRow
|
||||
from .TableType import TableType
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from pydantic_core import PydanticUndefined
|
||||
from pydantic import create_model
|
||||
|
|
@ -14,6 +15,8 @@ class JSONEncoder(json.JSONEncoder):
|
|||
elif isinstance(obj, UUID):
|
||||
# if the obj is uuid, we simply return the value of uuid
|
||||
return str(obj)
|
||||
elif isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
"EntityType": "#6510f4",
|
||||
"DocumentChunk": "#801212",
|
||||
"TextSummary": "#1077f4",
|
||||
"TableRow": "#f47710",
|
||||
"TableType": "#6510f4",
|
||||
"default": "#D3D3D3",
|
||||
}
|
||||
|
||||
|
|
@ -182,8 +184,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
</html>
|
||||
"""
|
||||
|
||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list))
|
||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, default=str))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list, default=str))
|
||||
|
||||
if not destination_file_path:
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .save_data_item_to_storage import save_data_item_to_storage
|
||||
from .ingest_data import ingest_data
|
||||
from .resolve_data_directories import resolve_data_directories
|
||||
from .migrate_relational_database import migrate_relational_database
|
||||
|
|
|
|||
164
cognee/tasks/ingestion/migrate_relational_database.py
Normal file
164
cognee/tasks/ingestion/migrate_relational_database.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import logging
|
||||
from decimal import Decimal
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from sqlalchemy import text
|
||||
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
|
||||
get_migration_relational_engine,
|
||||
)
|
||||
|
||||
from cognee.tasks.storage.index_data_points import index_data_points
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.modules.engine.models import TableRow, TableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def migrate_relational_database(graph_db, schema):
|
||||
"""
|
||||
Migrates data from a relational database into a graph database.
|
||||
|
||||
For each table in the schema:
|
||||
- Creates a TableType node representing the table
|
||||
- Fetches all rows and creates a TableRow node for each row
|
||||
- Links each TableRow node to its TableType node with an "is_part_of" relationship
|
||||
Then, for every foreign key defined in the schema:
|
||||
- Establishes relationships between TableRow nodes based on foreign key relationships
|
||||
|
||||
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
|
||||
"""
|
||||
engine = get_migration_relational_engine()
|
||||
# Create a mapping of node_id to node objects for referencing in edge creation
|
||||
node_mapping = {}
|
||||
edge_mapping = []
|
||||
|
||||
async with engine.engine.begin() as cursor:
|
||||
# First, create table type nodes for all tables
|
||||
for table_name, details in schema.items():
|
||||
# Create a TableType node for each table
|
||||
table_node = TableType(
|
||||
id=uuid5(NAMESPACE_OID, name=table_name),
|
||||
name=table_name,
|
||||
description=f"Table: {table_name}",
|
||||
)
|
||||
|
||||
# Add TableType node to mapping ( node will be added to the graph later based on this mapping )
|
||||
node_mapping[table_name] = table_node
|
||||
|
||||
# Fetch all rows for the current table
|
||||
rows_result = await cursor.execute(text(f"SELECT * FROM {table_name};"))
|
||||
rows = rows_result.fetchall()
|
||||
|
||||
for row in rows:
|
||||
# Build a dictionary of properties from the row
|
||||
row_properties = {
|
||||
col["name"]: row[idx] for idx, col in enumerate(details["columns"])
|
||||
}
|
||||
|
||||
# Determine the primary key value
|
||||
if not details["primary_key"]:
|
||||
# Use the first column as primary key if not specified
|
||||
primary_key_col = details["columns"][0]["name"]
|
||||
primary_key_value = row_properties[primary_key_col]
|
||||
else:
|
||||
# Use value of the specified primary key column
|
||||
primary_key_col = details["primary_key"]
|
||||
primary_key_value = row_properties[primary_key_col]
|
||||
|
||||
# Create a node ID in the format "table_name:primary_key_value"
|
||||
node_id = f"{table_name}:{primary_key_value}"
|
||||
|
||||
# Create a TableRow node
|
||||
# Node id must uniquely map to the id used in the relational database
|
||||
# To catch the foreign key relationships properly
|
||||
row_node = TableRow(
|
||||
id=uuid5(NAMESPACE_OID, name=node_id),
|
||||
name=node_id,
|
||||
is_a=table_node,
|
||||
properties=str(row_properties),
|
||||
description=f"Row in {table_name} with {primary_key_col}={primary_key_value}",
|
||||
)
|
||||
|
||||
# Store the node object in our mapping
|
||||
node_mapping[node_id] = row_node
|
||||
|
||||
# Add edge between row node and table node ( it will be added to the graph later )
|
||||
edge_mapping.append(
|
||||
(
|
||||
row_node.id,
|
||||
table_node.id,
|
||||
"is_part_of",
|
||||
dict(
|
||||
relationship_name="is_part_of",
|
||||
source_node_id=row_node.id,
|
||||
target_node_id=table_node.id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Process foreign key relationships after all nodes are created
|
||||
for table_name, details in schema.items():
|
||||
# Process foreign key relationships for the current table
|
||||
for fk in details.get("foreign_keys", []):
|
||||
# Aliases needed for self-referencing tables
|
||||
alias_1 = f"{table_name}_e1"
|
||||
alias_2 = f"{fk['ref_table']}_e2"
|
||||
|
||||
# Determine primary key column
|
||||
if not details["primary_key"]:
|
||||
primary_key_col = details["columns"][0]["name"]
|
||||
else:
|
||||
primary_key_col = details["primary_key"]
|
||||
|
||||
# Query to find relationships based on foreign keys
|
||||
fk_query = text(
|
||||
f"SELECT {alias_1}.{primary_key_col} AS source_id, "
|
||||
f"{alias_2}.{fk['ref_column']} AS ref_value "
|
||||
f"FROM {table_name} AS {alias_1} "
|
||||
f"JOIN {fk['ref_table']} AS {alias_2} "
|
||||
f"ON {alias_1}.{fk['column']} = {alias_2}.{fk['ref_column']};"
|
||||
)
|
||||
|
||||
fk_result = await cursor.execute(fk_query)
|
||||
relations = fk_result.fetchall()
|
||||
|
||||
for source_id, ref_value in relations:
|
||||
# Construct node ids
|
||||
source_node_id = f"{table_name}:{source_id}"
|
||||
target_node_id = f"{fk['ref_table']}:{ref_value}"
|
||||
|
||||
# Get the source and target node objects from our mapping
|
||||
source_node = node_mapping[source_node_id]
|
||||
target_node = node_mapping[target_node_id]
|
||||
|
||||
# Add edge representing the foreign key relationship using the node objects
|
||||
# Create edge to add to graph later
|
||||
edge_mapping.append(
|
||||
(
|
||||
source_node.id,
|
||||
target_node.id,
|
||||
fk["column"],
|
||||
dict(
|
||||
source_node_id=source_node.id,
|
||||
target_node_id=target_node.id,
|
||||
relationship_name=fk["column"],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Add all nodes and edges to the graph
|
||||
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
|
||||
# If we'd create nodes and add them to graph in real time the process would take too long.
|
||||
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
|
||||
await graph_db.add_nodes(list(node_mapping.values()))
|
||||
await graph_db.add_edges(edge_mapping)
|
||||
|
||||
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
|
||||
# Cognee uses this information to perform searches on the knowledge graph.
|
||||
await index_data_points(list(node_mapping.values()))
|
||||
await index_graph_edges()
|
||||
|
||||
logger.info("Data successfully migrated from relational database to desired graph database.")
|
||||
return await graph_db.get_graph_data()
|
||||
|
|
@ -38,7 +38,11 @@ async def index_data_points(data_points: list[DataPoint]):
|
|||
index_name = index_name_and_field[:first_occurence]
|
||||
field_name = index_name_and_field[first_occurence + 1 :]
|
||||
try:
|
||||
await vector_engine.index_data_points(index_name, field_name, indexable_points)
|
||||
# In case the ammount if indexable points is too large we need to send them in batches
|
||||
batch_size = 1000
|
||||
for i in range(0, len(indexable_points), batch_size):
|
||||
batch = indexable_points[i : i + batch_size]
|
||||
await vector_engine.index_data_points(index_name, field_name, batch)
|
||||
except EmbeddingException as e:
|
||||
logger.warning(f"Failed to index data points for {index_name}.{field_name}: {e}")
|
||||
|
||||
|
|
|
|||
72
examples/python/relational_database_migration_example.py
Normal file
72
examples/python/relational_database_migration_example.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
import os
|
||||
import logging
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
get_migration_relational_engine,
|
||||
)
|
||||
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.pgvector import (
|
||||
create_db_and_tables as create_pgvector_db_and_tables,
|
||||
)
|
||||
|
||||
# Prerequisites:
|
||||
# 1. Copy `.env.template` and rename it to `.env`.
|
||||
# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field:
|
||||
# LLM_API_KEY = "your_key_here"
|
||||
# 3. Fill all relevant MIGRATION_DB information for the database you want to migrate to graph / Cognee
|
||||
|
||||
|
||||
async def main():
|
||||
engine = get_migration_relational_engine()
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Needed to create principals table
|
||||
# Create tables for databases
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
print("\nExtracting schema of database to migrate.")
|
||||
schema = await engine.extract_schema()
|
||||
print(f"Migrated database schema:\n{schema}")
|
||||
|
||||
graph = await get_graph_engine()
|
||||
print("Migrating relational database to graph database based on schema.")
|
||||
from cognee.tasks.ingestion import migrate_relational_database
|
||||
|
||||
await migrate_relational_database(graph, schema=schema)
|
||||
print("Relational database migration complete.")
|
||||
|
||||
# Define location where to store html visualization of graph of the migrated database
|
||||
home_dir = os.path.expanduser("~")
|
||||
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
|
||||
|
||||
# test.html is a file with visualized data migration
|
||||
print("Adding html visualization of graph database after migration.")
|
||||
await visualize_graph(destination_file_path)
|
||||
print(f"Visualization can be found at: {destination_file_path}")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="What kind of data do you contain?"
|
||||
)
|
||||
print(f"Search results: {search_results}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
Loading…
Add table
Reference in a new issue