refactor: Make relational database search more effective (#1477)

<!-- .github/pull_request_template.md -->

## Description
Enhance search results of relational db data by adding more information
on data type and content

PR also includes schema migration done by contributor ( Geoff-Robin )
from contribute to win competition

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-09-28 15:23:59 +02:00 committed by GitHub
commit 3b101ae8f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 419 additions and 54 deletions

View file

@ -23,6 +23,9 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
"TableRow": "#f47710",
"TableType": "#6510f4",
"ColumnValue": "#13613a",
"SchemaTable": "#f47710",
"DatabaseSchema": "#6510f4",
"SchemaRelationship": "#13613a",
"default": "#D3D3D3",
}

View file

@ -4,16 +4,20 @@ from sqlalchemy import text
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine,
)
from cognee.infrastructure.databases.relational.config import get_migration_config
from cognee.tasks.storage.index_data_points import index_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.tasks.schema.ingest_database_schema import ingest_database_schema
from cognee.modules.engine.models import TableRow, TableType, ColumnValue
logger = logging.getLogger(__name__)
async def migrate_relational_database(graph_db, schema, migrate_column_data=True):
async def migrate_relational_database(
graph_db, schema, migrate_column_data=True, schema_only=False
):
"""
Migrates data from a relational database into a graph database.
@ -26,11 +30,133 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
"""
# Create a mapping of node_id to node objects for referencing in edge creation
if schema_only:
node_mapping, edge_mapping = await schema_only_ingestion(schema)
else:
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()
async def schema_only_ingestion(schema):
node_mapping = {}
edge_mapping = []
# Calling the ingest_database_schema function to return DataPoint subclasses
result = await ingest_database_schema(
schema=schema,
max_sample_rows=5,
)
database_schema = result["database_schema"]
schema_tables = result["schema_tables"]
schema_relationships = result["relationships"]
database_node_id = database_schema.id
node_mapping[database_node_id] = database_schema
for table in schema_tables:
table_node_id = table.id
# Add TableSchema Datapoint as a node.
node_mapping[table_node_id] = table
edge_mapping.append(
(
table_node_id,
database_node_id,
"is_part_of",
dict(
source_node_id=table_node_id,
target_node_id=database_node_id,
relationship_name="is_part_of",
),
)
)
table_name_to_id = {t.name: t.id for t in schema_tables}
for rel in schema_relationships:
source_table_id = table_name_to_id.get(rel.source_table)
target_table_id = table_name_to_id.get(rel.target_table)
relationship_id = rel.id
# Add RelationshipTable DataPoint as a node.
node_mapping[relationship_id] = rel
edge_mapping.append(
(
source_table_id,
relationship_id,
"has_relationship",
dict(
source_node_id=source_table_id,
target_node_id=relationship_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
relationship_id,
target_table_id,
"has_relationship",
dict(
source_node_id=relationship_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
source_table_id,
target_table_id,
rel.relationship_type,
dict(
source_node_id=source_table_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
return node_mapping, edge_mapping
async def complete_database_ingestion(schema, migrate_column_data):
engine = get_migration_relational_engine()
# Create a mapping of node_id to node objects for referencing in edge creation
node_mapping = {}
edge_mapping = []
async with engine.engine.begin() as cursor:
# First, create table type nodes for all tables
for table_name, details in schema.items():
@ -38,7 +164,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
table_node = TableType(
id=uuid5(NAMESPACE_OID, name=table_name),
name=table_name,
description=f"Table: {table_name}",
description=f'Relational database table with the following name: "{table_name}".',
)
# Add TableType node to mapping ( node will be added to the graph later based on this mapping )
@ -75,7 +201,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
name=node_id,
is_a=table_node,
properties=str(row_properties),
description=f"Row in {table_name} with {primary_key_col}={primary_key_value}",
description=f'Row in relational database table from the table with the name: "{table_name}" with the following row data {str(row_properties)} where the dictionary key value is the column name and the value is the column value. This row has the id of: {node_id}',
)
# Store the node object in our mapping
@ -113,7 +239,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
id=uuid5(NAMESPACE_OID, name=column_node_id),
name=column_node_id,
properties=f"{key} {value} {table_name}",
description=f"Column name={key} and value={value} from column from table={table_name}",
description=f"column from relational database table={table_name}. Column name={key} and value={value}. The value of the column is related to the following row with this id: {row_node.id}. This column has the following ID: {column_node_id}",
)
node_mapping[column_node_id] = column_node
@ -180,39 +306,4 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
),
)
)
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()
return node_mapping, edge_mapping

View file

@ -0,0 +1,134 @@
from typing import List, Dict
from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from sqlalchemy import text
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine,
)
from cognee.infrastructure.databases.relational.config import get_migration_config
from datetime import datetime, timezone
async def ingest_database_schema(
schema,
max_sample_rows: int = 0,
) -> Dict[str, List[DataPoint] | DataPoint]:
"""
Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction.
Args:
schema: Database schema
max_sample_rows: Maximum sample rows per table (0 means no sampling)
Returns:
Dict with keys:
"database_schema": DatabaseSchema
"schema_tables": List[SchemaTable]
"relationships": List[SchemaRelationship]
"""
tables = {}
sample_data = {}
schema_tables = []
schema_relationships = []
migration_config = get_migration_config()
engine = get_migration_relational_engine()
qi = engine.engine.dialect.identifier_preparer.quote
try:
max_sample_rows = max(0, int(max_sample_rows))
except (TypeError, ValueError):
max_sample_rows = 0
def qname(name: str):
split_name = name.split(".")
return ".".join(qi(p) for p in split_name)
async with engine.engine.begin() as cursor:
for table_name, details in schema.items():
tn = qname(table_name)
if max_sample_rows > 0:
rows_result = await cursor.execute(
text(f"SELECT * FROM {tn} LIMIT :limit;"), # noqa: S608 - tn is fully quoted
{"limit": max_sample_rows},
)
rows = [dict(r) for r in rows_result.mappings().all()]
else:
rows = []
if engine.engine.dialect.name == "postgresql":
if "." in table_name:
schema_part, table_part = table_name.split(".", 1)
else:
schema_part, table_part = "public", table_name
estimate = await cursor.execute(
text(
"SELECT reltuples::bigint AS estimate "
"FROM pg_class c "
"JOIN pg_namespace n ON n.oid = c.relnamespace "
"WHERE n.nspname = :schema AND c.relname = :table"
),
{"schema": schema_part, "table": table_part},
)
row_count_estimate = estimate.scalar() or 0
else:
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) # noqa: S608 - tn is fully quoted
row_count_estimate = count_result.scalar()
schema_table = SchemaTable(
id=uuid5(NAMESPACE_OID, name=f"{table_name}"),
name=table_name,
columns=details["columns"],
primary_key=details.get("primary_key"),
foreign_keys=details.get("foreign_keys", []),
sample_rows=rows,
row_count_estimate=row_count_estimate,
description=f"Relational database table with '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows."
f"Here are the columns this table contains: {details['columns']}"
f"Here are a few sample_rows to show the contents of the table: {rows}"
f"Table is part of the database: {migration_config.migration_db_name}",
)
schema_tables.append(schema_table)
tables[table_name] = details
sample_data[table_name] = rows
for fk in details.get("foreign_keys", []):
ref_table_fq = fk["ref_table"]
if "." not in ref_table_fq and "." in table_name:
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
relationship_name = (
f"{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}"
)
relationship = SchemaRelationship(
id=uuid5(NAMESPACE_OID, name=relationship_name),
name=relationship_name,
source_table=table_name,
target_table=ref_table_fq,
relationship_type="foreign_key",
source_column=fk["column"],
target_column=fk["ref_column"],
description=f"Relational database table foreign key relationship between: {table_name}.{fk['column']}{ref_table_fq}.{fk['ref_column']}"
f"This foreing key relationship between table columns is a part of the following database: {migration_config.migration_db_name}",
)
schema_relationships.append(relationship)
id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}"
database_schema = DatabaseSchema(
id=uuid5(NAMESPACE_OID, name=id_str),
name=migration_config.migration_db_name,
database_type=migration_config.migration_db_provider,
tables=tables,
sample_data=sample_data,
extraction_timestamp=datetime.now(timezone.utc),
description=f"Database schema containing {len(schema_tables)} tables and {len(schema_relationships)} relationships. "
f"The database type is {migration_config.migration_db_provider}."
f"The database contains the following tables: {tables}",
)
return {
"database_schema": database_schema,
"schema_tables": schema_tables,
"relationships": schema_relationships,
}

View file

@ -0,0 +1,41 @@
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from typing import List, Dict, Optional
from datetime import datetime
class DatabaseSchema(DataPoint):
"""Represents a complete database schema with sample data"""
name: str
database_type: str # sqlite, postgres, etc.
tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter
sample_data: Dict[str, List[Dict]] # Limited examples per table
extraction_timestamp: datetime
description: str
metadata: dict = {"index_fields": ["description", "name"]}
class SchemaTable(DataPoint):
"""Represents an individual table schema with relationships"""
name: str
columns: List[Dict] # Column definitions with types
primary_key: Optional[str]
foreign_keys: List[Dict] # Foreign key relationships
sample_rows: List[Dict] # Max 3-5 example rows
row_count_estimate: Optional[int] # Actual table size
description: str
metadata: dict = {"index_fields": ["description", "name"]}
class SchemaRelationship(DataPoint):
"""Represents relationships between tables"""
name: str
source_table: str
target_table: str
relationship_type: str # "foreign_key", "one_to_many", etc.
source_column: str
target_column: str
description: str
metadata: dict = {"index_fields": ["description", "name"]}

View file

@ -197,6 +197,80 @@ async def relational_db_migration():
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
async def test_schema_only_migration():
# 1. Setup test DB and extract schema
migration_engine = await setup_test_db()
schema = await migration_engine.extract_schema()
# 2. Setup graph engine
graph_engine = await get_graph_engine()
# 4. Migrate schema only
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
# 5. Verify number of tables through search
search_results = await cognee.search(
query_text="How many tables are there in this database",
query_type=cognee.SearchType.GRAPH_COMPLETION,
top_k=30,
)
assert any("11" in r for r in search_results), (
"Number of tables in the database reported in search_results is either None or not equal to 11"
)
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
edge_counts = {
"is_part_of": 0,
"has_relationship": 0,
"foreign_key": 0,
}
if graph_db_provider == "neo4j":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:{rel_type}]->()
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0]["c"]
elif graph_db_provider == "kuzu":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:EDGE]->()
WHERE r.relationship_name = '{rel_type}'
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
for _, _, key, _ in edges:
if key in edge_counts:
edge_counts[key] += 1
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
# 7. Assert counts match expected values
expected_counts = {
"is_part_of": 11,
"has_relationship": 22,
"foreign_key": 11,
}
for rel_type, expected in expected_counts.items():
actual = edge_counts[rel_type]
assert actual == expected, (
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
)
print("Schema-only migration edge counts validated successfully!")
print(f"Edge counts: {edge_counts}")
async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
@ -209,6 +283,7 @@ async def test_migration_sqlite():
)
await relational_db_migration()
await test_schema_only_migration()
async def test_migration_postgres():
@ -224,6 +299,7 @@ async def test_migration_postgres():
}
)
await relational_db_migration()
await test_schema_only_migration()
async def main():

View file

@ -1,16 +1,15 @@
from pathlib import Path
import asyncio
import cognee
import os
import cognee
from cognee.infrastructure.databases.relational.config import get_migration_config
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
)
from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
@ -32,16 +31,29 @@ from cognee.infrastructure.databases.vector.pgvector import (
async def main():
engine = get_migration_relational_engine()
# Clean all data stored in Cognee
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Needed to create appropriate tables only on the Cognee side
# Needed to create appropriate database tables only on the Cognee side
await create_relational_db_and_tables()
await create_vector_db_and_tables()
# In case environment variables are not set use the example database from the Cognee repo
migration_db_provider = os.environ.get("MIGRATION_DB_PROVIDER", "sqlite")
migration_db_path = os.environ.get(
"MIGRATION_DB_PATH",
os.path.join(Path(__file__).resolve().parent.parent.parent, "cognee/tests/test_data"),
)
migration_db_name = os.environ.get("MIGRATION_DB_NAME", "migration_database.sqlite")
migration_config = get_migration_config()
migration_config.migration_db_provider = migration_db_provider
migration_config.migration_db_path = migration_db_path
migration_config.migration_db_name = migration_db_name
engine = get_migration_relational_engine()
print("\nExtracting schema of database to migrate.")
schema = await engine.extract_schema()
print(f"Migrated database schema:\n{schema}")
@ -53,10 +65,6 @@ async def main():
await migrate_relational_database(graph, schema=schema)
print("Relational database migration complete.")
# Define location where to store html visualization of graph of the migrated database
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
# Make sure to set top_k at a high value for a broader search, the default value is only 10!
# top_k represent the number of graph tripplets to supply to the LLM to answer your question
search_results = await cognee.search(
@ -69,13 +77,25 @@ async def main():
# Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered.
# For this kind of question we've set the top_k to 30
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_COT,
query_type=SearchType.GRAPH_COMPLETION,
query_text="What invoices are related to Leonie Köhler?",
top_k=30,
)
print(f"Search results: {search_results}")
# test.html is a file with visualized data migration
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What invoices are related to Luís Gonçalves?",
top_k=30,
)
print(f"Search results: {search_results}")
# If you check the relational database for this example you can see that the search results successfully found all
# the invoices related to the two customers, without any hallucinations or additional information
# Define location where to store html visualization of graph of the migrated database
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
print("Adding html visualization of graph database after migration.")
await visualize_graph(destination_file_path)
print(f"Visualization can be found at: {destination_file_path}")