Compare commits
1 commit
main
...
pensar-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
079144bac9 |
1 changed files with 147 additions and 18 deletions
|
|
@ -19,6 +19,33 @@ from ..ModelBase import Base
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_identifier(s: str) -> bool:
|
||||||
|
# Simple SQL identifier check: must be valid Python identifier and not pure digit
|
||||||
|
if not isinstance(s, str):
|
||||||
|
return False
|
||||||
|
if not s.isidentifier():
|
||||||
|
return False
|
||||||
|
# Optionally, some SQL reserved words could be blacklisted here
|
||||||
|
reserved = {
|
||||||
|
'select', 'table', 'from', 'where', 'insert', 'delete', 'update', 'schema', 'join', 'or', 'and', 'not',
|
||||||
|
'create', 'drop', 'null', 'default', 'primary', 'key', 'if', 'exists', 'values', 'into'
|
||||||
|
}
|
||||||
|
if s.lower() in reserved:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _validate_sql_type(s: str) -> bool:
|
||||||
|
# A simple allow-list of basic SQL types to avoid injection via fake type names
|
||||||
|
# Adjust as needed to match your DB; should be strict.
|
||||||
|
allowed = {
|
||||||
|
'INTEGER', 'BIGINT', 'SMALLINT', 'SERIAL', 'BIGSERIAL',
|
||||||
|
'BOOLEAN', 'BOOL', 'TEXT', 'VARCHAR', 'CHAR', 'DATE', 'TIMESTAMP',
|
||||||
|
'UUID', 'REAL', 'FLOAT', 'DOUBLE PRECISION', 'BLOB', 'BYTEA'
|
||||||
|
}
|
||||||
|
# Accept potential varchar/char length: VARCHAR(255) etc.
|
||||||
|
prefix = s.strip().split('(')[0].upper()
|
||||||
|
return prefix in allowed
|
||||||
|
|
||||||
class SQLAlchemyAdapter:
|
class SQLAlchemyAdapter:
|
||||||
def __init__(self, connection_string: str):
|
def __init__(self, connection_string: str):
|
||||||
self.db_path: str = None
|
self.db_path: str = None
|
||||||
|
|
@ -56,12 +83,25 @@ class SQLAlchemyAdapter:
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
|
async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
|
||||||
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
|
# Validate schema and table names as SQL identifiers
|
||||||
|
if not _is_identifier(schema_name):
|
||||||
|
raise ValueError(f"Invalid schema name: {schema_name!r}")
|
||||||
|
if not _is_identifier(table_name):
|
||||||
|
raise ValueError(f"Invalid table name: {table_name!r}")
|
||||||
|
|
||||||
|
fields_query_parts = []
|
||||||
|
for item in table_config:
|
||||||
|
col_name, col_type = item.get('name'), item.get('type')
|
||||||
|
if not _is_identifier(col_name):
|
||||||
|
raise ValueError(f"Invalid column name: {col_name!r}")
|
||||||
|
if not isinstance(col_type, str) or not _validate_sql_type(col_type):
|
||||||
|
raise ValueError(f"Invalid or potentially dangerous column type: {col_type!r}")
|
||||||
|
fields_query_parts.append(f'"{col_name}" {col_type.strip()}')
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
|
await connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}";'))
|
||||||
await connection.execute(
|
await connection.execute(
|
||||||
text(
|
text(
|
||||||
f'CREATE TABLE IF NOT EXISTS {schema_name}."{table_name}" ({", ".join(fields_query_parts)});'
|
f'CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}" ({", ".join(fields_query_parts)});'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await connection.close()
|
await connection.close()
|
||||||
|
|
@ -248,24 +288,86 @@ class SQLAlchemyAdapter:
|
||||||
|
|
||||||
return table_names
|
return table_names
|
||||||
|
|
||||||
|
def _validate_identifier(self, identifier: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the provided identifier is a valid SQL identifier.
|
||||||
|
This is used to defend against SQL injection for table and column names.
|
||||||
|
"""
|
||||||
|
if not isinstance(identifier, str):
|
||||||
|
return False
|
||||||
|
# Disallow any double quotes, semicolons, whitespace, or dangerous chars.
|
||||||
|
if (
|
||||||
|
not identifier
|
||||||
|
or not identifier.isidentifier()
|
||||||
|
or '"' in identifier
|
||||||
|
or ";" in identifier
|
||||||
|
or "--" in identifier
|
||||||
|
or " " in identifier
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
# Add further SQL-specific restrictions if needed
|
||||||
|
return True
|
||||||
|
|
||||||
async def get_data(self, table_name: str, filters: dict = None):
|
async def get_data(self, table_name: str, filters: dict = None):
|
||||||
|
# Validate identifiers
|
||||||
|
if not self._validate_identifier(table_name):
|
||||||
|
raise ValueError("Invalid table name")
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
query = f'SELECT * FROM "{table_name}"'
|
# Reflect the correct table object and columns
|
||||||
|
if self.engine.dialect.name == "sqlite":
|
||||||
|
await connection.run_sync(Base.metadata.reflect)
|
||||||
|
if table_name not in Base.metadata.tables:
|
||||||
|
raise ValueError(f"Table '{table_name}' does not exist in schema.")
|
||||||
|
table = Base.metadata.tables[table_name]
|
||||||
|
else:
|
||||||
|
# Assume public schema unless another schema is provided (for future-proofing)
|
||||||
|
schema_name = "public"
|
||||||
|
metadata = MetaData()
|
||||||
|
await connection.run_sync(metadata.reflect, schema=schema_name)
|
||||||
|
full_table_name = f"{schema_name}.{table_name}"
|
||||||
|
if full_table_name not in metadata.tables:
|
||||||
|
raise ValueError(f"Table '{table_name}' does not exist in schema '{schema_name}'.")
|
||||||
|
table = metadata.tables[full_table_name]
|
||||||
|
|
||||||
|
valid_columns = set(column.name for column in table.columns)
|
||||||
|
|
||||||
|
# Validate filter column names
|
||||||
|
filtered_columns = []
|
||||||
|
filter_params = dict()
|
||||||
|
filter_conditions = ""
|
||||||
if filters:
|
if filters:
|
||||||
filter_conditions = " AND ".join(
|
for key, value in filters.items():
|
||||||
[
|
if not self._validate_identifier(key):
|
||||||
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})"
|
raise ValueError(f"Invalid column identifier: {key}")
|
||||||
if isinstance(value, list)
|
if key not in valid_columns:
|
||||||
else f"{key} = :{key}"
|
raise ValueError(f"Column '{key}' does not exist in table '{table_name}'.")
|
||||||
for key, value in filters.items()
|
filtered_columns.append(key)
|
||||||
]
|
|
||||||
)
|
filter_items = []
|
||||||
|
for key, value in filters.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
# Assign unique keys to avoid collisions and support multi-valued IN
|
||||||
|
param_keys = [f"{key}{i}" for i in range(len(value))]
|
||||||
|
filter_items.append(
|
||||||
|
f'"{key}" IN ({", ".join([f":{p}" for p in param_keys])})'
|
||||||
|
)
|
||||||
|
filter_params.update({param_keys[i]: value[i] for i in range(len(value))})
|
||||||
|
else:
|
||||||
|
filter_items.append(f'"{key}" = :{key}')
|
||||||
|
filter_params[key] = value
|
||||||
|
filter_conditions = " AND ".join(filter_items)
|
||||||
|
|
||||||
|
query = f'SELECT * FROM "{table_name}"'
|
||||||
|
if filter_conditions:
|
||||||
query += f" WHERE {filter_conditions};"
|
query += f" WHERE {filter_conditions};"
|
||||||
query = text(query)
|
|
||||||
results = await connection.execute(query, filters)
|
|
||||||
else:
|
else:
|
||||||
query += ";"
|
query += ";"
|
||||||
query = text(query)
|
query = text(query)
|
||||||
|
|
||||||
|
if filter_params:
|
||||||
|
results = await connection.execute(query, filter_params)
|
||||||
|
else:
|
||||||
results = await connection.execute(query)
|
results = await connection.execute(query)
|
||||||
return {result["data_id"]: result["status"] for result in results}
|
return {result["data_id"]: result["status"] for result in results}
|
||||||
|
|
||||||
|
|
@ -290,9 +392,36 @@ class SQLAlchemyAdapter:
|
||||||
rows = result.mappings().all()
|
rows = result.mappings().all()
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
async def execute_query(self, query):
|
async def execute_query(self, query: str, parameters: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
Securely executes a parameterized SELECT query only.
|
||||||
|
Rejects any non-SELECT queries.
|
||||||
|
|
||||||
|
:param query: SQL SELECT query (must start with 'SELECT', case-insensitive, cannot contain ';', '--', or multiple statements)
|
||||||
|
:param parameters: dictionary of bind parameters for the query
|
||||||
|
:return: list of dictionaries representing rows
|
||||||
|
"""
|
||||||
|
# Basic whitelist: Only allow SELECT queries. Disallow anything else.
|
||||||
|
if not isinstance(query, str):
|
||||||
|
raise ValueError("Query must be a string.")
|
||||||
|
|
||||||
|
stripped = query.lstrip().lower()
|
||||||
|
if not stripped.startswith("select"):
|
||||||
|
raise ValueError("Only SELECT queries are permitted.")
|
||||||
|
|
||||||
|
# Block semicolons (which could be used for stacked queries) and SQL comments
|
||||||
|
forbidden_patterns = [';', '--', '/*', '*/']
|
||||||
|
for pattern in forbidden_patterns:
|
||||||
|
if pattern in stripped:
|
||||||
|
raise ValueError("Forbidden SQL pattern detected in query.")
|
||||||
|
|
||||||
|
# Only allow one statement
|
||||||
|
if query.count(';') > 0:
|
||||||
|
raise ValueError("Query must not contain semicolons.")
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
result = await connection.execute(text(query))
|
# Use SQLAlchemy parameterization for safety
|
||||||
|
result = await connection.execute(text(query), parameters or {})
|
||||||
return [dict(row) for row in result]
|
return [dict(row) for row in result]
|
||||||
|
|
||||||
async def drop_tables(self):
|
async def drop_tables(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue