Run the ML_TRAIN
routine on a
labeled training dataset to produce a trained machine learning
model.
MySQL 9.0.0 introduces support for large models that changes
how HeatWave AutoML stores models, see:
Section 3.14.1, “The Model Catalog”.
ML_TRAIN
upgrades older models.
mysql> CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options], model_handle);
options: {
JSON_OBJECT("key","value"[,"key","value"] ...)
"key","value": {
['task', {'classification'|'regression'|'forecasting'|'anomaly_detection'|'recommendation'}|NULL]
['datetime_index', 'column']
['endogenous_variables', JSON_ARRAY('column'[,'column'] ...)]
['exogenous_variables', JSON_ARRAY('column'[,'column'] ...)]
['model_list', JSON_ARRAY('model'[,'model'] ...)]
['exclude_model_list', JSON_ARRAY('model'[,'model'] ...)]
['optimization_metric', 'metric']
['include_column_list', JSON_ARRAY('column'[,'column'] ...)]
['exclude_column_list', JSON_ARRAY('column'[,'column'] ...)]
['contamination', 'contamination factor']
['users', 'users_column']
['items', 'items_column']
['notes', 'notes_text']
['feedback', {'explicit' ['implicit'}]
['feedback_threshold', 'threshold']
['item_metadata', JSON_OBJECT('table_name'[,'schema_name.table_name'] ...)]
}
}
The MySQL account that runs
ML_TRAIN
cannot have a
period character (".") in its name; for example, a user
named
'joesmith'@'
is permitted to train a model, but a user named
%
''joe.smith'@'
is not. For more information about this limitation, see
Section 3.19, “HeatWave AutoML Limitations”.
%
'
The ML_TRAIN
routine also
runs the ML_EXPLAIN
routine
with the default Permutation Importance model for prediction
explainers and model explainers. See
Section 3.6, “Training Explainers”. To train other
prediction explainers and model explainers use the
ML_EXPLAIN
routine with the
preferred explainer after
ML_TRAIN
.
ML_EXPLAIN
does not support
the anomaly_detection
and
recommendation
tasks, and
ML_TRAIN
does not run
ML_EXPLAIN
.
ML_TRAIN
parameters:
table_name
: The name of the table that contains the labeled training dataset. The table name must be valid and fully qualified; that is, it must include the schema name,schema_name.table_name
. The table cannot exceed 10 GB, 100 million rows, or 1017 columns.-
target_column_name
: The name of the target column containing ground truth values.HeatWave AutoML does not support a text target column.
Anomaly detection does not require labelled data, and
target_column_name
must be set toNULL
.Forecasting does not require
target_column_name
, and it can be set toNULL
. -
model_handle
: The name of a user-defined session variable that stores the machine learning model handle for the duration of the connection. User variables are written as@
. Some of the examples in this guide usevar_name
@census_model
as the variable name. Any valid name for a user-defined variable is permitted, for example@my_model
.If the
model_handle
variable was set to a value before callingML_TRAIN
, that model handle is used for the model. A model handle must be unique in the model catalog.Otherwise, HeatWave AutoML generates a model handle. When
ML_TRAIN
finishes executing, retrieve the generated model handle by querying the session variable. See Section 3.14.8, “Model Handles”. -
options
: Optional parameters specified as key-value pairs inJSON
format. If an option is not specified, the default setting is used. If no options are specified, you can specifyNULL
in place of theJSON
argument.-
task
: Specifies the machine learning task. Permitted values are:classification
: The default. Use this task type if the target is a discrete value.regression
: Use this task type if the target column is a continuous numerical value.forecasting
: Use this task type if you have a date-time column that requires a timeseries forecast. To use this task, you must set a target column, the date-time column (datetime_index
), and endogenous variables (endogenous_variables
).anomaly_detection
: Use this task type to detect anomalies.recommendation
: Use this task type for recommendation models.
-
datetime_index
: For forecasting tasks, the column name for a datetime column that acts as an index for the forecast variable. The column can be one of the supported datetime column types,DATETIME
,TIMESTAMP
,DATE
,TIME
, andYEAR
, or an auto-incrementing index.The forecast models
SARIMAXForecaster
,VARMAXForecaster
, andDynFactorForecaster
cannot back test, that is forecast into training data, when usingexogenous_variables
. Therefore, the predict table must not overlap thedatetime_index
with the training table. The start date in the predict table must be a date immediately following the last date in the training table whenexogenous_variables
are used. For example, the predict table has to start with year 2024 if the training table withYEAR
data typedatetime_index
ends with year 2023.The
datetime_index
for the predict table must not have missing dates after the last date in the training table. For example, the predict table has to start with year 2024 if the training table withYEAR
data typedatetime_index
ends with year 2023. The predict table cannot start with year, for example, 2025 or 2030, because that would miss out 1 and 6 years, respectively.When
options
do not includeexogenous_variables
, the predict table can overlap thedatetime_index
with the training table. This supports back testing.The valid range of years for
datetime_index
dates must be between 1678 and 2261. It will cause an error if any part of the training table or predict table has dates outside this range. The last date in the training table plus the predict table length must still be inside the valid year range. For example, if thedatetime_index
in the training table hasYEAR
data type, and the last date is year 2023, the predict table length must be less than 238 rows: 2261 minus 2023 equals 238 rows. -
endogenous_variables
: For forecasting tasks, the column or columns to be forecast.Univariate forecasting models support a single numeric column, specified as a
JSON_ARRAY
. This column must also be specified as thetarget_column_name
, because that field is required, but it is not used in that location.Multivariate forecasting models support multiple numeric columns, specified as a
JSON_ARRAY
. One of these columns must also be specified as thetarget_column_name
.endogenous_variables
cannot be text. -
exogenous_variables
: For forecasting tasks, the column or columns of independent, non-forecast, predictive variables, specified as aJSON_ARRAY
. These optional variables are not forecast, but help to predict the future values of the forecast variables. These variables affect a model without being affected by it. For example, for sales forecasting these variables might be advertising expenditure, occurrence of promotional events, weather, or holidays.ML_TRAIN
will consider all supported models during the algorithm selection stage ifoptions
includesexogenous_variables
, including models that do not supportexogenous_variables
.For example, if
options
includes univariateendogenous_variables
withexogenous_variables
, thenML_TRAIN
will considerNaiveForecaster
,ThetaForecaster
,ExpSmoothForecaster
,ETSForecaster
,STLwESForecaster
,STLwARIMAForecaster
, andSARIMAXForecaster
.ML_TRAIN
will ignoreexogenous_variables
if the model does not support them.Similarly, if
options
includes multivariateendogenous_variables
withexogenous_variables
, thenML_TRAIN
will considerVARMAXForecaster
andDynFactorForecaster
.If
options
also includesinclude_column_list
, this will forceML_TRAIN
to only consider those models that supportexogenous_variables
. -
model_list
: The type of model to be trained. If more than one model is specified, the best model type is selected from the list. See Section 3.16.13, “Model Types”.This option cannot be used together with the
exclude_model_list
option. Before MySQL 8.4.0,model_list
is not supported foranomaly_detection
tasks. As of MySQL 8.4.0, you can select the Principal Component Analysis (PCA) model and Generalized Local Outlier Factor (GLOF) model foranomaly_detection
tasks. The default model is the Generalized kth Nearest Neighbors (GkNN) model. -
exclude_model_list
: Model types that should not be trained. Specified model types are excluded from consideration during model selection. See Section 3.16.13, “Model Types”.This option cannot be specified together with the
model_list
option, and it is not supported foranomaly_detection
tasks. -
optimization_metric
: The scoring metric to optimize for when training a machine learning model. The metric must be compatible with thetask
type and the target data. See Section 3.16.14, “Optimization and Scoring Metrics”.This is not supported for
anomaly_detection
tasks. -
include_column_list
:ML_TRAIN
must include this list of columns.For
classification
,regression
,anomaly_detection
andrecommendation
tasks,include_column_list
ensures thatML_TRAIN
will not drop these columns.For
forecasting
tasks,include_column_list
can only includeexogenous_variables
. Ifinclude_column_list
contains at least oneexogenous_variables
, this will forceML_TRAIN
to only consider those models that supportexogenous_variables
.All columns in
include_column_list
must be included in the training table. -
exclude_column_list
: Feature columns of the training dataset to exclude from consideration when training a model. Columns that are excluded usingexclude_column_list
do not need to be excluded from the dataset used for predictions.The
exclude_column_list
cannot contain any columns provided inendogenous_variables
,exogenous_variables
, andinclude_column_list
. contamination
: The optional contamination factor for use with theanomaly_detection
task. 0 <contamination
< 0.5. The default value is 0.1.-
users
: Specifies the column name corresponding to the user ids.This must be a valid column name, and it must be different from the
items
column name. -
items
: Specifies the column name corresponding to the item ids.This must be a valid column name, and it must be different from the
users
column name. notes
: Add notes to themodel_metadata
.feedback
: The type of feedback for a recommendation model,explicit
, the default, orimplicit
.feedback_threshold
: The feedback threshold for a recommendation model with implicit feedback. All ratings at or above thefeedback_threshold
are implied to provide positive feedback. All ratings below thefeedback_threshold
are implied to provide negative feedback. The default value is 1.item_metadata:
When training a content-based recommendation model, this option specifies the table that has item descriptions. This table must have two columns: one specifying theitem_id
and the other with a text description of the item. You must include thetable_name
option as a JSON object and provide it in a fully qualified format asschema_name.table_name
.
-
-
An
ML_TRAIN
example that uses theclassification
task option implicitly (classification
is the default if not specified explicitly):mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', NULL, @iris_model);
-
An
ML_TRAIN
example that specifies theclassification
task type explicitly, and sets a model handle instead of letting HeatWave AutoML generate one:mysql> SET @iris_model = 'iris_manual'; mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task', 'classification'), @iris_model);
-
An
ML_TRAIN
example that specifies theregression
task type:mysql> CALL sys.ML_TRAIN('employee.salary_train', 'salary', JSON_OBJECT('task', 'regression'), @salary_model);
-
An
ML_TRAIN
example that specifies themodel_list
option. This example trains either anXGBClassifier
orLGBMClassifier
model.mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task','classification', 'model_list', JSON_ARRAY('XGBClassifier', 'LGBMClassifier')), @iris_model);
-
An
ML_TRAIN
example that specifies theexclude_model_list
option. In this example,LogisticRegression
andGaussianNB
models are excluded from model selection.mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task','classification', 'exclude_model_list', JSON_ARRAY('LogisticRegression', 'GaussianNB')), @iris_model);
-
An
ML_TRAIN
example that specifies theoptimization_metric
option.mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task','classification', 'optimization_metric', 'neg_log_loss') , @iris_model);
-
An
ML_TRAIN
example that specifies theexclude_column_list
option.mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task','classification', 'exclude_column_list', JSON_ARRAY('sepal length', 'petal length')), @iris_model);
-
An
ML_TRAIN
example that addsnotes
to themodel_metadata
.mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task','classification', 'notes', 'testing2'), @iris_model); Query OK, 0 rows affected (1 min 42.53 sec) mysql> SELECT model_metadata FROM ML_SCHEMA_root.MODEL_CATALOG WHERE model_handle=@iris_model; +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | model_metadata | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | {"task": "classification", "notes": "testing2", "format": "HWMLv1.0", "n_rows": 105, "status": "Ready", "n_columns": 4, "column_names": "{\"0\": \"sepal length\", \"1\": \"sepal width\", \"2\": \"petal length\", \"3\": \"petal width\"}", "contamination": 0.0, "model_quality": "high", "training_time": 10.408954620361328, "algorithm_name": "RandomForestClassifier", "training_score": -0.08308402448892593, "build_timestamp": 0, "n_selected_rows": 84, "train_table_name": "mlcorpus_v5.iris_train", "training_columns": "[\"sepal length\", \"sepal width\", \"petal length\", \"petal width\"]", "model_explanation": "{\"permutation_importance\": {\"petal width\": 0.5926, \"sepal width\": 0.0, \"petal length\": 0.0423, \"sepal length\": 0.0}}", "n_selected_columns": 2, "target_column_name": "class", "optimization_metric": "neg_log_loss", "selected_column_names": "[\"petal length\", \"petal width\"]"} | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 1 row in set (0.00 sec)
-
An
ML_TRAIN
example that stores the model in themodel_object_catalog
.mysql> CALL sys.ML_TRAIN('mlcorpus.iris_train', 'class', NULL, @iris_model); Query OK, 0 rows affected (32.18 sec) mysql> SELECT model_object, model_object_size FROM ML_SCHEMA_user1.MODEL_CATALOG WHERE model_handle=@iris_model; +--------------+-------------------+ | model_object | model_object_size | +--------------+-------------------+ | NULL | 346866 | +--------------+-------------------+ 1 row in set (0.00 sec) mysql> SELECT model_metadata->>'$.format', model_metadata->>'$.chunks' FROM ML_SCHEMA_user1.MODEL_CATALOG WHERE model_handle=@iris_model; +-----------------------------+-----------------------------+ | model_metadata->>'$.format' | model_metadata->>'$.chunks' | +-----------------------------+-----------------------------+ | HWMLv2.0 | 1 | +-----------------------------+-----------------------------+ 1 row in set (0.00 sec) mysql> SELECT chunk_id, length(model_object) FROM ML_SCHEMA_user1.model_object_catalog WHERE model_handle=@iris_model; +----------+----------------------+ | chunk_id | length(model_object) | +----------+----------------------+ | 1 | 346866 | +----------+----------------------+ 1 row in set (0.00 sec)
See also: