Related Documentation Download this Manual
PDF (US Ltr) - 1.6Mb
PDF (A4) - 1.6Mb


3.15.1 ML_TRAIN

Run the ML_TRAIN routine on a labeled training dataset to produce a trained machine learning model.

ML_TRAIN Syntax

MySQL 8.2.0 adds recommendation models that use implicit feedback to learn and recommend rankings for users and items.

mysql> CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options], model_handle);
 
options: {
    JSON_OBJECT('key','value'[,'key','value'] ...)
        'key','value':
        |'task', {'classification'|'regression'|'forecasting'|'anomaly_detection'|'recommendation'}|NULL
        |'datetime_index', 'column'
        |'endogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'exogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'model_list', JSON_ARRAY('model'[,'model'] ...)
        |'exclude_model_list', JSON_ARRAY('model'[,'model'] ...)
        |'optimization_metric', 'metric'
        |'include_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'exclude_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'contamination',  'contamination factor'
        |'users', 'users_column'
        |'items', 'items_column'
        |'notes', 'notes_text'
        |'feedback', {'explicit'|'implicit'}
        |'feedback_threshold', 'threshold'
}

MySQL 8.1.0 adds notes to the JSON options, the ExtraTreesClassifier classification model, and the ExtraTreesRegressor regression model. Forecasting does not require target_column_name, and it can be set to NULL.

mysql> CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options], model_handle);
 
options: {
    JSON_OBJECT('key','value'[,'key','value'] ...)
        'key','value':
        |'task', {'classification'|'regression'|'forecasting'|'anomaly_detection'|'recommendation'}|NULL
        |'datetime_index', 'column'
        |'endogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'exogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'model_list', JSON_ARRAY('model'[,'model'] ...)
        |'exclude_model_list', JSON_ARRAY('model'[,'model'] ...)
        |'optimization_metric', 'metric'
        |'include_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'exclude_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'contamination',  'contamination factor'
        |'users', 'users_column'
        |'items', 'items_column'
        |'notes', 'notes_text'
}

MySQL 8.0.33 adds support for recommendation models.

mysql> CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options], model_handle);
 
options: {
    JSON_OBJECT('key','value'[,'key','value'] ...)
        'key','value':
        |'task', {'classification'|'regression'|'forecasting'|'anomaly_detection'|'recommendation'}|NULL
        |'datetime_index', 'column'
        |'endogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'exogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'model_list', JSON_ARRAY('model'[,'model'] ...)
        |'exclude_model_list', JSON_ARRAY('model'[,'model'] ...)
        |'optimization_metric', 'metric'
        |'include_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'exclude_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'contamination',  'contamination factor'
        |'users', 'users_column'
        |'items', 'items_column'
}

MySQL 8.0.32 adds support for multivariate endogenous forecasting models, and exogenous forecasting models. MySQL 8.0.32 also adds support for anomaly detection models.

mysql> CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options], model_handle);
 
options: {
    JSON_OBJECT('key','value'[,'key','value'] ...)
        'key','value':
        |'task', {'classification'|'regression'|'forecasting'|'anomaly_detection'}|NULL
        |'datetime_index', 'column'
        |'endogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'exogenous_variables', JSON_ARRAY('column'[,'column'] ...)
        |'model_list', JSON_ARRAY('model'[,'model'] ...)
        |'exclude_model_list', JSON_ARRAY('model'[,'model'] ...)
        |'optimization_metric', 'metric'
        |'include_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'exclude_column_list', JSON_ARRAY('column'[,'column'] ...)
        |'contamination',  'contamination factor'  
}

Before MySQL 8.0.32:

mysql> CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options], model_handle);
 
options: {
    JSON_OBJECT('key','value'[,'key','value'] ...)
        'key','value':
        |'task', {'classification'|'regression'|'forecasting'}|NULL
        |'datetime_index', 'column'
        |'endogenous_variables', JSON_ARRAY('column')
        |'model_list', JSON_ARRAY('model'[,'model'] ...)
        |'exclude_model_list', JSON_ARRAY('model'[,'model'] ...)
        |'optimization_metric', 'metric'
        |'exclude_column_list', JSON_ARRAY('column'[,'column'] ...)
}
Note

The MySQL account that runs ML_TRAIN cannot have a period character (".") in its name; for example, a user named 'joesmith'@'%' is permitted to train a model, but a user named 'joe.smith'@'%' is not. For more information about this limitation, see Section 3.18, “HeatWave AutoML Limitations”.

The ML_TRAIN routine also runs the ML_EXPLAIN routine with the default Permutation Importance model for prediction explainers and model explainers. See Section 3.6, “Training Explainers”. To train other prediction explainers and model explainers use the ML_EXPLAIN routine with the preferred explainer after ML_TRAIN. MySQL 8.0.31 does not run the ML_EXPLAIN routine after ML_TRAIN.

ML_EXPLAIN does not support the anomaly_detection and recommendation tasks, and ML_TRAIN does not run ML_EXPLAIN.

ML_TRAIN parameters:

  • table_name: The name of the table that contains the labeled training dataset. The table name must be valid and fully qualified; that is, it must include the schema name, schema_name.table_name. The table cannot exceed 10 GB, 100 million rows, or 1017 columns. Before MySQL 8.0.29, the column limit was 900.

  • target_column_name: The name of the target column containing ground truth values.

    HeatWave AutoML does not support a text target column.

    Anomaly detection does not require labelled data, and target_column_name must be set to NULL.

    As of MySQL 8.1.0 forecasting does not require target_column_name, and it can be set to NULL.

  • model_handle: The name of a user-defined session variable that stores the machine learning model handle for the duration of the connection. User variables are written as @var_name. Some of the examples in this guide use @census_model as the variable name. Any valid name for a user-defined variable is permitted, for example @my_model.

    As of MySQL 8.0.31, if the model_handle variable was set to a value before calling ML_TRAIN, that model handle is used for the model. A model handle must be unique in the model catalog.

    Otherwise, HeatWave AutoML generates a model handle. When ML_TRAIN finishes executing, retrieve the generated model handle by querying the session variable. See Section 3.13.8, “Model Handles”.

  • options: Optional parameters specified as key-value pairs in JSON format. If an option is not specified, the default setting is used. If no options are specified, you can specify NULL in place of the JSON argument.

    • task: Specifies the machine learning task. Permitted values are:

      • classification: The default. Use this task type if the target is a discrete value.

      • regression: Use this task type if the target column is a continuous numerical value.

      • forecasting: Use this task type if the target column is a date-time column that requires a timeseries forecast. The datetime_index and endogenous_variables parameters are required with the forecasting task.

      • anomaly_detection: Use this task type to detect anomalies.

      • recommendation: Use this task type for recommendation models.

    • datetime_index: For forecasting tasks, the column name for a datetime column that acts as an index for the forecast variable. The column can be one of the supported datetime column types, DATETIME, TIMESTAMP, DATE, TIME, and YEAR, or an auto-incrementing index.

      The forecast models SARIMAXForecaster, VARMAXForecaster, and DynFactorForecaster cannot back test, that is forecast into training data, when using exogenous_variables. Therefore, the predict table must not overlap the datetime_index with the training table. The start date in the predict table must be a date immediately following the last date in the training table when exogenous_variables are used. For example, the predict table has to start with year 2024 if the training table with YEAR data type datetime_index ends with year 2023.

      The datetime_index for the predict table must not have missing dates after the last date in the training table. For example, the predict table has to start with year 2024 if the training table with YEAR data type datetime_index ends with year 2023. The predict table cannot start with year, for example, 2025 or 2030, because that would miss out 1 and 6 years, respectively.

      When options do not include exogenous_variables , the predict table can overlap the datetime_index with the training table. This supports back testing.

      The valid range of years for datetime_index dates must be between 1678 and 2261. It will cause an error if any part of the training table or predict table has dates outside this range. The last date in the training table plus the predict table length must still be inside the valid year range. For example, if the datetime_index in the training table has YEAR data type, and the last date is year 2023, the predict table length must be less than 238 rows: 2261 minus 2023 equals 238 rows.

    • endogenous_variables: For forecasting tasks, the column or columns to be forecast.

      Univariate forecasting models support a single numeric column, specified as a JSON_ARRAY. This column must also be specified as the target_column_name, because that field is required, but it is not used in that location.

      Multivariate forecasting models support multiple numeric columns, specified as a JSON_ARRAY. One of these columns must also be specified as the target_column_name.

      endogenous_variables cannot be text.

    • exogenous_variables: For forecasting tasks, the column or columns of independent, non-forecast, predictive variables, specified as a JSON_ARRAY. These optional variables are not forecast, but help to predict the future values of the forecast variables. These variables affect a model without being affected by it. For example, for sales forecasting these variables might be advertising expenditure, occurrence of promotional events, weather, or holidays.

      ML_TRAIN will consider all supported models during the algorithm selection stage if options includes exogenous_variables, including models that do not support exogenous_variables.

      For example, if options includes univariate endogenous_variables with exogenous_variables, then ML_TRAIN will consider NaiveForecaster, ThetaForecaster, ExpSmoothForecaster, ETSForecaster, STLwESForecaster, STLwARIMAForecaster, and SARIMAXForecaster. ML_TRAIN will ignore exogenous_variables if the model does not support them.

      Similarly, if options includes multivariate endogenous_variables with exogenous_variables, then ML_TRAIN will consider VARMAXForecaster and DynFactorForecaster.

      If options also includes include_column_list, this will force ML_TRAIN to only consider those models that support exogenous_variables.

    • model_list: The type of model to be trained. If more than one model is specified, the best model type is selected from the list. See Section 3.15.11, “Model Types”.

      This option cannot be used together with the exclude_model_list option, and it is not supported for anomaly_detection tasks.

    • exclude_model_list: Model types that should not be trained. Specified model types are excluded from consideration during model selection. See Section 3.15.11, “Model Types”.

      This option cannot be specified together with the model_list option, and it is not supported for anomaly_detection tasks.

    • optimization_metric: The scoring metric to optimize for when training a machine learning model. The metric must be compatible with the task type and the target data. See Section 3.15.13, “Optimization and Scoring Metrics”.

      This is not supported for anomaly_detection tasks.

    • include_column_list: ML_TRAIN must include this list of columns.

      For classification, regression, anomaly_detection and recommendation tasks, include_column_list ensures that ML_TRAIN will not drop these columns.

      For forecasting tasks, include_column_list can only include exogenous_variables. If include_column_list contains at least one exogenous_variables, this will force ML_TRAIN to only consider those models that support exogenous_variables.

      All columns in include_column_list must be included in the training table.

    • exclude_column_list: Feature columns of the training dataset to exclude from consideration when training a model. Columns that are excluded using exclude_column_list do not need to be excluded from the dataset used for predictions.

      The exclude_column_list cannot contain any columns provided in endogenous_variables, exogenous_variables, and include_column_list.

    • contamination: The optional contamination factor for use with the anomaly_detection task. 0 < contamination < 0.5. The default value is 0.1.

    • users: Specifies the column name corresponding to the user ids.

      This must be a valid column name, and it must be different from the items column name.

    • items: Specifies the column name corresponding to the item ids.

      This must be a valid column name, and it must be different from the users column name.

    • notes: Add notes to the model_metadata.

    • feedback: The type of feedback for a recommendation model, explicit, the default, or implicit.

    • feedback_threshold: The feedback threshold for a recommendation model with implicit feedback. All ratings at or above the feedback_threshold are implied to provide positive feedback. All ratings below the feedback_threshold are implied to provide negative feedback. The default value is 1.

Syntax Examples

  • An ML_TRAIN example that uses the classification task option implicitly (classification is the default if not specified explicitly):

    mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', 
              NULL, @iris_model);
  • An ML_TRAIN example that specifies the classification task type explicitly, and sets a model handle instead of letting HeatWave AutoML generate one:

    mysql> SET @iris_model = 'iris_manual';
    mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', 
              JSON_OBJECT('task', 'classification'), 
              @iris_model);
  • An ML_TRAIN example that specifies the regression task type:

    mysql> CALL sys.ML_TRAIN('employee.salary_train', 'salary', 
              JSON_OBJECT('task', 'regression'), @salary_model);
  • An ML_TRAIN example that specifies the model_list option. This example trains either an XGBClassifier or LGBMClassifier model.

    mysql> CALL sys.ml_train('ml_data.iris_train', 'class', 
              JSON_OBJECT('task','classification', 
              'model_list', JSON_ARRAY('XGBClassifier', 'LGBMClassifier')), 
              @iris_model);
  • An ML_TRAIN example that specifies the exclude_model_list option. In this example, LogisticRegression and GaussianNB models are excluded from model selection.

    mysql> CALL sys.ml_train('ml_data.iris_train', 'class', 
              JSON_OBJECT('task','classification', 
              'exclude_model_list', JSON_ARRAY('LogisticRegression', 'GaussianNB')), 
              @iris_model);
  • An ML_TRAIN example that specifies the optimization_metric option.

    mysql> CALL sys.ml_train('ml_data.iris_train', 'class', 
              JSON_OBJECT('task','classification', 'optimization_metric', 'neg_log_loss') , 
              @iris_model);
  • An ML_TRAIN example that specifies the exclude_column_list option.

    mysql> CALL sys.ml_train('ml_data.iris_train', 'class', 
              JSON_OBJECT('task','classification', 
              'exclude_column_list', JSON_ARRAY('sepal length', 'petal length')), 
              @iris_model);
  • An ML_TRAIN example that adds notes to the model_metadata.

    mysql> CALL sys.ml_train('ml_data.iris_train', 'class', 
              JSON_OBJECT('task','classification', 'notes', 'testing2'), 
              @iris_model);
    Query OK, 0 rows affected (1 min 42.53 sec)
    
    mysql>  SELECT model_metadata FROM ML_SCHEMA_root.MODEL_CATALOG 
              WHERE model_handle=@iris_model;
    +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | model_metadata                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       |
    +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | {"task": "classification", "notes": "testing2", "format": "HWMLv1.0", "n_rows": 105, "status": "Ready", "n_columns": 4, "column_names": "{\"0\": \"sepal length\", \"1\": \"sepal width\", \"2\": \"petal length\", \"3\": \"petal width\"}", "contamination": 0.0, "model_quality": "high", "training_time": 10.408954620361328, "algorithm_name": "RandomForestClassifier", "training_score": -0.08308402448892593, "build_timestamp": 0, "n_selected_rows": 84, "train_table_name": "mlcorpus_v5.iris_train", "training_columns": "[\"sepal length\", \"sepal width\", \"petal length\", \"petal width\"]", "model_explanation": "{\"permutation_importance\": {\"petal width\": 0.5926, \"sepal width\": 0.0, \"petal length\": 0.0423, \"sepal length\": 0.0}}", "n_selected_columns": 2, "target_column_name": "class", "optimization_metric": "neg_log_loss", "selected_column_names": "[\"petal length\", \"petal width\"]"} |
    +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    1 row in set (0.00 sec)

See also: