diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index ade2a0821..259880797 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -19,6 +19,33 @@ from ..ModelBase import Base logger = get_logger() +def _is_identifier(s: str) -> bool: + # Simple SQL identifier check: must be valid Python identifier and not pure digit + if not isinstance(s, str): + return False + if not s.isidentifier(): + return False + # Optionally, some SQL reserved words could be blacklisted here + reserved = { + 'select', 'table', 'from', 'where', 'insert', 'delete', 'update', 'schema', 'join', 'or', 'and', 'not', + 'create', 'drop', 'null', 'default', 'primary', 'key', 'if', 'exists', 'values', 'into' + } + if s.lower() in reserved: + return False + return True + +def _validate_sql_type(s: str) -> bool: + # A simple allow-list of basic SQL types to avoid injection via fake type names + # Adjust as needed to match your DB; should be strict. + allowed = { + 'INTEGER', 'BIGINT', 'SMALLINT', 'SERIAL', 'BIGSERIAL', + 'BOOLEAN', 'BOOL', 'TEXT', 'VARCHAR', 'CHAR', 'DATE', 'TIMESTAMP', + 'UUID', 'REAL', 'FLOAT', 'DOUBLE PRECISION', 'BLOB', 'BYTEA' + } + # Accept potential varchar/char length: VARCHAR(255) etc. + prefix = s.strip().split('(')[0].upper() + return prefix in allowed + class SQLAlchemyAdapter: def __init__(self, connection_string: str): self.db_path: str = None @@ -56,12 +83,25 @@ class SQLAlchemyAdapter: return datasets async def create_table(self, schema_name: str, table_name: str, table_config: list[dict]): - fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config] + # Validate schema and table names as SQL identifiers + if not _is_identifier(schema_name): + raise ValueError(f"Invalid schema name: {schema_name!r}") + if not _is_identifier(table_name): + raise ValueError(f"Invalid table name: {table_name!r}") + + fields_query_parts = [] + for item in table_config: + col_name, col_type = item.get('name'), item.get('type') + if not _is_identifier(col_name): + raise ValueError(f"Invalid column name: {col_name!r}") + if not isinstance(col_type, str) or not _validate_sql_type(col_type): + raise ValueError(f"Invalid or potentially dangerous column type: {col_type!r}") + fields_query_parts.append(f'"{col_name}" {col_type.strip()}') async with self.engine.begin() as connection: - await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};")) + await connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}";')) await connection.execute( text( - f'CREATE TABLE IF NOT EXISTS {schema_name}."{table_name}" ({", ".join(fields_query_parts)});' + f'CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}" ({", ".join(fields_query_parts)});' ) ) await connection.close() @@ -248,24 +288,86 @@ class SQLAlchemyAdapter: return table_names + def _validate_identifier(self, identifier: str) -> bool: + """ + Check if the provided identifier is a valid SQL identifier. + This is used to defend against SQL injection for table and column names. + """ + if not isinstance(identifier, str): + return False + # Disallow any double quotes, semicolons, whitespace, or dangerous chars. + if ( + not identifier + or not identifier.isidentifier() + or '"' in identifier + or ";" in identifier + or "--" in identifier + or " " in identifier + ): + return False + # Add further SQL-specific restrictions if needed + return True + async def get_data(self, table_name: str, filters: dict = None): + # Validate identifiers + if not self._validate_identifier(table_name): + raise ValueError("Invalid table name") + async with self.engine.begin() as connection: - query = f'SELECT * FROM "{table_name}"' + # Reflect the correct table object and columns + if self.engine.dialect.name == "sqlite": + await connection.run_sync(Base.metadata.reflect) + if table_name not in Base.metadata.tables: + raise ValueError(f"Table '{table_name}' does not exist in schema.") + table = Base.metadata.tables[table_name] + else: + # Assume public schema unless another schema is provided (for future-proofing) + schema_name = "public" + metadata = MetaData() + await connection.run_sync(metadata.reflect, schema=schema_name) + full_table_name = f"{schema_name}.{table_name}" + if full_table_name not in metadata.tables: + raise ValueError(f"Table '{table_name}' does not exist in schema '{schema_name}'.") + table = metadata.tables[full_table_name] + + valid_columns = set(column.name for column in table.columns) + + # Validate filter column names + filtered_columns = [] + filter_params = dict() + filter_conditions = "" if filters: - filter_conditions = " AND ".join( - [ - f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})" - if isinstance(value, list) - else f"{key} = :{key}" - for key, value in filters.items() - ] - ) + for key, value in filters.items(): + if not self._validate_identifier(key): + raise ValueError(f"Invalid column identifier: {key}") + if key not in valid_columns: + raise ValueError(f"Column '{key}' does not exist in table '{table_name}'.") + filtered_columns.append(key) + + filter_items = [] + for key, value in filters.items(): + if isinstance(value, list): + # Assign unique keys to avoid collisions and support multi-valued IN + param_keys = [f"{key}{i}" for i in range(len(value))] + filter_items.append( + f'"{key}" IN ({", ".join([f":{p}" for p in param_keys])})' + ) + filter_params.update({param_keys[i]: value[i] for i in range(len(value))}) + else: + filter_items.append(f'"{key}" = :{key}') + filter_params[key] = value + filter_conditions = " AND ".join(filter_items) + + query = f'SELECT * FROM "{table_name}"' + if filter_conditions: query += f" WHERE {filter_conditions};" - query = text(query) - results = await connection.execute(query, filters) else: query += ";" - query = text(query) + query = text(query) + + if filter_params: + results = await connection.execute(query, filter_params) + else: results = await connection.execute(query) return {result["data_id"]: result["status"] for result in results} @@ -290,9 +392,36 @@ class SQLAlchemyAdapter: rows = result.mappings().all() return rows - async def execute_query(self, query): + async def execute_query(self, query: str, parameters: Optional[dict] = None): + """ + Securely executes a parameterized SELECT query only. + Rejects any non-SELECT queries. + + :param query: SQL SELECT query (must start with 'SELECT', case-insensitive, cannot contain ';', '--', or multiple statements) + :param parameters: dictionary of bind parameters for the query + :return: list of dictionaries representing rows + """ + # Basic whitelist: Only allow SELECT queries. Disallow anything else. + if not isinstance(query, str): + raise ValueError("Query must be a string.") + + stripped = query.lstrip().lower() + if not stripped.startswith("select"): + raise ValueError("Only SELECT queries are permitted.") + + # Block semicolons (which could be used for stacked queries) and SQL comments + forbidden_patterns = [';', '--', '/*', '*/'] + for pattern in forbidden_patterns: + if pattern in stripped: + raise ValueError("Forbidden SQL pattern detected in query.") + + # Only allow one statement + if query.count(';') > 0: + raise ValueError("Query must not contain semicolons.") + async with self.engine.begin() as connection: - result = await connection.execute(text(query)) + # Use SQLAlchemy parameterization for safety + result = await connection.execute(text(query), parameters or {}) return [dict(row) for row in result] async def drop_tables(self): @@ -435,4 +564,4 @@ class SQLAlchemyAdapter: f"Missing value in foreign key information. \nColumn value: {col}\nReference column value: {ref_col}\n" ) - return schema + return schema \ No newline at end of file