Run the ML_TRAIN
routine on a
training dataset to produce a trained machine learning model.
MySQL 9.0.0 introduces changes to how MySQL HeatWave AutoML stores models. See
The Model Object Catalog Table.
ML_TRAIN
upgrades older models.
Before training models, make sure to review the following:
This topic has the following sections. Refer to the appropriate sections depending on the type of machine learning model you would like to train.
mysql> CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options | NULL], model_handle);
options: {
JSON_OBJECT("key","value"[,"key","value"] ...)
"key","value": {
['task', {'classification'|'regression'|'forecasting'|'anomaly_detection'|'log_anomaly_detection'|'recommendation'|'topic_modeling'}|NULL]
['datetime_index', 'column']
['endogenous_variables', JSON_ARRAY('column'[,'column'] ...)]
['exogenous_variables', JSON_ARRAY('column'[,'column'] ...)]
['model_list', JSON_ARRAY('model'[,'model'] ...)]
['exclude_model_list', JSON_ARRAY('model'[,'model'] ...)]
['optimization_metric', 'metric']
['include_column_list', JSON_ARRAY('column'[,'column'] ...)]
['exclude_column_list', JSON_ARRAY('column'[,'column'] ...)]
['contamination', 'contamination factor']
['supervised_submodel_options', {'n_neighbors', 'N', 'min_labels', N}']
['experimental', {'semisupervised'}]
['ensemble_score', 'ensemble metric']
['users', 'users_column']
['items', 'items_column']
['notes', 'notes_text']
['feedback', {'explicit' ['implicit'}]
['feedback_threshold', 'threshold']
['item_metadata', JSON_OBJECT('table_name'[,'database_name.table_name'] ...)]
['document_column', 'column_name']
['logad_options', JSON_OBJECT(("key","value"[,"key","value"] ...)
"key","value": {
['additional_masking_regex', JSON_ARRAY('regular_expression'[,'regular_expression', ...])]]
['window_size', 'N']
['window_stride', 'N']
['log_source_column', 'column']
}
}
}
Set the following parameters to train all machine learning models.
table_name
: The name of the table that contains the labeled training dataset. The table name must be valid and fully qualified, so it must include the database name,database_name.table_name
. The table cannot exceed 10 GB, 100 million rows, or 1017 columns.-
target_column_name
: The name of the target column containing ground truth values.MySQL HeatWave AutoML does not support a text target column.
If training an unsupervised Anomaly detection model (unlabeled data), set
target_column_name
toNULL
.Forecasting does not require
target_column_name
, and it can be set toNULL
. -
model_handle
: A user-defined session variable that stores the machine learning model handle for the duration of the connection. User variables are written as@
. Any valid name for a user-defined variable is permitted. For example,var_name
@my_model
.If you set a value to the
model_handle
variable before callingML_TRAIN
, that model handle is used for the model. A model handle must be unique in the model catalog. We recommend this method.If you don't set a value to the
model_handle
variable, MySQL HeatWave AutoML generates one. WhenML_TRAIN
finishes executing, retrieve the generated model handle by querying the session variable. See Model Handles to learn more.
The following optional parameters apply to more than one type
of machine learning task. They are specified as key-value
pairs in JSON
format. If an option is not
specified, the default setting is used. If no options are
specified, you can specify NULL
in place of
the JSON
argument.
-
task
: Specifies the machine learning task.classification
: The default value if a task is not set. Use this task type to assign items to defined categories.regression
: Use this task type if the target column is a continuous numerical value. This task generates predictions based on the relationship between a dependent variable and one or more independent variables.forecasting
: Use this task type if you have a date-time column that requires a timeseries forecast. To use this task, you must set a target column, the date-time column (datetime_index
), and endogenous variables (endogenous_variables
).anomaly_detection
: Use this task type to detect unusual patterns in data.log_anomaly_detection
: Use this task to detect unusual patterns in log data (As of MySQL 9.2.2).recommendation
: Use this task type for generate recommendations for users and items.topic_modeling
: Use this task to cluster word groups and similar expressions that best characterize the documents (As of MySQL 9.0.1-u1).
-
model_list
: The type of model to be trained. If more than one model is specified, the best model type is selected from the list. See Model Types.This option cannot be used together with the
exclude_model_list
option. -
exclude_model_list
: Model types that should not be trained. Specified model types are excluded from consideration during model selection. See Model Types.This option cannot be specified together with the
model_list
option. -
optimization_metric
: The scoring metric to optimize for when training a machine learning model. The metric must be compatible with thetask
type and the target data. See Section 10.2.16, “Optimization and Scoring Metrics”.This is not supported for
anomaly_detection
tasks. Instead, metrics for anomaly detection can only be used with theML_SCORE
routine. -
include_column_list
:ML_TRAIN
must include this list of columns.For
classification
,regression
,anomaly_detection
andrecommendation
tasks,include_column_list
ensures thatML_TRAIN
will not drop these columns.For
forecasting
tasks,include_column_list
can only includeexogenous_variables
. Ifinclude_column_list
is included in theML_TRAIN
options for aforecasting
task with at least oneexogenous_variables
, this forcesML_TRAIN
to only consider those models that supportexogenous_variables
.All columns in
include_column_list
must be included in the training table. -
exclude_column_list
: Feature columns of the training dataset to exclude from consideration when training a model. Columns that are excluded usingexclude_column_list
do not also need to be excluded from the dataset used for predictions.The
exclude_column_list
cannot contain any columns provided inendogenous_variables
,exogenous_variables
, andinclude_column_list
. notes
: Add notes to themodel_metadata
for your own reference.
Refer to the following model-specific parameters to train different types of machine learning models.
To train a classification model, set the
task
to classification
.
If the task
is set to
NULL
, or if all training options is set to
NULL
, a classification model is trained by
default.
-
The following example sets the model handle before training, which is good practice. See Define Model Handle. The
task
is set toclassification
.mysql> SET @census_model = 'census_manual'; mysql> CALL sys.ML_TRAIN('census_data.census_train', 'revenue', JSON_OBJECT('task', 'classification'), @census_model);
-
The following example sets all options to
NULL
, soML_TRAIN
runs theclassification
task option by default.mysql> CALL sys.ML_TRAIN('census_data.census_train', 'revenue', NULL, @census_model);
-
The following example specifies the
regression
task type.mysql> CALL sys.ML_TRAIN('nyc_taxi.nyc_taxi_train', 'tip_amount', JSON_OBJECT('task', 'regression'), @nyc_taxi);
See the following to learn more about forecasting models:
To train a forecasting model, set the task
to forecasting
and set the following
required parameters.
-
datetime_index
: The column name for a datetime column that acts as an index for the forecast variable. The column can be one of the supported datetime column types,DATETIME
,TIMESTAMP
,DATE
,TIME
, andYEAR
, or an auto-incrementing index.The forecast models
SARIMAXForecaster
,VARMAXForecaster
, andDynFactorForecaster
cannot back test, that is forecast into training data, when usingexogenous_variables
. Therefore, the predict table must not overlap thedatetime_index
with the training table. The start date in the predict table must be a date immediately following the last date in the training table whenexogenous_variables
are used. For example, the predict table has to start with year 2024 if the training table withYEAR
data typedatetime_index
ends with year 2023.The
datetime_index
for the predict table must not have missing dates after the last date in the training table. For example, the predict table has to start with year 2024 if the training table withYEAR
data typedatetime_index
ends with year 2023. The predict table cannot start with year, for example, 2025 or 2030, because that would miss out 1 and 6 years, respectively.When
options
do not includeexogenous_variables
, the predict table can overlap thedatetime_index
with the training table. This supports back testing.The valid range of years for
datetime_index
dates must be between 1678 and 2261. It will cause an error if any part of the training table or predict table has dates outside this range. The last date in the training table plus the predict table length must still be inside the valid year range. For example, if thedatetime_index
in the training table hasYEAR
data type, and the last date is year 2023, the predict table length must be less than 238 rows: 2261 minus 2023 equals 238 rows. -
endogenous_variables
: The column or columns to be forecast.Univariate forecasting models support a single numeric column, specified as a
JSON_ARRAY
. This column must also be specified as thetarget_column_name
, because that field is required, but it is not used in that location.Multivariate forecasting models support multiple numeric columns, specified as a
JSON_ARRAY
. One of these columns must also be specified as thetarget_column_name
.endogenous_variables
cannot be text.
Set the following forecasting options as required to train forecasting models.
-
exogenous_variables
: For forecasting tasks, the column or columns of independent, non-forecast, predictive variables, specified as aJSON_ARRAY
. These optional variables are not forecast, but help to predict the future values of the forecast variables. These variables affect a model without being affected by it. For example, for sales forecasting these variables might be advertising expenditure, occurrence of promotional events, weather, or holidays.ML_TRAIN
will consider all supported models during the algorithm selection stage ifoptions
includesexogenous_variables
, including models that do not supportexogenous_variables
.For example, if
options
includes univariateendogenous_variables
withexogenous_variables
, thenML_TRAIN
will considerNaiveForecaster
,ThetaForecaster
,ExpSmoothForecaster
,ETSForecaster
,STLwESForecaster
,STLwARIMAForecaster
, andSARIMAXForecaster
.ML_TRAIN
will ignoreexogenous_variables
if the model does not support them.Similarly, if
options
includes multivariateendogenous_variables
withexogenous_variables
, thenML_TRAIN
will considerVARMAXForecaster
andDynFactorForecaster
.If
options
also includesinclude_column_list
, this will forceML_TRAIN
to only consider those models that supportexogenous_variables
. include_column_list
: Can only includeexogenous_variables
. Ifinclude_column_list
contains at least oneexogenous_variables
, this will forceML_TRAIN
to only consider those models that supportexogenous_variables
.
-
The following example specifies the
forecasting
task type, and the additional required parameters,datetime_index
andendogenous_variables
.mysql> CALL sys.ML_TRAIN('air_data.air_train', 'passengers', JSON_OBJECT('task', 'forecasting', 'datetime_index', 'Period', 'endogenous_variables', JSON_ARRAY('passengers')), @model);
-
The following example specifies the
OrbitForecaster
forecasting model with exogenous variables.mysql> CALL sys.ML_TRAIN('mlcorpus.opsd_germany_daily_train', NULL, JSON_OBJECT('task', 'forecasting', 'datetime_index', 'ddate', 'endogenous_variables', JSON_ARRAY('consumption'), 'exogenous_variables', JSON_ARRAY('wind', 'solar', 'wind_solar'), 'model_list', JSON_ARRAY('OrbitForecaster')), @model);
-
The following example specifies the
OrbitForecaster
forecasting model without exogenous variables.mysql> CALL sys.ML_TRAIN('air_data.air_train', 'passengers', JSON_OBJECT('task', 'forecasting', 'datetime_index', 'Period', 'endogenous_variables', JSON_ARRAY('passengers'), 'model_list', JSON_ARRAY('OrbitForecaster')), @model);
See the following to learn more about anomaly detection models:
To train an anomaly detection model, set the appropriate required parameters depending on the type of anomaly detection model to train.
Set the
task
parameter toanomaly_detection
for running anomaly detection on table data, orlog_anomaly_detection
for running anomaly detection on log data (MySQL 9.2.2 and later).If running an unsupervised model, the
target_column_name
parameter must be set toNULL
.-
If running a semi-supervised model:
The
target_column_name
parameter must specify a column whose only allowed values are 0 (normal), 1 (anomalous), and NULL (unlabeled). All rows will be used to train the unsupervised component, while the rows with a value different than NULL will be used to train the supervised component.The
experimental
option must be set tosemisupervised
.
-
If running anomaly detection on log data (MySQL 9.2.2 and later), the input table can only have the following columns:
The column containing the logs.
If including logs from different sources, a column containing the source of each log. Identify this column with the
log_source_column
option.If including labeled data, a column identifying the labeled log lines. See Semi-supervised Anomaly Detection to learn more.
At least one column must act as the primary key to establish the temporal order of logs. If the primary key column (or columns) is not one of the previous required columns (log data, source of log, or label), then you must use the
exclude_column_list
option when runningML_TRAIN
to exclude all primary key columns that don't include required data. See Syntax Examples for Anomaly Detection Training to review relevant examples.If the input table has additional columns to the ones permitted, you must use the
exclude_column_list
option to exclude irrelevant columns.
Set the following options as needed for anomaly detection models:
-
contamination
: Represents an estimate of the percentage of outliers in the training table.The contamination factor is calculated as: estimated number of rows with anomalies/total number of rows in the training table.
The contamination value must be greater than 0 and less than 0.5. The default value is 0.01.
model_list
: You can select the Principal Component Analysis (PCA), Generalized Local Outlier Factor (GLOF), or Generalized kth Nearest Neighbors (GkNN) model. If no option is specified, the default model is GkNN. Selecting more than one model or an unsupported model produces an error.
To train a semi-supervised anomaly detection model (MySQL 9.0.1-u1 and later), set the following options:
-
supervised_submodel_options
: Allows you to set optional override parameters for the supervised model component. The only model supported isDistanceWeightedKNNClassifier
. The following parameters are supported:n_neighbors
: Sets the desired k value that checks the k closest neighbors for each unclassified point. The default value is 5 and the value must be an integer greater than 0.min_labels
: Sets the minimum number of labeled data points required to train the supervised component. If fewer labeled data points are provided during training of the model,ML_TRAIN
fails. The default value is 20 and the value must be an integer greater than 0.
ensemble_score
: This option specifies the metric to use to score the ensemble of unsupervised and supervised components. It identifies the optimal weight between the two components based on the metric. The supported metrics areaccuracy
,precision
,recall
, andf1
. The default metric isf1
.
To train a model for anomaly detection on log data (MySQL 9.2.2 and later), set the following options:
-
logad_options
: AJSON_OBJECT
that allows you to configure the following options.-
additional_masking_regex
: Allows you to mask log data in aJSON_ARRAY
. By default, the following parameters are automatically masked during training.IP
DATETIME
TIME
HEX
IPPORT
OCID
window_size
: Specifies the maximum number of log lines to be grouped for anomaly detection. The default value is 10.window_stride
: Specifies the stride value to use for segmenting log lines. For example, there is log A, B, C, D, and E. Thewindow_size
is 3, and thewindow_stride
is 2. The first row has log A, B, and C. The second row has log C, D, and E. If this value is equal towindow_size
, there is no overlapping of log segments. The default value is 3.log_source_column
: Specifies the column name that contains the source identifier of the respective log lines. Log lines are grouped according to their respective source (for example, logs from multiple MySQL databases that are in the same table). By default, all log lines are assumed to be from the same source.
-
Anomaly detection models don't support the following options during training:
exclude_model_list
optimization_metric
-
The following example specifies the
anomaly_detection
task type and excludes thetarget
column for training. After training, query the model type to confirm the default GKNN model is selected.mysql> CALL sys.ML_TRAIN('volcano_data.volcano_data_train', NULL, JSON_OBJECT('task', 'anomaly_detection', 'exclude_column_list', JSON_ARRAY('target')), @anomaly_model); Query OK, 0 rows affected (10.1872 sec) mysql> SELECT model_handle, model_type FROM ML_SCHEMA_admin.MODEL_CATALOG WHERE model_handle=@model; +---------------+------------+ | model_handle | model_type | +---------------+------------+ | anomaly_model | GKNN | +---------------+------------+ 1 row in set (0.0428 sec)
-
The following example specifies the
anomaly_detection
task with acontamination
option. After training, query the model catalog metadata to check the value of thecontamination
option.mysql> CALL sys.ML_TRAIN('volcano_data.volcano_data_train', NULL, JSON_OBJECT('task', 'anomaly_detection', 'contamination', 0.013, 'exclude_column_list', JSON_ARRAY('target')), @model2); Query OK, 0 rows affected (11.6307 sec) mysql> SELECT JSON_EXTRACT(model_metadata, '$.contamination') FROM ML_SCHEMA_user1.MODEL_CATALOG WHERE model_handle = @model2; +-------------------------------------------------+ | JSON_EXTRACT(model_metadata, '$.contamination') | +-------------------------------------------------+ | 0.0130000002682209 | +-------------------------------------------------+ 1 row in set (0.0717 sec)
-
The following example enables semi-supervised learning using all defaults. The
target_column_name
is set totarget
. Theexperimental
option is set tosemisupervised
.mysql> CALL sys.ML_TRAIN('mlcorpus.anomaly_train_with_partial_target', "target", CAST('{"task": "anomaly_detection", "experimental": {"semisupervised": {}}}' as JSON), @semisupervised_model);
-
The following example enables semi-supervised learning with additional options.
mysql> CALL sys.ML_TRAIN('mlcorpus.anomaly_train_with_partial_target', "target", CAST('{"task": "anomaly_detection", "experimental": {"semisupervised": {"supervised_submodel_options": {"min_labels": 10, "n_neighbors": 3}, "ensemble_score": "recall"}}}' as JSON), @semisupervised_model_options);
Where:
The
supervised_submodel_options
parametermin_labels
is set to 10.The
supervised_submodel_options
parametern_neighbors
is set to 3.The
ensemble_score
option is set to therecall
metric.
-
The following example selects the PCA (Principal Component Analysis) anomaly detection model. After training, query the model handle and model type to confirm the selected model.
mysql> CALL sys.ML_TRAIN('volcano_data.volcano_data_train', NULL, JSON_OBJECT('task', 'anomaly_detection', 'exclude_column_list', JSON_ARRAY('target'), 'model_list', JSON_ARRAY('PCA')), @model); Query OK, 0 rows affected (4.8730 sec) mysql> SELECT model_handle, model_type FROM ML_SCHEMA_admin.MODEL_CATALOG WHERE model_handle=@model; +--------------+------------+ | model_handle | model_type | +--------------+------------+ | anomaly_pca | PCA | +--------------+------------+ 1 row in set (0.0416 sec)
-
The following example runs the
log_anomaly_detection
task with available default values.mysql> CALL sys.ML_TRAIN('mlcorpus.`log_anomaly_just_patterns`', NULL, JSON_OBJECT('task', 'log_anomaly_detection'), @logad_model);
-
The following example runs the
log_anomaly_detection
task with the PCA anomaly detection model.mysql> CALL sys.ML_TRAIN('mlcorpus.`log_anomaly_just_patterns`', NULL, JSON_OBJECT('task', 'log_anomaly_detection', 'model_list', JSON_ARRAY('PCA')), @logad_model);
-
An
ML_TRAIN
example that excludes two primary key columns:primary_key_column1
andprimary_key_column2
. These columns must be excluded because they do not have one of the required items of data for training: the log data, the source of the log, or the label.mysql>CALL sys.ML_TRAIN( 'mlcorpus.log_anomaly_two_primary', NULL, JSON_OBJECT( 'task', 'log_anomaly_detection', 'logad_options', JSON_OBJECT('window_size', 2, 'window_stride', 1), 'exclude_column_list', JSON_ARRAY('primary_key_column1', 'primary_key_column2') ), @log_anomaly_us );
-
The following example runs the
log_anomaly_detection
task and masks log data with theadditional_masking_regex
option. In addition to the default parameters that are automatically masked, email addresses from Yahoo, Hotmail, and Gmail are also masked. Thelog_source_column
option is also included, which specifies the column that identifies the respective source of the log line.mysql> CALL sys.ML_TRAIN('mlcorpus.`log_anomaly_sourced`', NULL, JSON_OBJECT('task', 'log_anomaly_detection', 'logad_options', JSON_OBJECT('additional_masking_regex', JSON_ARRAY('(\W|^)[\w.\-]{0,25}@(yahoo|hotmail|gmail)\.com(\W|$)'), 'log_source_column', 'source')), @log_anomaly_us);
-
The following example sets semi-supervised learning for training the log data for anomaly detection. The window size is also set to a value of 4, and the window stride is set to 1.
mysql> CALL sys.ML_TRAIN('mlcorpus.`log_anomaly_semi`', "label", JSON_OBJECT('task', 'log_anomaly_detection', 'logad_options', JSON_OBJECT('window_size', 4, 'window_stride', 1), "experimental", JSON_OBJECT("semisupervised", JSON_OBJECT("supervised_submodel_options", JSON_OBJECT("min_labels", 10)))), @log_anomaly_us);
See Recommendation Task Types to learn more about recommendation models.
To train a recommendation model, set the
task
to recommendation
and set the following required parameters.
-
users
: Specifies the column name corresponding to the user ids. Values in this column must be in aSTRING
data type, otherwise an error will be generated during training.This must be a valid column name, and it must be different from the
items
column name. -
items
: Specifies the column name corresponding to the item ids. Values in this column must be in aSTRING
data type, otherwise an error will be generated during training.This must be a valid column name, and it must be different from the
users
column name.
To train a recommendation model with explicit feedback, set
feedback
to explicit
. If
feedback
is not set, the default value is
explicit
.
To train a recommendation model with implicit feedback, set
feedback
to implicit
and
set the following option as needed:
feedback_threshold
: The feedback threshold for a recommendation model that uses implicit feedback. It represents the threshold required to be considered positive feedback. For example, if numerical data records the number of times users interact with an item, you might set a threshold with a value of 3. This means users would need to interact with an item more than three times to be considered positive feedback.
To train a content-based recommendation model, set
feedback
to implicit
and
set the following required parameters:
-
item_metadata
: Defines the table that has item descrption. It is a JSON object that can have thetable_name
option as a key, which specifies the table that has item descriptions. This table must only have two columns: one corresponding to theitem_id
, and the other with aTEXT
data type (TINYTEXT, TEXT, MEDIUMTEXT, LONGTEXT) that has the description of the item.table_name
: To be used with theitem_metadata
option. It specifies the table name that has item descriptions. It must be a string in a fully qualified format (database_name.table_name) that specifies the table name.
-
The following example defines a model handle, trains a recommendation model with no specified model type, and then queries the automatically selected
model_type
for the trained model from the model catalog.mysql> SET @rec_model = 'rec_model'; mysql> CALL sys.ML_TRAIN('movielens_data.movielens_train', 'rating', JSON_OBJECT('task', 'recommendation', 'users', 'user_id', 'items', 'item_id'), @rec_model); Query OK, 0 rows affected (14.4091 sec) mysql> SELECT model_handle, model_type FROM ML_SCHEMA_admin.MODEL_CATALOG WHERE model_handle='rec_model'; +--------------+------------+ | model_handle | model_type | +--------------+------------+ | rec_model | SVDpp | +--------------+------------+ 1 row in set (0.0395 sec)
-
The following example specifies the
SVD
recommendation model type, and then queries themodel_type
for the trained model from the model catalog to confirm the selected model.mysql> CALL sys.ML_TRAIN('movielens_data.movielens_train', 'rating', JSON_OBJECT('task', 'recommendation', 'users', 'user_id', 'items', 'item_id', 'model_list', JSON_ARRAY('SVD')), @rec_model); Query OK, 0 rows affected (9.4139 sec) mysql> SELECT model_type FROM ML_SCHEMA_admin.MODEL_CATALOG WHERE model_handle=@rec_model; +------------+ | model_type | +------------+ | SVD | +------------+ 1 row in set (0.0485 sec)
-
The following example specifies three models for the
model_list
option. From those three recommendation models, theSVDpp
model is automatically selected for training.mysql> CALL sys.ML_TRAIN('movielens_data.movielens_train', 'rating', JSON_OBJECT('task', 'recommendation', 'users', 'user_id', 'items', 'item_id', 'model_list', JSON_ARRAY('SVD', 'SVDpp', 'NMF')), @rec_model); Query OK, 0 rows affected (13.8714 sec) mysql> SELECT model_type FROM ML_SCHEMA_admin.MODEL_CATALOG WHERE model_handle=@rec_model; +------------+ | model_type | +------------+ | SVDpp | +------------+ 1 row in set (0.0403 sec)
-
The following example specifies the
recommendation
task with implicit feedback.mysql> CALL sys.ML_TRAIN('mlcorpus.training_table', 'rating', JSON_OBJECT('task', 'recommendation', 'users', 'user_id', 'items', 'item_id', 'feedback', 'implicit'), @model); Query OK, 0 rows affected (2 min 13.6415 sec)
-
The following example trains a content-based recommendation model by specifying a table with item descriptions (
mlcorpus_recsys.`citeulike_items_sample
). The optimization metrichit_ratio_at_k
is used. The model must use implicit feedback.mysql> CALL sys.ML_TRAIN('amazon_data.`amazon_train`', 'rating', JSON_OBJECT('task', 'recommendation', 'model_list', JSON_ARRAY('CTR'), 'users', 'user_id', 'items', 'item_id','feedback', 'implicit', 'optimization_metric', 'hit_ratio_at_k', 'item_metadata', JSON_OBJECT('table_name', 'amazon_data.`amazon_item_descriptions`')), @rec_model);
To train a machine learning model with topic modeling, set the
task
to topic_modeling
and set the following required parameter:
document_column
: Specify the column name that contains the text to train.
The following parameters are not supported for training machine learning models with topic modeling:
model_list
optimization_metric
exclude_model_list
exclude_column_list
include_column_list
The following example runs the
topic_modeling
task with the required
defined parameters.
mysql> CALL sys.ML_TRAIN('topic_modeling_data.text_types_train', NULL, JSON_OBJECT('task', 'topic_modeling', 'document_column', 'D0'), @topic_model);
The ML_TRAIN
routine also runs
the ML_EXPLAIN
routine with the
default Permutation Importance model for prediction explainers
and model explainers. See
Generate Model
Explanations. To train other prediction explainers and
model explainers use the
ML_EXPLAIN
routine with the
preferred explainer after
ML_TRAIN
.
ML_EXPLAIN
does not support the
anomaly_detection
and
recommendation
tasks, and
ML_TRAIN
does not run
ML_EXPLAIN
.
-
The
model_list
option permits specifying the type of model to be trained. If more than one type of model specified, the best model type is selected from the list. For a list of supported model types, see Model Types. This option cannot be used together with theexclude_model_list
option.The following example trains either an
XGBClassifier
orLGBMClassifier
model.mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task','classification', 'model_list', JSON_ARRAY('XGBClassifier', 'LGBMClassifier')), @iris_model);
-
The
exclude_model_list
option specifies types of models that should not be trained. Specified model types are excluded from consideration. For a list of model types you can specify, see Model Types. This option cannot be used together with themodel_list
option.The following example excludes the
LogisticRegression
andGaussianNB
models.mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task','classification', 'exclude_model_list', JSON_ARRAY('LogisticRegression', 'GaussianNB')), @iris_model);
-
The
optimization_metric
option specifies a scoring metric to optimize for. See: Optimization and Scoring Metrics.The following example optimizes for the
neg_log_loss
metric.mysql> CALL sys.ML_TRAIN('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'optimization_metric', 'neg_log_loss'), @census_model);
-
The
exclude_column_list
option specifies feature columns to exclude from consideration when training a model.The following example excludes the
'age'
column from consideration when training a model for thecensus
dataset.mysql> CALL sys.ML_TRAIN('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'exclude_column_list', JSON_ARRAY('age')), @census_model);
-
The
include_column_list
option specifies feature columns that must be considered for training and should not be dropped.The following example specifies to consider the
'job'
column when training a model for thecensus
dataset.mysql> CALL sys.ML_TRAIN('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'include_column_list', JSON_ARRAY('job')), @census_model);
-
The following example adds
notes
to themodel_metadata
. After training, query the model metadata to confirm the notes were successfully added. Optionally, useJSON_PRETTY
to view the output in an easily readable format.mysql> CALL sys.ML_TRAIN('bank_test.bank_train', 'y', JSON_OBJECT('task', 'classification', 'notes', 'bank marketing model'), @model); Query OK, 0 rows affected (1 min 34.7958 sec) mysql> SELECT JSON_PRETTY(model_metadata) FROM ML_SCHEMA_admin.MODEL_CATALOG WHERE model_handle=@model; +----------------------------------------------------------------------------------------------------------------------------+ | JSON_PRETTY(model_metadata) | +----------------------------------------------------------------------------------------------------------------------------+ | { "task": "classification", "notes": "bank marketing model", "chunks": 1, "format": "HWMLv2.0", "n_rows": 300, "status": "Ready", "options": { "task": "classification", "notes": "bank marketing model", "model_explainer": "permutation_importance", "prediction_explainer": "permutation_importance" }, "n_columns": 16, "column_names": [ "age", "job", "marital", "education", "default", "balance", "housing", "loan", "contact", "day", "month", "duration", "campaign", "pdays", "previous", "poutcome" ], "contamination": null, "model_quality": "high", "training_time": 70.57345581054688, "algorithm_name": "XGBClassifier", "training_score": -0.2614343762397766, "build_timestamp": 1746109286, "n_selected_rows": 240, "training_params": { "recommend": "ratings", "force_use_X": false, "recommend_k": 3, "remove_seen": true, "ranking_topk": 10, "lsa_components": 100, "ranking_threshold": 1, "feedback_threshold": 1 }, "train_table_name": "bank_test.bank_train", "model_explanation": { "permutation_importance": { "age": 0.0, "day": 0.0, "job": 0.0, "loan": 0.0, "month": 0.0, "pdays": 0.0, "balance": 0.0, "contact": 0.0175, "default": 0.0, "housing": 0.0, "marital": 0.0, "campaign": 0.0, "duration": 0.1524, "poutcome": 0.0278, "previous": 0.0, "education": 0.0 } }, "n_selected_columns": 3, "target_column_name": "y", "optimization_metric": "neg_log_loss", "selected_column_names": [ "contact", "duration", "poutcome" ], "training_drift_metric": { "mean": 0.2961, "variance": 0.3273 } } | +----------------------------------------------------------------------------------------------------------------------------+ 1 row in set (0.0403 sec)