refactor: Make relational database search more effective (#1477)

<!-- .github/pull_request_template.md -->

## Description
Enhance search results of relational db data by adding more information
on data type and content

PR also includes schema migration done by contributor ( Geoff-Robin )
from contribute to win competition

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-09-28 15:23:59 +02:00 committed by GitHub
commit 3b101ae8f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 419 additions and 54 deletions

View file

@ -23,6 +23,9 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
"TableRow": "#f47710", "TableRow": "#f47710",
"TableType": "#6510f4", "TableType": "#6510f4",
"ColumnValue": "#13613a", "ColumnValue": "#13613a",
"SchemaTable": "#f47710",
"DatabaseSchema": "#6510f4",
"SchemaRelationship": "#13613a",
"default": "#D3D3D3", "default": "#D3D3D3",
} }

View file

@ -4,16 +4,20 @@ from sqlalchemy import text
from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine, get_migration_relational_engine,
) )
from cognee.infrastructure.databases.relational.config import get_migration_config
from cognee.tasks.storage.index_data_points import index_data_points from cognee.tasks.storage.index_data_points import index_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.tasks.schema.ingest_database_schema import ingest_database_schema
from cognee.modules.engine.models import TableRow, TableType, ColumnValue from cognee.modules.engine.models import TableRow, TableType, ColumnValue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def migrate_relational_database(graph_db, schema, migrate_column_data=True): async def migrate_relational_database(
graph_db, schema, migrate_column_data=True, schema_only=False
):
""" """
Migrates data from a relational database into a graph database. Migrates data from a relational database into a graph database.
@ -26,11 +30,133 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model. Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
""" """
# Create a mapping of node_id to node objects for referencing in edge creation
if schema_only:
node_mapping, edge_mapping = await schema_only_ingestion(schema)
else:
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()
async def schema_only_ingestion(schema):
node_mapping = {}
edge_mapping = []
# Calling the ingest_database_schema function to return DataPoint subclasses
result = await ingest_database_schema(
schema=schema,
max_sample_rows=5,
)
database_schema = result["database_schema"]
schema_tables = result["schema_tables"]
schema_relationships = result["relationships"]
database_node_id = database_schema.id
node_mapping[database_node_id] = database_schema
for table in schema_tables:
table_node_id = table.id
# Add TableSchema Datapoint as a node.
node_mapping[table_node_id] = table
edge_mapping.append(
(
table_node_id,
database_node_id,
"is_part_of",
dict(
source_node_id=table_node_id,
target_node_id=database_node_id,
relationship_name="is_part_of",
),
)
)
table_name_to_id = {t.name: t.id for t in schema_tables}
for rel in schema_relationships:
source_table_id = table_name_to_id.get(rel.source_table)
target_table_id = table_name_to_id.get(rel.target_table)
relationship_id = rel.id
# Add RelationshipTable DataPoint as a node.
node_mapping[relationship_id] = rel
edge_mapping.append(
(
source_table_id,
relationship_id,
"has_relationship",
dict(
source_node_id=source_table_id,
target_node_id=relationship_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
relationship_id,
target_table_id,
"has_relationship",
dict(
source_node_id=relationship_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
source_table_id,
target_table_id,
rel.relationship_type,
dict(
source_node_id=source_table_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
return node_mapping, edge_mapping
async def complete_database_ingestion(schema, migrate_column_data):
engine = get_migration_relational_engine() engine = get_migration_relational_engine()
# Create a mapping of node_id to node objects for referencing in edge creation # Create a mapping of node_id to node objects for referencing in edge creation
node_mapping = {} node_mapping = {}
edge_mapping = [] edge_mapping = []
async with engine.engine.begin() as cursor: async with engine.engine.begin() as cursor:
# First, create table type nodes for all tables # First, create table type nodes for all tables
for table_name, details in schema.items(): for table_name, details in schema.items():
@ -38,7 +164,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
table_node = TableType( table_node = TableType(
id=uuid5(NAMESPACE_OID, name=table_name), id=uuid5(NAMESPACE_OID, name=table_name),
name=table_name, name=table_name,
description=f"Table: {table_name}", description=f'Relational database table with the following name: "{table_name}".',
) )
# Add TableType node to mapping ( node will be added to the graph later based on this mapping ) # Add TableType node to mapping ( node will be added to the graph later based on this mapping )
@ -75,7 +201,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
name=node_id, name=node_id,
is_a=table_node, is_a=table_node,
properties=str(row_properties), properties=str(row_properties),
description=f"Row in {table_name} with {primary_key_col}={primary_key_value}", description=f'Row in relational database table from the table with the name: "{table_name}" with the following row data {str(row_properties)} where the dictionary key value is the column name and the value is the column value. This row has the id of: {node_id}',
) )
# Store the node object in our mapping # Store the node object in our mapping
@ -113,7 +239,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
id=uuid5(NAMESPACE_OID, name=column_node_id), id=uuid5(NAMESPACE_OID, name=column_node_id),
name=column_node_id, name=column_node_id,
properties=f"{key} {value} {table_name}", properties=f"{key} {value} {table_name}",
description=f"Column name={key} and value={value} from column from table={table_name}", description=f"column from relational database table={table_name}. Column name={key} and value={value}. The value of the column is related to the following row with this id: {row_node.id}. This column has the following ID: {column_node_id}",
) )
node_mapping[column_node_id] = column_node node_mapping[column_node_id] = column_node
@ -180,39 +306,4 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
), ),
) )
) )
return node_mapping, edge_mapping
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()

View file

@ -0,0 +1,134 @@
from typing import List, Dict
from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from sqlalchemy import text
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine,
)
from cognee.infrastructure.databases.relational.config import get_migration_config
from datetime import datetime, timezone
async def ingest_database_schema(
schema,
max_sample_rows: int = 0,
) -> Dict[str, List[DataPoint] | DataPoint]:
"""
Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction.
Args:
schema: Database schema
max_sample_rows: Maximum sample rows per table (0 means no sampling)
Returns:
Dict with keys:
"database_schema": DatabaseSchema
"schema_tables": List[SchemaTable]
"relationships": List[SchemaRelationship]
"""
tables = {}
sample_data = {}
schema_tables = []
schema_relationships = []
migration_config = get_migration_config()
engine = get_migration_relational_engine()
qi = engine.engine.dialect.identifier_preparer.quote
try:
max_sample_rows = max(0, int(max_sample_rows))
except (TypeError, ValueError):
max_sample_rows = 0
def qname(name: str):
split_name = name.split(".")
return ".".join(qi(p) for p in split_name)
async with engine.engine.begin() as cursor:
for table_name, details in schema.items():
tn = qname(table_name)
if max_sample_rows > 0:
rows_result = await cursor.execute(
text(f"SELECT * FROM {tn} LIMIT :limit;"), # noqa: S608 - tn is fully quoted
{"limit": max_sample_rows},
)
rows = [dict(r) for r in rows_result.mappings().all()]
else:
rows = []
if engine.engine.dialect.name == "postgresql":
if "." in table_name:
schema_part, table_part = table_name.split(".", 1)
else:
schema_part, table_part = "public", table_name
estimate = await cursor.execute(
text(
"SELECT reltuples::bigint AS estimate "
"FROM pg_class c "
"JOIN pg_namespace n ON n.oid = c.relnamespace "
"WHERE n.nspname = :schema AND c.relname = :table"
),
{"schema": schema_part, "table": table_part},
)
row_count_estimate = estimate.scalar() or 0
else:
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) # noqa: S608 - tn is fully quoted
row_count_estimate = count_result.scalar()
schema_table = SchemaTable(
id=uuid5(NAMESPACE_OID, name=f"{table_name}"),
name=table_name,
columns=details["columns"],
primary_key=details.get("primary_key"),
foreign_keys=details.get("foreign_keys", []),
sample_rows=rows,
row_count_estimate=row_count_estimate,
description=f"Relational database table with '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows."
f"Here are the columns this table contains: {details['columns']}"
f"Here are a few sample_rows to show the contents of the table: {rows}"
f"Table is part of the database: {migration_config.migration_db_name}",
)
schema_tables.append(schema_table)
tables[table_name] = details
sample_data[table_name] = rows
for fk in details.get("foreign_keys", []):
ref_table_fq = fk["ref_table"]
if "." not in ref_table_fq and "." in table_name:
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
relationship_name = (
f"{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}"
)
relationship = SchemaRelationship(
id=uuid5(NAMESPACE_OID, name=relationship_name),
name=relationship_name,
source_table=table_name,
target_table=ref_table_fq,
relationship_type="foreign_key",
source_column=fk["column"],
target_column=fk["ref_column"],
description=f"Relational database table foreign key relationship between: {table_name}.{fk['column']}{ref_table_fq}.{fk['ref_column']}"
f"This foreing key relationship between table columns is a part of the following database: {migration_config.migration_db_name}",
)
schema_relationships.append(relationship)
id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}"
database_schema = DatabaseSchema(
id=uuid5(NAMESPACE_OID, name=id_str),
name=migration_config.migration_db_name,
database_type=migration_config.migration_db_provider,
tables=tables,
sample_data=sample_data,
extraction_timestamp=datetime.now(timezone.utc),
description=f"Database schema containing {len(schema_tables)} tables and {len(schema_relationships)} relationships. "
f"The database type is {migration_config.migration_db_provider}."
f"The database contains the following tables: {tables}",
)
return {
"database_schema": database_schema,
"schema_tables": schema_tables,
"relationships": schema_relationships,
}

View file

@ -0,0 +1,41 @@
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from typing import List, Dict, Optional
from datetime import datetime
class DatabaseSchema(DataPoint):
"""Represents a complete database schema with sample data"""
name: str
database_type: str # sqlite, postgres, etc.
tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter
sample_data: Dict[str, List[Dict]] # Limited examples per table
extraction_timestamp: datetime
description: str
metadata: dict = {"index_fields": ["description", "name"]}
class SchemaTable(DataPoint):
"""Represents an individual table schema with relationships"""
name: str
columns: List[Dict] # Column definitions with types
primary_key: Optional[str]
foreign_keys: List[Dict] # Foreign key relationships
sample_rows: List[Dict] # Max 3-5 example rows
row_count_estimate: Optional[int] # Actual table size
description: str
metadata: dict = {"index_fields": ["description", "name"]}
class SchemaRelationship(DataPoint):
"""Represents relationships between tables"""
name: str
source_table: str
target_table: str
relationship_type: str # "foreign_key", "one_to_many", etc.
source_column: str
target_column: str
description: str
metadata: dict = {"index_fields": ["description", "name"]}

View file

@ -197,6 +197,80 @@ async def relational_db_migration():
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!") print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
async def test_schema_only_migration():
# 1. Setup test DB and extract schema
migration_engine = await setup_test_db()
schema = await migration_engine.extract_schema()
# 2. Setup graph engine
graph_engine = await get_graph_engine()
# 4. Migrate schema only
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
# 5. Verify number of tables through search
search_results = await cognee.search(
query_text="How many tables are there in this database",
query_type=cognee.SearchType.GRAPH_COMPLETION,
top_k=30,
)
assert any("11" in r for r in search_results), (
"Number of tables in the database reported in search_results is either None or not equal to 11"
)
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
edge_counts = {
"is_part_of": 0,
"has_relationship": 0,
"foreign_key": 0,
}
if graph_db_provider == "neo4j":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:{rel_type}]->()
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0]["c"]
elif graph_db_provider == "kuzu":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:EDGE]->()
WHERE r.relationship_name = '{rel_type}'
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
for _, _, key, _ in edges:
if key in edge_counts:
edge_counts[key] += 1
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
# 7. Assert counts match expected values
expected_counts = {
"is_part_of": 11,
"has_relationship": 22,
"foreign_key": 11,
}
for rel_type, expected in expected_counts.items():
actual = edge_counts[rel_type]
assert actual == expected, (
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
)
print("Schema-only migration edge counts validated successfully!")
print(f"Edge counts: {edge_counts}")
async def test_migration_sqlite(): async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/") database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
@ -209,6 +283,7 @@ async def test_migration_sqlite():
) )
await relational_db_migration() await relational_db_migration()
await test_schema_only_migration()
async def test_migration_postgres(): async def test_migration_postgres():
@ -224,6 +299,7 @@ async def test_migration_postgres():
} }
) )
await relational_db_migration() await relational_db_migration()
await test_schema_only_migration()
async def main(): async def main():

View file

@ -1,16 +1,15 @@
from pathlib import Path
import asyncio import asyncio
import cognee
import os import os
import cognee
from cognee.infrastructure.databases.relational.config import get_migration_config
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.api.v1.visualize.visualize import visualize_graph from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.infrastructure.databases.relational import ( from cognee.infrastructure.databases.relational import (
get_migration_relational_engine, get_migration_relational_engine,
) )
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.relational import ( from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables, create_db_and_tables as create_relational_db_and_tables,
) )
@ -32,16 +31,29 @@ from cognee.infrastructure.databases.vector.pgvector import (
async def main(): async def main():
engine = get_migration_relational_engine()
# Clean all data stored in Cognee # Clean all data stored in Cognee
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
# Needed to create appropriate tables only on the Cognee side # Needed to create appropriate database tables only on the Cognee side
await create_relational_db_and_tables() await create_relational_db_and_tables()
await create_vector_db_and_tables() await create_vector_db_and_tables()
# In case environment variables are not set use the example database from the Cognee repo
migration_db_provider = os.environ.get("MIGRATION_DB_PROVIDER", "sqlite")
migration_db_path = os.environ.get(
"MIGRATION_DB_PATH",
os.path.join(Path(__file__).resolve().parent.parent.parent, "cognee/tests/test_data"),
)
migration_db_name = os.environ.get("MIGRATION_DB_NAME", "migration_database.sqlite")
migration_config = get_migration_config()
migration_config.migration_db_provider = migration_db_provider
migration_config.migration_db_path = migration_db_path
migration_config.migration_db_name = migration_db_name
engine = get_migration_relational_engine()
print("\nExtracting schema of database to migrate.") print("\nExtracting schema of database to migrate.")
schema = await engine.extract_schema() schema = await engine.extract_schema()
print(f"Migrated database schema:\n{schema}") print(f"Migrated database schema:\n{schema}")
@ -53,10 +65,6 @@ async def main():
await migrate_relational_database(graph, schema=schema) await migrate_relational_database(graph, schema=schema)
print("Relational database migration complete.") print("Relational database migration complete.")
# Define location where to store html visualization of graph of the migrated database
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
# Make sure to set top_k at a high value for a broader search, the default value is only 10! # Make sure to set top_k at a high value for a broader search, the default value is only 10!
# top_k represent the number of graph tripplets to supply to the LLM to answer your question # top_k represent the number of graph tripplets to supply to the LLM to answer your question
search_results = await cognee.search( search_results = await cognee.search(
@ -69,13 +77,25 @@ async def main():
# Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered. # Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered.
# For this kind of question we've set the top_k to 30 # For this kind of question we've set the top_k to 30
search_results = await cognee.search( search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_COT, query_type=SearchType.GRAPH_COMPLETION,
query_text="What invoices are related to Leonie Köhler?", query_text="What invoices are related to Leonie Köhler?",
top_k=30, top_k=30,
) )
print(f"Search results: {search_results}") print(f"Search results: {search_results}")
# test.html is a file with visualized data migration search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What invoices are related to Luís Gonçalves?",
top_k=30,
)
print(f"Search results: {search_results}")
# If you check the relational database for this example you can see that the search results successfully found all
# the invoices related to the two customers, without any hallucinations or additional information
# Define location where to store html visualization of graph of the migrated database
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
print("Adding html visualization of graph database after migration.") print("Adding html visualization of graph database after migration.")
await visualize_graph(destination_file_path) await visualize_graph(destination_file_path)
print(f"Visualization can be found at: {destination_file_path}") print(f"Visualization can be found at: {destination_file_path}")