Fix 3 security issues:

1. SQL Injection Through Unvalidated Schema and Table Creation Parameters (CWE-89)
2. SQL Injection Through Unvalidated Table and Column Names in Database Query Construction (CWE-89)
3. Unrestricted SQL Query Execution Vulnerability in Database Adapter (CWE-89)
This commit is contained in:
pensarapp[bot] 2025-05-22 08:57:16 +00:00 committed by GitHub
parent b1b4ae3d5f
commit 079144bac9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -19,6 +19,33 @@ from ..ModelBase import Base
logger = get_logger()
def _is_identifier(s: str) -> bool:
# Simple SQL identifier check: must be valid Python identifier and not pure digit
if not isinstance(s, str):
return False
if not s.isidentifier():
return False
# Optionally, some SQL reserved words could be blacklisted here
reserved = {
'select', 'table', 'from', 'where', 'insert', 'delete', 'update', 'schema', 'join', 'or', 'and', 'not',
'create', 'drop', 'null', 'default', 'primary', 'key', 'if', 'exists', 'values', 'into'
}
if s.lower() in reserved:
return False
return True
def _validate_sql_type(s: str) -> bool:
# A simple allow-list of basic SQL types to avoid injection via fake type names
# Adjust as needed to match your DB; should be strict.
allowed = {
'INTEGER', 'BIGINT', 'SMALLINT', 'SERIAL', 'BIGSERIAL',
'BOOLEAN', 'BOOL', 'TEXT', 'VARCHAR', 'CHAR', 'DATE', 'TIMESTAMP',
'UUID', 'REAL', 'FLOAT', 'DOUBLE PRECISION', 'BLOB', 'BYTEA'
}
# Accept potential varchar/char length: VARCHAR(255) etc.
prefix = s.strip().split('(')[0].upper()
return prefix in allowed
class SQLAlchemyAdapter:
def __init__(self, connection_string: str):
self.db_path: str = None
@ -56,12 +83,25 @@ class SQLAlchemyAdapter:
return datasets
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
# Validate schema and table names as SQL identifiers
if not _is_identifier(schema_name):
raise ValueError(f"Invalid schema name: {schema_name!r}")
if not _is_identifier(table_name):
raise ValueError(f"Invalid table name: {table_name!r}")
fields_query_parts = []
for item in table_config:
col_name, col_type = item.get('name'), item.get('type')
if not _is_identifier(col_name):
raise ValueError(f"Invalid column name: {col_name!r}")
if not isinstance(col_type, str) or not _validate_sql_type(col_type):
raise ValueError(f"Invalid or potentially dangerous column type: {col_type!r}")
fields_query_parts.append(f'"{col_name}" {col_type.strip()}')
async with self.engine.begin() as connection:
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
await connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}";'))
await connection.execute(
text(
f'CREATE TABLE IF NOT EXISTS {schema_name}."{table_name}" ({", ".join(fields_query_parts)});'
f'CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}" ({", ".join(fields_query_parts)});'
)
)
await connection.close()
@ -248,24 +288,86 @@ class SQLAlchemyAdapter:
return table_names
def _validate_identifier(self, identifier: str) -> bool:
"""
Check if the provided identifier is a valid SQL identifier.
This is used to defend against SQL injection for table and column names.
"""
if not isinstance(identifier, str):
return False
# Disallow any double quotes, semicolons, whitespace, or dangerous chars.
if (
not identifier
or not identifier.isidentifier()
or '"' in identifier
or ";" in identifier
or "--" in identifier
or " " in identifier
):
return False
# Add further SQL-specific restrictions if needed
return True
async def get_data(self, table_name: str, filters: dict = None):
# Validate identifiers
if not self._validate_identifier(table_name):
raise ValueError("Invalid table name")
async with self.engine.begin() as connection:
query = f'SELECT * FROM "{table_name}"'
# Reflect the correct table object and columns
if self.engine.dialect.name == "sqlite":
await connection.run_sync(Base.metadata.reflect)
if table_name not in Base.metadata.tables:
raise ValueError(f"Table '{table_name}' does not exist in schema.")
table = Base.metadata.tables[table_name]
else:
# Assume public schema unless another schema is provided (for future-proofing)
schema_name = "public"
metadata = MetaData()
await connection.run_sync(metadata.reflect, schema=schema_name)
full_table_name = f"{schema_name}.{table_name}"
if full_table_name not in metadata.tables:
raise ValueError(f"Table '{table_name}' does not exist in schema '{schema_name}'.")
table = metadata.tables[full_table_name]
valid_columns = set(column.name for column in table.columns)
# Validate filter column names
filtered_columns = []
filter_params = dict()
filter_conditions = ""
if filters:
filter_conditions = " AND ".join(
[
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})"
if isinstance(value, list)
else f"{key} = :{key}"
for key, value in filters.items()
]
)
for key, value in filters.items():
if not self._validate_identifier(key):
raise ValueError(f"Invalid column identifier: {key}")
if key not in valid_columns:
raise ValueError(f"Column '{key}' does not exist in table '{table_name}'.")
filtered_columns.append(key)
filter_items = []
for key, value in filters.items():
if isinstance(value, list):
# Assign unique keys to avoid collisions and support multi-valued IN
param_keys = [f"{key}{i}" for i in range(len(value))]
filter_items.append(
f'"{key}" IN ({", ".join([f":{p}" for p in param_keys])})'
)
filter_params.update({param_keys[i]: value[i] for i in range(len(value))})
else:
filter_items.append(f'"{key}" = :{key}')
filter_params[key] = value
filter_conditions = " AND ".join(filter_items)
query = f'SELECT * FROM "{table_name}"'
if filter_conditions:
query += f" WHERE {filter_conditions};"
query = text(query)
results = await connection.execute(query, filters)
else:
query += ";"
query = text(query)
query = text(query)
if filter_params:
results = await connection.execute(query, filter_params)
else:
results = await connection.execute(query)
return {result["data_id"]: result["status"] for result in results}
@ -290,9 +392,36 @@ class SQLAlchemyAdapter:
rows = result.mappings().all()
return rows
async def execute_query(self, query):
async def execute_query(self, query: str, parameters: Optional[dict] = None):
"""
Securely executes a parameterized SELECT query only.
Rejects any non-SELECT queries.
:param query: SQL SELECT query (must start with 'SELECT', case-insensitive, cannot contain ';', '--', or multiple statements)
:param parameters: dictionary of bind parameters for the query
:return: list of dictionaries representing rows
"""
# Basic whitelist: Only allow SELECT queries. Disallow anything else.
if not isinstance(query, str):
raise ValueError("Query must be a string.")
stripped = query.lstrip().lower()
if not stripped.startswith("select"):
raise ValueError("Only SELECT queries are permitted.")
# Block semicolons (which could be used for stacked queries) and SQL comments
forbidden_patterns = [';', '--', '/*', '*/']
for pattern in forbidden_patterns:
if pattern in stripped:
raise ValueError("Forbidden SQL pattern detected in query.")
# Only allow one statement
if query.count(';') > 0:
raise ValueError("Query must not contain semicolons.")
async with self.engine.begin() as connection:
result = await connection.execute(text(query))
# Use SQLAlchemy parameterization for safety
result = await connection.execute(text(query), parameters or {})
return [dict(row) for row in result]
async def drop_tables(self):
@ -435,4 +564,4 @@ class SQLAlchemyAdapter:
f"Missing value in foreign key information. \nColumn value: {col}\nReference column value: {ref_col}\n"
)
return schema
return schema