Compare commits
1 commit
main
...
pensar-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
079144bac9 |
1 changed files with 147 additions and 18 deletions
|
|
@ -19,6 +19,33 @@ from ..ModelBase import Base
|
|||
logger = get_logger()
|
||||
|
||||
|
||||
def _is_identifier(s: str) -> bool:
|
||||
# Simple SQL identifier check: must be valid Python identifier and not pure digit
|
||||
if not isinstance(s, str):
|
||||
return False
|
||||
if not s.isidentifier():
|
||||
return False
|
||||
# Optionally, some SQL reserved words could be blacklisted here
|
||||
reserved = {
|
||||
'select', 'table', 'from', 'where', 'insert', 'delete', 'update', 'schema', 'join', 'or', 'and', 'not',
|
||||
'create', 'drop', 'null', 'default', 'primary', 'key', 'if', 'exists', 'values', 'into'
|
||||
}
|
||||
if s.lower() in reserved:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _validate_sql_type(s: str) -> bool:
|
||||
# A simple allow-list of basic SQL types to avoid injection via fake type names
|
||||
# Adjust as needed to match your DB; should be strict.
|
||||
allowed = {
|
||||
'INTEGER', 'BIGINT', 'SMALLINT', 'SERIAL', 'BIGSERIAL',
|
||||
'BOOLEAN', 'BOOL', 'TEXT', 'VARCHAR', 'CHAR', 'DATE', 'TIMESTAMP',
|
||||
'UUID', 'REAL', 'FLOAT', 'DOUBLE PRECISION', 'BLOB', 'BYTEA'
|
||||
}
|
||||
# Accept potential varchar/char length: VARCHAR(255) etc.
|
||||
prefix = s.strip().split('(')[0].upper()
|
||||
return prefix in allowed
|
||||
|
||||
class SQLAlchemyAdapter:
|
||||
def __init__(self, connection_string: str):
|
||||
self.db_path: str = None
|
||||
|
|
@ -56,12 +83,25 @@ class SQLAlchemyAdapter:
|
|||
return datasets
|
||||
|
||||
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
|
||||
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
|
||||
# Validate schema and table names as SQL identifiers
|
||||
if not _is_identifier(schema_name):
|
||||
raise ValueError(f"Invalid schema name: {schema_name!r}")
|
||||
if not _is_identifier(table_name):
|
||||
raise ValueError(f"Invalid table name: {table_name!r}")
|
||||
|
||||
fields_query_parts = []
|
||||
for item in table_config:
|
||||
col_name, col_type = item.get('name'), item.get('type')
|
||||
if not _is_identifier(col_name):
|
||||
raise ValueError(f"Invalid column name: {col_name!r}")
|
||||
if not isinstance(col_type, str) or not _validate_sql_type(col_type):
|
||||
raise ValueError(f"Invalid or potentially dangerous column type: {col_type!r}")
|
||||
fields_query_parts.append(f'"{col_name}" {col_type.strip()}')
|
||||
async with self.engine.begin() as connection:
|
||||
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
|
||||
await connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}";'))
|
||||
await connection.execute(
|
||||
text(
|
||||
f'CREATE TABLE IF NOT EXISTS {schema_name}."{table_name}" ({", ".join(fields_query_parts)});'
|
||||
f'CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}" ({", ".join(fields_query_parts)});'
|
||||
)
|
||||
)
|
||||
await connection.close()
|
||||
|
|
@ -248,24 +288,86 @@ class SQLAlchemyAdapter:
|
|||
|
||||
return table_names
|
||||
|
||||
def _validate_identifier(self, identifier: str) -> bool:
|
||||
"""
|
||||
Check if the provided identifier is a valid SQL identifier.
|
||||
This is used to defend against SQL injection for table and column names.
|
||||
"""
|
||||
if not isinstance(identifier, str):
|
||||
return False
|
||||
# Disallow any double quotes, semicolons, whitespace, or dangerous chars.
|
||||
if (
|
||||
not identifier
|
||||
or not identifier.isidentifier()
|
||||
or '"' in identifier
|
||||
or ";" in identifier
|
||||
or "--" in identifier
|
||||
or " " in identifier
|
||||
):
|
||||
return False
|
||||
# Add further SQL-specific restrictions if needed
|
||||
return True
|
||||
|
||||
async def get_data(self, table_name: str, filters: dict = None):
|
||||
# Validate identifiers
|
||||
if not self._validate_identifier(table_name):
|
||||
raise ValueError("Invalid table name")
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
query = f'SELECT * FROM "{table_name}"'
|
||||
# Reflect the correct table object and columns
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
await connection.run_sync(Base.metadata.reflect)
|
||||
if table_name not in Base.metadata.tables:
|
||||
raise ValueError(f"Table '{table_name}' does not exist in schema.")
|
||||
table = Base.metadata.tables[table_name]
|
||||
else:
|
||||
# Assume public schema unless another schema is provided (for future-proofing)
|
||||
schema_name = "public"
|
||||
metadata = MetaData()
|
||||
await connection.run_sync(metadata.reflect, schema=schema_name)
|
||||
full_table_name = f"{schema_name}.{table_name}"
|
||||
if full_table_name not in metadata.tables:
|
||||
raise ValueError(f"Table '{table_name}' does not exist in schema '{schema_name}'.")
|
||||
table = metadata.tables[full_table_name]
|
||||
|
||||
valid_columns = set(column.name for column in table.columns)
|
||||
|
||||
# Validate filter column names
|
||||
filtered_columns = []
|
||||
filter_params = dict()
|
||||
filter_conditions = ""
|
||||
if filters:
|
||||
filter_conditions = " AND ".join(
|
||||
[
|
||||
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})"
|
||||
if isinstance(value, list)
|
||||
else f"{key} = :{key}"
|
||||
for key, value in filters.items()
|
||||
]
|
||||
)
|
||||
for key, value in filters.items():
|
||||
if not self._validate_identifier(key):
|
||||
raise ValueError(f"Invalid column identifier: {key}")
|
||||
if key not in valid_columns:
|
||||
raise ValueError(f"Column '{key}' does not exist in table '{table_name}'.")
|
||||
filtered_columns.append(key)
|
||||
|
||||
filter_items = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, list):
|
||||
# Assign unique keys to avoid collisions and support multi-valued IN
|
||||
param_keys = [f"{key}{i}" for i in range(len(value))]
|
||||
filter_items.append(
|
||||
f'"{key}" IN ({", ".join([f":{p}" for p in param_keys])})'
|
||||
)
|
||||
filter_params.update({param_keys[i]: value[i] for i in range(len(value))})
|
||||
else:
|
||||
filter_items.append(f'"{key}" = :{key}')
|
||||
filter_params[key] = value
|
||||
filter_conditions = " AND ".join(filter_items)
|
||||
|
||||
query = f'SELECT * FROM "{table_name}"'
|
||||
if filter_conditions:
|
||||
query += f" WHERE {filter_conditions};"
|
||||
query = text(query)
|
||||
results = await connection.execute(query, filters)
|
||||
else:
|
||||
query += ";"
|
||||
query = text(query)
|
||||
query = text(query)
|
||||
|
||||
if filter_params:
|
||||
results = await connection.execute(query, filter_params)
|
||||
else:
|
||||
results = await connection.execute(query)
|
||||
return {result["data_id"]: result["status"] for result in results}
|
||||
|
||||
|
|
@ -290,9 +392,36 @@ class SQLAlchemyAdapter:
|
|||
rows = result.mappings().all()
|
||||
return rows
|
||||
|
||||
async def execute_query(self, query):
|
||||
async def execute_query(self, query: str, parameters: Optional[dict] = None):
|
||||
"""
|
||||
Securely executes a parameterized SELECT query only.
|
||||
Rejects any non-SELECT queries.
|
||||
|
||||
:param query: SQL SELECT query (must start with 'SELECT', case-insensitive, cannot contain ';', '--', or multiple statements)
|
||||
:param parameters: dictionary of bind parameters for the query
|
||||
:return: list of dictionaries representing rows
|
||||
"""
|
||||
# Basic whitelist: Only allow SELECT queries. Disallow anything else.
|
||||
if not isinstance(query, str):
|
||||
raise ValueError("Query must be a string.")
|
||||
|
||||
stripped = query.lstrip().lower()
|
||||
if not stripped.startswith("select"):
|
||||
raise ValueError("Only SELECT queries are permitted.")
|
||||
|
||||
# Block semicolons (which could be used for stacked queries) and SQL comments
|
||||
forbidden_patterns = [';', '--', '/*', '*/']
|
||||
for pattern in forbidden_patterns:
|
||||
if pattern in stripped:
|
||||
raise ValueError("Forbidden SQL pattern detected in query.")
|
||||
|
||||
# Only allow one statement
|
||||
if query.count(';') > 0:
|
||||
raise ValueError("Query must not contain semicolons.")
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
result = await connection.execute(text(query))
|
||||
# Use SQLAlchemy parameterization for safety
|
||||
result = await connection.execute(text(query), parameters or {})
|
||||
return [dict(row) for row in result]
|
||||
|
||||
async def drop_tables(self):
|
||||
|
|
@ -435,4 +564,4 @@ class SQLAlchemyAdapter:
|
|||
f"Missing value in foreign key information. \nColumn value: {col}\nReference column value: {ref_col}\n"
|
||||
)
|
||||
|
||||
return schema
|
||||
return schema
|
||||
Loading…
Add table
Reference in a new issue