MySQL HeatWave User Guide  /  HeatWave ML  /  Training a Model

3.4 Training a Model

The ML_TRAIN routine, when run on a training dataset, produces a trained machine learning (ML) model.

ML_TRAIN supports training of classification, regression, and forecasting models. A classification model is for predicting discrete values. A regression model is for predicting continuous values. A forecasting model is for creating timeseries forecasts for date and time data.

The time required to train a model can take a few minutes to a few hours depending on the number of rows and columns in the dataset, specified ML_TRAIN parameters, and the size of the HeatWave Cluster. HeatWave ML supports tables up to 10 GB is size with a maximum of 100 million rows and 900 columns.

ML_TRAIN stores machine learning models in the MODEL_CATALOG table. See Section 3.9.1, “The Model Catalog”.

For ML_TRAIN option descriptions, see Section 3.10.1, “ML_TRAIN”.

The training dataset used with ML_TRAIN must reside in a table on the MySQL DB System. For an example training dataset, see Example Data.

The following example runs ML_TRAIN on the heatwaveml_bench.census_train training dataset:

CALL sys.ML_TRAIN('heatwaveml_bench.census_train', 'revenue', 
JSON_OBJECT('task', 'classification'), @census_model);

Where:

  • heatwaveml_bench.census_train is the fully qualified name of the table that contains the training dataset (schema_name.table_name).

  • revenue is the name of the target column, which contains ground truth values.

  • JSON_OBJECT('task', 'classification') specifies the machine learning task type. Supported types are classification (the default), regression, and forecasting.

    NULL can be specified in place of the JSON_OBJECT if you intend to use the default classification task type.

    When using the regression task type, only a numeric target column is permitted.

    When using the forecasting task type, further key-value pairs are added to specify the index column and the column to be forecast. See Section 3.8, “Forecasts” for instructions.

  • @census_model is the name of the user-defined session variable that stores the model handle for the duration of the connection. User variables are written as @var_name. Some of the examples in this guide use @census_model as the variable name. Any valid name for a user-defined variable is permitted (e.g., @my_model).

After ML_TRAIN trains a model, the model is stored in the user's model catalog. To retrieve the generated model handle, query the specified session variable; for example:

mysql> SELECT @census_model;
+--------------------------------------------------+
| @census_model                                    |
+--------------------------------------------------+
| heatwaveml_bench.census_train_user1_1636729526   |
+--------------------------------------------------+
Tip

While using the same connection used to execute ML_TRAIN, you can specify the session variable (e.g., @census_model) in place of the model handle in other HeatWave ML routines, but the session variable data is lost when the current session is terminated. If you need to look up a model handle, you can do so by querying the model catalog table. See Section 3.9.8, “Model Handles”.

The quality and reliability of a trained model can be assessed using the ML_SCORE routine. For more information, see Section 3.9.6, “Scoring Models”. From MySQL 8.0.30, ML_TRAIN displays the following message if a trained model has a low score: Model Has a low training score, expect low quality model explanations.

Advanced ML_TRAIN Options

The ML_TRAIN routine provides advanced options you can use to influence model selection and training.

  • The model_list option permits specifying the type of model to be trained. If more than one type of model specified, the best model type is selected from the list. For a list of supported model types, see Model Types. This option cannot be used together with the exclude_model_list option.

    The following example trains either an XGBClassifier or LGBMClassifier model.

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 'model_list', 
    JSON_ARRAY('XGBClassifier', 'LGBMClassifier')), @census_model);
  • The exclude_model_list option specifies types of models that should not be trained. Specified model types are excluded from consideration. For a list of model types you can specify, see Model Types. This option cannot be used together with the model_list option.

    The following example excludes the LogisticRegression and GaussianNB models.

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 
    'exclude_model_list', JSON_ARRAY('LogisticRegression', 'GaussianNB')), 
    @census_model);
  • The optimization_metric option specifies a scoring metric to optimize for. For a list of supported metrics, see Scoring Metrics.

    The following example optimizes for the neg_log_loss metric.

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 'optimization_metric', 'neg_log_loss'), 
    @census_model);
  • The exclude_column_list option specifies feature columns to exclude from consideration when training a model.

    The following example excludes the 'age' column from consideration when training a model for the census dataset.

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 'exclude_column_list', JSON_ARRAY('age')), 
    @census_model);