cognee/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py
ketanjain7981 e1d313a46b move DATABASE_CONNECT_ARGS parsing to RelationalConfig
Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
2025-12-09 10:15:36 +05:30

672 lines
28 KiB
Python

import os
import asyncio
from os import path
import tempfile
from uuid import UUID
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import NoResultFound
from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from cognee.modules.data.models.Data import Data
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from ..ModelBase import Base
logger = get_logger()
class SQLAlchemyAdapter:
"""
Adapt a SQLAlchemy connection for asynchronous operations with database management
functions.
"""
def __init__(self, connection_string: str, connect_args: dict = None):
"""
Initialize the SQLAlchemy adapter with connection settings.
Parameters:
-----------
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
or 'postgresql://user:pass@host:port/db').
connect_args (dict, optional): Database driver connection arguments.
Configuration is loaded from RelationalConfig.database_connect_args, which reads
from the DATABASE_CONNECT_ARGS environment variable.
Examples:
PostgreSQL with SSL:
DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
SQLite with custom timeout:
DATABASE_CONNECT_ARGS='{"timeout": 60}'
"""
self.db_path: str = None
self.db_uri: str = connection_string
# Use provided connect_args (already parsed from config)
final_connect_args = connect_args or {}
if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")
self.db_path = db_path
if "s3://" in self.db_path:
db_dir_path = path.dirname(self.db_path)
file_storage = get_file_storage(db_dir_path)
run_sync(file_storage.ensure_directory_exists())
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
self.temp_db_file = temp_file.name
connection_string = prefix + "///" + self.temp_db_file
run_sync(self.pull_from_s3())
if "sqlite" in connection_string:
self.engine = create_async_engine(
connection_string,
poolclass=NullPool,
connect_args={**{"timeout": 30}, **final_connect_args},
)
else:
self.engine = create_async_engine(
connection_string,
pool_size=20,
max_overflow=20,
pool_recycle=280,
pool_pre_ping=True,
pool_timeout=280,
connect_args=final_connect_args,
)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_db_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
s3_file_storage.s3.put(self.temp_db_file, self.db_path, recursive=True)
async def pull_from_s3(self) -> None:
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
try:
s3_file_storage.s3.get(self.db_path, self.temp_db_file, recursive=True)
except FileNotFoundError:
pass
@asynccontextmanager
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""
Provide an asynchronous context manager for obtaining a database session.
"""
async_session_maker = self.sessionmaker
async with async_session_maker() as session:
try:
yield session
finally:
await session.close() # Ensure the session is closed
def get_session(self):
"""
Provide a context manager for obtaining a database session.
"""
session_maker = self.sessionmaker
with session_maker() as session:
try:
yield session
finally:
session.close() # Ensure the session is closed
async def get_datasets(self):
"""
Retrieve a list of datasets from the database.
Returns:
--------
A list of datasets retrieved from the database.
"""
from cognee.modules.data.models import Dataset
async with self.get_async_session() as session:
result = await session.execute(select(Dataset).options(joinedload(Dataset.data)))
datasets = result.unique().scalars().all()
return datasets
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
"""
Create a table with the specified schema and configuration if it does not exist.
Parameters:
-----------
- schema_name (str): The name of the schema to create the table within.
- table_name (str): The name of the table to be created.
- table_config (list[dict]): A list of dictionaries representing the table's fields
and their types.
"""
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
async with self.engine.begin() as connection:
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
await connection.execute(
text(
f'CREATE TABLE IF NOT EXISTS {schema_name}."{table_name}" ({", ".join(fields_query_parts)});'
)
)
await connection.close()
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
"""
Delete a table from the specified schema, if it exists.
Parameters:
-----------
- table_name (str): The name of the table to delete.
- schema_name (Optional[str]): The name of the schema containing the table to
delete, defaults to 'public'. (default 'public')
"""
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# SQLite doesn't support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text(f'DROP TABLE IF EXISTS "{table_name}";'))
else:
await connection.execute(
text(f'DROP TABLE IF EXISTS {schema_name}."{table_name}" CASCADE;')
)
async def insert_data(
self,
table_name: str,
data: list[dict],
schema_name: Optional[str] = "public",
) -> int:
"""
Insert data into a specified table, returning the number of rows inserted.
Parameters:
-----------
- table_name (str): The name of the table to insert data into.
- data (list[dict]): A list of dictionaries representing the rows to add to the
table.
- schema_name (Optional[str]): The name of the schema containing the table, defaults
to 'public'. (default 'public')
Returns:
--------
- int: The number of rows inserted into the table.
"""
if not data:
logger.info("No data provided for insertion")
return 0
try:
# Use SQLAlchemy Core insert with execution options
async with self.engine.begin() as conn:
# Dialect-agnostic table reference
if self.engine.dialect.name == "sqlite":
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
# so must be enabled for each database connection/session separately.
await conn.execute(text("PRAGMA foreign_keys=ON"))
table = await self.get_table(table_name) # SQLite ignores schemas
else:
table = await self.get_table(table_name, schema_name)
result = await conn.execute(table.insert().values(data))
# Return rowcount for validation
return result.rowcount
except Exception as e:
logger.error(f"Insert failed: {str(e)}")
raise e # Re-raise for error handling upstream
async def get_schema_list(self) -> List[str]:
"""
Return a list of all schema names in the database, excluding system schemas.
Returns:
--------
- List[str]: A list of schema names in the database.
"""
if self.engine.dialect.name == "postgresql":
async with self.engine.begin() as connection:
result = await connection.execute(
text(
"""
SELECT schema_name FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
"""
)
)
return [schema[0] for schema in result.fetchall()]
return []
async def delete_entity_by_id(
self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"
):
"""
Delete an entity from the specified table based on its unique ID.
Parameters:
-----------
- table_name (str): The name of the table from which to delete the entity.
- data_id (UUID): The unique identifier of the entity to be deleted.
- schema_name (Optional[str]): The name of the schema where the table resides,
defaults to 'public'. (default 'public')
"""
if self.engine.dialect.name == "sqlite":
async with self.get_async_session() as session:
TableModel = await self.get_table(table_name, schema_name)
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
# so must be enabled for each database connection/session separately.
await session.execute(text("PRAGMA foreign_keys = ON;"))
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()
else:
async with self.get_async_session() as session:
TableModel = await self.get_table(table_name, schema_name)
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()
async def delete_data_entity(self, data_id: UUID):
"""
Delete a data entity along with its local files if no references remain in the database.
Parameters:
-----------
- data_id (UUID): The unique identifier of the data entity to be deleted.
"""
async with self.get_async_session() as session:
if self.engine.dialect.name == "sqlite":
# Foreign key constraints are disabled by default in SQLite (for backwards compatibility),
# so must be enabled for each database connection/session separately.
await session.execute(text("PRAGMA foreign_keys = ON;"))
try:
data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one()
except (ValueError, NoResultFound) as e:
raise EntityNotFoundError(message=f"Entity not found: {str(e)}") from e
# Check if other data objects point to the same raw data location
raw_data_location_entities = (
await session.execute(
select(Data.raw_data_location).where(
Data.raw_data_location == data_entity.raw_data_location
)
)
).all()
# Don't delete local file unless this is the only reference to the file in the database
if len(raw_data_location_entities) == 1:
# delete local file only if it's created by cognee
storage_config = get_storage_config()
if (
storage_config["data_root_directory"]
in raw_data_location_entities[0].raw_data_location
):
file_storage = get_file_storage(storage_config["data_root_directory"])
file_path = os.path.basename(raw_data_location_entities[0].raw_data_location)
if await file_storage.file_exists(file_path):
await file_storage.remove(file_path)
else:
# Report bug as file should exist
logger.error("Local file which should exist can't be found.")
await session.execute(delete(Data).where(Data.id == data_id))
await session.commit()
async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table:
"""
Load a table dynamically using the specified name and schema information.
Parameters:
-----------
- table_name (str): The name of the table to load.
- schema_name (Optional[str]): The name of the schema containing the table, defaults
to 'public'. (default 'public')
Returns:
--------
- Table: The SQLAlchemy Table object corresponding to the specified table.
"""
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if table_name in Base.metadata.tables:
return Base.metadata.tables[table_name]
else:
raise EntityNotFoundError(message=f"Table '{table_name}' not found.")
else:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect, schema=schema_name)
# Define the full table name
full_table_name = f"{schema_name}.{table_name}"
# Check if table is in list of tables for the given schema
if full_table_name in metadata.tables:
return metadata.tables[full_table_name]
raise EntityNotFoundError(message=f"Table '{full_table_name}' not found.")
async def get_table_names(self) -> List[str]:
"""
Return a list of all table names in the database, regardless of their model definitions.
Returns:
--------
- List[str]: A list of all table names in the database.
"""
table_names = []
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# Use a new MetaData instance to reflect all tables
metadata = MetaData()
await connection.run_sync(metadata.reflect) # Reflect the entire database
table_names = list(metadata.tables.keys()) # Get table names
else:
schema_list = await self.get_schema_list()
metadata = MetaData()
for schema_name in schema_list:
await connection.run_sync(metadata.reflect, schema=schema_name)
table_names.extend(metadata.tables.keys()) # Append table names from schema
metadata.clear() # Clear metadata for the next schema
return table_names
async def get_data(self, table_name: str, filters: dict = None):
"""
Retrieve data from a specified table using optional filtering conditions.
Parameters:
-----------
- table_name (str): The name of the table from which to retrieve data.
- filters (dict): A dictionary of filtering conditions to apply to the data
retrieval, defaults to None for no filters. (default None)
Returns:
--------
A dictionary of data results keyed by their unique identifiers.
"""
async with self.engine.begin() as connection:
query = f'SELECT * FROM "{table_name}"'
if filters:
filter_conditions = " AND ".join(
[
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})"
if isinstance(value, list)
else f"{key} = :{key}"
for key, value in filters.items()
]
)
query += f" WHERE {filter_conditions};"
query = text(query)
results = await connection.execute(query, filters)
else:
query += ";"
query = text(query)
results = await connection.execute(query)
return {result["data_id"]: result["status"] for result in results}
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
"""
Fetch all records from a specified table, validating inputs against SQL injection.
Parameters:
-----------
- table_name (str): The name of the table to query data from.
- schema (str): The schema of the table, defaults to 'public'. (default 'public')
Returns:
--------
A list of dictionaries representing all rows in the specified table.
"""
async with self.get_async_session() as session:
# Validate inputs to prevent SQL injection
if not table_name.isidentifier():
raise ValueError("Invalid table name")
if schema and not schema.isidentifier():
raise ValueError("Invalid schema name")
if self.engine.dialect.name == "sqlite":
table = await self.get_table(table_name)
else:
table = await self.get_table(table_name, schema)
# Query all data from the table
query = select(table)
result = await session.execute(query)
# Fetch all rows as a list of dictionaries
rows = result.mappings().all()
return rows
async def execute_query(self, query):
"""
Execute a raw SQL query against the database asynchronously.
Parameters:
-----------
- query: The SQL query string to execute.
Returns:
--------
The result set as a list of dictionaries, with each dictionary representing a row.
"""
async with self.engine.begin() as connection:
result = await connection.execute(text(query))
return [dict(row) for row in result]
async def drop_tables(self):
"""
Drop specified tables from the database, if they exist, and handle errors that may occur
during the process.
"""
async with self.engine.begin() as connection:
try:
await connection.execute(text("DROP TABLE IF EXISTS group_permission CASCADE"))
await connection.execute(text("DROP TABLE IF EXISTS permissions CASCADE"))
# Add more DROP TABLE statements for other tables as needed
logger.debug("Database tables dropped successfully.")
except Exception as e:
logger.error(f"Error dropping database tables: {e}")
raise e
async def create_database(self):
"""
Create the database if it does not exist, ensuring necessary directories are in place
for SQLite.
"""
if self.engine.dialect.name == "sqlite":
db_directory = path.dirname(self.db_path)
db_name = path.basename(self.db_path)
file_storage = get_file_storage(db_directory)
if not await file_storage.file_exists(db_name):
await file_storage.ensure_directory_exists()
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(Base.metadata.create_all)
async def delete_database(self):
"""
Delete the entire database, including all tables and files if applicable.
"""
try:
if self.engine.dialect.name == "sqlite":
await self.engine.dispose(close=True)
# Wait for the database connections to close and release the file (Windows)
await asyncio.sleep(2)
db_directory = path.dirname(self.db_path)
file_name = path.basename(self.db_path)
file_storage = get_file_storage(db_directory)
await file_storage.remove(file_name)
else:
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Drop all tables from the public schema
schema_list = ["public", "public_staging"]
for schema_name in schema_list:
# Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
drop_table_query = text(
f'DROP TABLE IF EXISTS {schema_name}."{table.name}" CASCADE'
)
await connection.execute(drop_table_query)
metadata.clear()
except Exception as e:
logger.error(f"Error deleting database: {e}")
raise e
logger.info("Database deleted successfully.")
async def extract_schema(self):
"""
Extract the schema information of the database, including tables and their respective
columns and constraints.
Returns:
--------
A dictionary containing the schema details of the database.
"""
async with self.engine.begin() as connection:
tables = await self.get_table_names()
schema = {}
if self.engine.dialect.name == "sqlite":
for table_name in tables:
schema[table_name] = {"columns": [], "primary_key": None, "foreign_keys": []}
# Get column details
columns_result = await connection.execute(
text(f"PRAGMA table_info('{table_name}');")
)
columns = columns_result.fetchall()
for column in columns:
column_name = column[1]
column_type = column[2]
is_pk = column[5] == 1
schema[table_name]["columns"].append(
{"name": column_name, "type": column_type}
)
if is_pk:
schema[table_name]["primary_key"] = column_name
# Get foreign key details
foreign_keys_results = await connection.execute(
text(f"PRAGMA foreign_key_list('{table_name}');")
)
foreign_keys = foreign_keys_results.fetchall()
for fk in foreign_keys:
schema[table_name]["foreign_keys"].append(
{
"column": fk[3], # Column in the current table
"ref_table": fk[2], # Referenced table
"ref_column": fk[4], # Referenced column
}
)
else:
schema_list = await self.get_schema_list()
for schema_name in schema_list:
# Get tables for the current schema via the inspector.
tables = await connection.run_sync(
lambda sync_conn: inspect(sync_conn).get_table_names(schema=schema_name)
)
for table_name in tables:
# Optionally, qualify the table name with the schema if not in the default schema.
key = (
table_name if schema_name == "public" else f"{schema_name}.{table_name}"
)
schema[key] = {"columns": [], "primary_key": None, "foreign_keys": []}
# Helper function to get table details using the inspector.
def get_details(sync_conn, table, schema_name):
"""
Retrieve detailed information about a specific table including columns, primary keys,
and foreign keys.
Parameters:
-----------
- sync_conn: The synchronous connection used to interact with the database.
- table: The name of the table to retrieve details for.
- schema_name: The schema name in which the table resides.
Returns:
--------
Details about the table's columns, primary keys, and foreign keys.
"""
insp = inspect(sync_conn)
cols = insp.get_columns(table, schema=schema_name)
pk = insp.get_pk_constraint(table, schema=schema_name)
fks = insp.get_foreign_keys(table, schema=schema_name)
return cols, pk, fks
cols, pk, fks = await connection.run_sync(
get_details, table_name, schema_name
)
for column in cols:
# Convert the type to string
schema[key]["columns"].append(
{"name": column["name"], "type": str(column["type"])}
)
pk_columns = pk.get("constrained_columns", [])
if pk_columns:
schema[key]["primary_key"] = pk_columns[0]
for fk in fks:
for col, ref_col in zip(
fk.get("constrained_columns", []), fk.get("referred_columns", [])
):
if col and ref_col:
schema[key]["foreign_keys"].append(
{
"column": col,
"ref_table": fk.get("referred_table"),
"ref_column": ref_col,
}
)
else:
logger.warning(
f"Missing value in foreign key information. \nColumn value: {col}\nReference column value: {ref_col}\n"
)
return schema