The
ML_TRAIN
routine, when run on a training dataset, produces a trained
machine learning (ML) model.
ML_TRAIN
supports training of classification and regression models. A
classification model is for predicting discrete values. A
regression model is for predicting continuous values.
The time required to train a model can take a few minutes to a
few hours depending on the number of rows and columns in the
dataset, specified
ML_TRAIN
parameters, and the size of the HeatWave Cluster. HeatWave ML supports
tables up to 10 GB is size with a maximum of 100 million rows
and 900 columns.
ML_TRAIN
stores machine learning models in the
MODEL_CATALOG
table. See
Section 3.7.1, “The Model Catalog”.
For
ML_TRAIN
option descriptions, see Section 3.8.1, “ML_TRAIN”.
The training dataset used with
ML_TRAIN
must reside in a table on the MySQL DB System. For an example
training dataset, see Example Data.
The following example runs
ML_TRAIN
on the heatwaveml_bench.census_train
training
dataset:
CALL sys.ML_TRAIN('heatwaveml_bench.census_train', 'revenue',
JSON_OBJECT('task', 'classification'), @census_model);
Where:
heatwaveml_bench.census_train
is the fully qualified name of the table that contains the training dataset (schema_name.table_name
).revenue
is the name of the target column, which contains ground truth values.JSON_OBJECT('task', 'classification')
specifies the machine learning task type. Supported types areclassification
(the default) andregression
.NULL
can be specified in place of theJSON_OBJECT
if you intend to use the defaultclassification
task type. When using theregression
task type, only a numeric target column is permitted.@census_model
is the name of the user-defined session variable that stores the model handle for the duration of the connection. User variables are written as@
. Some of the examples in this guide usevar_name
@census_model
as the variable name. Any valid name for a user-defined variable is permitted (e.g.,@my_model
).
After
ML_TRAIN
trains a model, the model is stored in the user's model catalog.
To retrieve the generated model handle, query the specified
session variable; for example:
mysql> SELECT @census_model;
+--------------------------------------------------+
| @census_model |
+--------------------------------------------------+
| heatwaveml_bench.census_train_user1_1636729526 |
+--------------------------------------------------+
While using the same connection used to execute
ML_TRAIN
,
you can specify the session variable (e.g.,
@cenus_model
) in place of the model handle
in other HeatWave ML routines, but the session variable data is
lost when the current session is terminated. If you need to
look up a model handle, you can do so by querying the model
catalog table. See Section 3.7.7, “Model Handles”.
The quality and reliability of a trained model can be assessed
using the
ML_SCORE
routine. For more information, see
Section 3.7.5, “Scoring Models”. From MySQL 8.0.30,
ML_TRAIN
displays the following message if a trained model has a low
score: Model Has a low training score, expect low
quality model explanations
.
The
ML_TRAIN
routine provides advanced options you can use to influence
model selection and training.
-
The
model_list
option permits specifying the type of model to be trained. If more than one type of model specified, the best model type is selected from the list. For a list of supported model types, see Model Types. This option cannot be used together with theexclude_model_list
option.The following example trains either an
XGBClassifier
orLGBMClassifier
model.CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'model_list', JSON_ARRAY('XGBClassifier', 'LGBMClassifier')), @census_model);
-
The
exclude_model_list
option specifies types of models that should not be trained. Specified model types are excluded from consideration. For a list of model types you can specify, see Model Types. This option cannot be used together with themodel_list
option.The following example excludes the
LogisticRegression
andGaussianNB
models.CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'exclude_model_list', JSON_ARRAY('LogisticRegression', 'GaussianNB')), @census_model);
-
The
optimization_metric
option specifies a scoring metric to optimize for. For a list of supported metrics, see Scoring Metrics.The following example optimizes for the
neg_log_loss
metric.CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'optimization_metric', 'neg_log_loss'), @census_model);
-
The
exclude_column_list
option specifies feature columns to exclude from consideration when training a model.The following example excludes the
'age'
column from consideration when training a model for thecensus
dataset.CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'exclude_column_list', JSON_ARRAY('age')), @census_model);