improved code readability by splitting code blocks under conditional statements into separate functions

This commit is contained in:
Geoff-Robin 2025-09-26 00:58:43 +05:30 committed by Igor Ilic
parent 656894370e
commit 2921021ca3

View file

@ -9,7 +9,6 @@ from cognee.infrastructure.databases.relational.config import get_migration_conf
from cognee.tasks.storage.index_data_points import index_data_points from cognee.tasks.storage.index_data_points import index_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.tasks.schema.ingest_database_schema import ingest_database_schema from cognee.tasks.schema.ingest_database_schema import ingest_database_schema
from cognee.tasks.schema.models import SchemaTable
from cognee.modules.engine.models import TableRow, TableType, ColumnValue from cognee.modules.engine.models import TableRow, TableType, ColumnValue
@ -31,12 +30,56 @@ async def migrate_relational_database(
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model. Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
""" """
engine = get_migration_relational_engine()
# Create a mapping of node_id to node objects for referencing in edge creation # Create a mapping of node_id to node objects for referencing in edge creation
node_mapping = {} node_mapping = {}
edge_mapping = [] edge_mapping = []
if schema_only: if schema_only:
node_mapping, edge_mapping = await schema_only_ingestion()
else:
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()
async def schema_only_ingestion():
node_mapping = {}
edge_mapping = []
database_config = get_migration_config().to_dict() database_config = get_migration_config().to_dict()
# Calling the ingest_database_schema function to return DataPoint subclasses # Calling the ingest_database_schema function to return DataPoint subclasses
result = await ingest_database_schema( result = await ingest_database_schema(
@ -110,8 +153,14 @@ async def migrate_relational_database(
), ),
) )
) )
return node_mapping, edge_mapping
else:
async def complete_database_ingestion(schema, migrate_column_data):
engine = get_migration_relational_engine()
# Create a mapping of node_id to node objects for referencing in edge creation
node_mapping = {}
edge_mapping = []
async with engine.engine.begin() as cursor: async with engine.engine.begin() as cursor:
# First, create table type nodes for all tables # First, create table type nodes for all tables
for table_name, details in schema.items(): for table_name, details in schema.items():
@ -261,39 +310,4 @@ async def migrate_relational_database(
), ),
) )
) )
return node_mapping, edge_mapping
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()