This tutorial illustrates an end-to-end example of creating and
using a predictive machine learning model using HeatWave AutoML. It
steps through preparing data, using the
ML_TRAIN
routine to train a model,
and using ML_PREDICT_*
and
ML_EXPLAIN_*
routines to generate predictions
and explanations. The tutorial also demonstrates how to assess the
quality of a model using the
ML_SCORE
routine, and how to view a
model explanation to understand how the model works.
For an online workshop based on this tutorial, see Get started with MySQL HeatWave AutoML.
The tutorial uses the publicly available Iris Data Set from the UCI Machine Learning Repository.
Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information.
This quickstart contains the following sections:
The Iris Data Set has the following data,
where the sepal and petal features are used to predict the
class
label, which is the type of Iris plant:
sepal length (cm)
sepal width (cm)
petal length (cm)
petal width (cm)
-
class. Possible values include:
Iris Setosa
Iris Versicolour
Iris Virginica
Data is stored in the MySQL database in the following schema and tables:
ml_data
schema: The schema containing training and test dataset tables.iris_train
table: The training dataset (labeled). Includes feature columns (sepal length, sepal width, petal length, petal width) and a populatedclass
target column with ground truth values.iris_test
table: The test dataset (unlabeled). Includes feature columns (sepal length, sepal width, petal length, petal width) but no target column.iris_validate
table: The validation dataset (labeled). Includes feature columns (sepal length, sepal width, petal length, petal width) and a populatedclass
target column with ground truth values.
Review the HeatWave Quickstart Requirements.
Create the example schema and tables on the MySQL DB System with the following statements:
mysql> CREATE SCHEMA ml_data;
USE ml_data;
CREATE TABLE `iris_train` (
`sepal length` float DEFAULT NULL,
`sepal width` float DEFAULT NULL,
`petal length` float DEFAULT NULL,
`petal width` float DEFAULT NULL,
`class` varchar(16) DEFAULT NULL);
INSERT INTO iris_train VALUES(6.4,2.8,5.6,2.2,'Iris-virginica');
INSERT INTO iris_train VALUES(5.0,2.3,3.3,1.0,'Iris-setosa');
INSERT INTO iris_train VALUES(4.9,2.5,4.5,1.7,'Iris-virginica');
INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.7,3.8,1.7,0.3,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.4,3.2,1.3,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.4,3.4,1.5,0.4,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.9,3.1,5.1,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(6.7,3.1,4.4,1.4,'Iris-setosa');
INSERT INTO iris_train VALUES(5.1,3.7,1.5,0.4,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.2,2.7,3.9,1.4,'Iris-setosa');
INSERT INTO iris_train VALUES(6.9,3.1,4.9,1.5,'Iris-setosa');
INSERT INTO iris_train VALUES(5.8,4.0,1.2,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.4,3.9,1.7,0.4,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.7,3.8,6.7,2.2,'Iris-virginica');
INSERT INTO iris_train VALUES(6.3,3.3,4.7,1.6,'Iris-setosa');
INSERT INTO iris_train VALUES(6.8,3.2,5.9,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(7.6,3.0,6.6,2.1,'Iris-virginica');
INSERT INTO iris_train VALUES(6.4,3.2,5.3,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(5.7,4.4,1.5,0.4,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.7,3.3,5.7,2.1,'Iris-virginica');
INSERT INTO iris_train VALUES(6.4,2.8,5.6,2.1,'Iris-virginica');
INSERT INTO iris_train VALUES(5.4,3.9,1.3,0.4,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.1,2.6,5.6,1.4,'Iris-virginica');
INSERT INTO iris_train VALUES(7.2,3.0,5.8,1.6,'Iris-virginica');
INSERT INTO iris_train VALUES(5.2,3.5,1.5,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.8,2.6,4.0,1.2,'Iris-setosa');
INSERT INTO iris_train VALUES(5.9,3.0,5.1,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(5.4,3.0,4.5,1.5,'Iris-setosa');
INSERT INTO iris_train VALUES(6.7,3.0,5.0,1.7,'Iris-setosa');
INSERT INTO iris_train VALUES(6.3,2.3,4.4,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(5.1,2.5,3.0,1.1,'Iris-setosa');
INSERT INTO iris_train VALUES(6.4,3.2,4.5,1.5,'Iris-setosa');
INSERT INTO iris_train VALUES(6.8,3.0,5.5,2.1,'Iris-virginica');
INSERT INTO iris_train VALUES(6.2,2.8,4.8,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(6.9,3.2,5.7,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(6.5,3.2,5.1,2.0,'Iris-virginica');
INSERT INTO iris_train VALUES(5.8,2.8,5.1,2.4,'Iris-virginica');
INSERT INTO iris_train VALUES(5.1,3.8,1.5,0.3,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.8,3.0,1.4,0.3,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.9,3.8,6.4,2.0,'Iris-virginica');
INSERT INTO iris_train VALUES(5.8,2.7,5.1,1.9,'Iris-virginica');
INSERT INTO iris_train VALUES(6.7,3.0,5.2,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(5.1,3.8,1.9,0.4,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.7,3.2,1.6,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.0,2.2,5.0,1.5,'Iris-virginica');
INSERT INTO iris_train VALUES(4.8,3.4,1.6,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.7,2.6,6.9,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(4.6,3.6,1.0,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.2,3.2,6.0,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(5.0,3.3,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.6,3.0,4.4,1.4,'Iris-setosa');
INSERT INTO iris_train VALUES(6.1,2.8,4.0,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(5.0,3.2,1.2,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.0,3.2,4.7,1.4,'Iris-setosa');
INSERT INTO iris_train VALUES(6.0,3.0,4.8,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(7.4,2.8,6.1,1.9,'Iris-virginica');
INSERT INTO iris_train VALUES(5.8,2.7,5.1,1.9,'Iris-virginica');
INSERT INTO iris_train VALUES(6.2,3.4,5.4,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(5.0,2.0,3.5,1.0,'Iris-setosa');
INSERT INTO iris_train VALUES(5.6,2.5,3.9,1.1,'Iris-setosa');
INSERT INTO iris_train VALUES(6.7,3.1,5.6,2.4,'Iris-virginica');
INSERT INTO iris_train VALUES(6.3,2.5,5.0,1.9,'Iris-virginica');
INSERT INTO iris_train VALUES(6.4,3.1,5.5,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(6.2,2.2,4.5,1.5,'Iris-setosa');
INSERT INTO iris_train VALUES(7.3,2.9,6.3,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(4.4,3.0,1.3,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.2,3.6,6.1,2.5,'Iris-virginica');
INSERT INTO iris_train VALUES(6.5,3.0,5.5,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(5.0,3.4,1.5,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.7,3.2,1.3,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.6,2.9,4.6,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(5.5,3.5,1.3,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.7,3.0,6.1,2.3,'Iris-virginica');
INSERT INTO iris_train VALUES(6.1,3.0,4.9,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.5,2.4,3.8,1.1,'Iris-setosa');
INSERT INTO iris_train VALUES(5.7,2.9,4.2,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(6.0,2.9,4.5,1.5,'Iris-setosa');
INSERT INTO iris_train VALUES(6.4,2.7,5.3,1.9,'Iris-virginica');
INSERT INTO iris_train VALUES(5.4,3.7,1.5,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.1,2.9,4.7,1.4,'Iris-setosa');
INSERT INTO iris_train VALUES(6.5,2.8,4.6,1.5,'Iris-setosa');
INSERT INTO iris_train VALUES(5.6,2.7,4.2,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(6.3,3.4,5.6,2.4,'Iris-virginica');
INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.8,2.8,4.8,1.4,'Iris-setosa');
INSERT INTO iris_train VALUES(5.7,2.8,4.5,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(6.0,2.7,5.1,1.6,'Iris-setosa');
INSERT INTO iris_train VALUES(5.0,3.5,1.3,0.3,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.5,3.0,5.2,2.0,'Iris-virginica');
INSERT INTO iris_train VALUES(6.1,2.8,4.7,1.2,'Iris-setosa');
INSERT INTO iris_train VALUES(5.1,3.5,1.4,0.3,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.6,3.1,1.5,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.5,3.0,5.8,2.2,'Iris-virginica');
INSERT INTO iris_train VALUES(4.6,3.4,1.4,0.3,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.6,3.2,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(7.7,2.8,6.7,2.0,'Iris-virginica');
INSERT INTO iris_train VALUES(5.9,3.2,4.8,1.8,'Iris-setosa');
INSERT INTO iris_train VALUES(5.1,3.8,1.6,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.9,3.0,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.9,2.4,3.3,1.0,'Iris-setosa');
INSERT INTO iris_train VALUES(4.5,2.3,1.3,0.3,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.8,2.7,4.1,1.0,'Iris-setosa');
INSERT INTO iris_train VALUES(5.0,3.4,1.6,0.4,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.2,3.4,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.3,3.7,1.5,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.0,3.6,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.6,2.9,3.6,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(4.8,3.1,1.6,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.3,2.7,4.9,1.8,'Iris-virginica');
INSERT INTO iris_train VALUES(5.7,2.8,4.1,1.3,'Iris-setosa');
INSERT INTO iris_train VALUES(5.0,3.0,1.6,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(6.3,3.3,6.0,2.5,'Iris-virginica');
INSERT INTO iris_train VALUES(5.0,3.5,1.6,0.6,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.5,2.6,4.4,1.2,'Iris-setosa');
INSERT INTO iris_train VALUES(5.7,3.0,4.2,1.2,'Iris-setosa');
INSERT INTO iris_train VALUES(4.4,2.9,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_train VALUES(4.8,3.0,1.4,0.1,'Iris-versicolor');
INSERT INTO iris_train VALUES(5.5,2.4,3.7,1.0,'Iris-setosa');
CREATE TABLE `iris_test` LIKE `iris_train`;
INSERT INTO iris_test VALUES(5.9,3.0,4.2,1.5,'Iris-setosa');
INSERT INTO iris_test VALUES(6.9,3.1,5.4,2.1,'Iris-virginica');
INSERT INTO iris_test VALUES(5.1,3.3,1.7,0.5,'Iris-versicolor');
INSERT INTO iris_test VALUES(6.0,3.4,4.5,1.6,'Iris-setosa');
INSERT INTO iris_test VALUES(5.5,2.5,4.0,1.3,'Iris-setosa');
INSERT INTO iris_test VALUES(6.2,2.9,4.3,1.3,'Iris-setosa');
INSERT INTO iris_test VALUES(5.5,4.2,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_test VALUES(6.3,2.8,5.1,1.5,'Iris-virginica');
INSERT INTO iris_test VALUES(5.6,3.0,4.1,1.3,'Iris-setosa');
INSERT INTO iris_test VALUES(6.7,2.5,5.8,1.8,'Iris-virginica');
INSERT INTO iris_test VALUES(7.1,3.0,5.9,2.1,'Iris-virginica');
INSERT INTO iris_test VALUES(4.3,3.0,1.1,0.1,'Iris-versicolor');
INSERT INTO iris_test VALUES(5.6,2.8,4.9,2.0,'Iris-virginica');
INSERT INTO iris_test VALUES(5.5,2.3,4.0,1.3,'Iris-setosa');
INSERT INTO iris_test VALUES(6.0,2.2,4.0,1.0,'Iris-setosa');
INSERT INTO iris_test VALUES(5.1,3.5,1.4,0.2,'Iris-versicolor');
INSERT INTO iris_test VALUES(5.7,2.6,3.5,1.0,'Iris-setosa');
INSERT INTO iris_test VALUES(4.8,3.4,1.9,0.2,'Iris-versicolor');
INSERT INTO iris_test VALUES(5.1,3.4,1.5,0.2,'Iris-versicolor');
INSERT INTO iris_test VALUES(5.7,2.5,5.0,2.0,'Iris-virginica');
INSERT INTO iris_test VALUES(5.4,3.4,1.7,0.2,'Iris-versicolor');
INSERT INTO iris_test VALUES(5.6,3.0,4.5,1.5,'Iris-setosa');
INSERT INTO iris_test VALUES(6.3,2.9,5.6,1.8,'Iris-virginica');
INSERT INTO iris_test VALUES(6.3,2.5,4.9,1.5,'Iris-setosa');
INSERT INTO iris_test VALUES(5.8,2.7,3.9,1.2,'Iris-setosa');
INSERT INTO iris_test VALUES(6.1,3.0,4.6,1.4,'Iris-setosa');
INSERT INTO iris_test VALUES(5.2,4.1,1.5,0.1,'Iris-versicolor');
INSERT INTO iris_test VALUES(6.7,3.1,4.7,1.5,'Iris-setosa');
INSERT INTO iris_test VALUES(6.7,3.3,5.7,2.5,'Iris-virginica');
INSERT INTO iris_test VALUES(6.4,2.9,4.3,1.3,'Iris-setosa');
CREATE TABLE `iris_validate` LIKE `iris_test`;
INSERT INTO `iris_validate` SELECT * FROM `iris_test`;
Before MySQL 8.0.32, drop the class
column
from iris_test
.
mysql> ALTER TABLE `iris_test` DROP COLUMN `class`;
Train the model with ML_TRAIN
.
Since this is a classification dataset, use the
classification
task to create a
classification model:
mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class',
JSON_OBJECT('task', 'classification'), @iris_model);
When the training operation finishes, the model handle is
assigned to the @iris_model
session variable,
and the model is stored in the model catalog. View the entry in
the model catalog with the following query. Replace
user1
with the MySQL account name:
mysql> SELECT model_id, model_handle, train_table_name FROM ML_SCHEMA_user1.MODEL_CATALOG;
+----------+---------------------------------------+--------------------+
| model_id | model_handle | train_table_name |
+----------+---------------------------------------+--------------------+
| 1 | ml_data.iris_train_user1_1648140791 | ml_data.iris_train |
+----------+---------------------------------------+--------------------+
MySQL 8.0.31 does not run the
ML_EXPLAIN
routine with the
default Permutation Importance model after
ML_TRAIN
. For MySQL 8.0.31, run
ML_EXPLAIN
and use
NULL
for the options:
mysql> CALL sys.ML_EXPLAIN('ml_data.iris_train', 'class',
'ml_data.iris_train_user1_1648140791', NULL);
Load the model into HeatWave AutoML using
ML_MODEL_LOAD
routine:
mysql> CALL sys.ML_MODEL_LOAD(@iris_model, NULL);
A model must be loaded before you can use it. The model remains loaded until you unload it or the HeatWave Cluster is restarted.
You can make predictions on a single row of data or on the table of data.
-
Make a prediction for a single row of data using the
ML_PREDICT_ROW
routine. In this example, data is assigned to a@row_input
session variable, and the variable is called by the routine. The model handle is called using the@iris_model
session variable:mysql> SET @row_input = JSON_OBJECT( "sepal length", 7.3, "sepal width", 2.9, "petal length", 6.3, "petal width", 1.8); mysql> SELECT sys.ML_PREDICT_ROW(@row_input, @iris_model, NULL); +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | sys.ML_PREDICT_ROW('{"sepal length": 7.3, "sepal width": 2.9, "petal length": 6.3, "petal width": 1.8}', @iris_model, NULL) | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | {"Prediction": "Iris-virginica", "ml_results": "{'predictions': {'class': 'Iris-virginica'}, 'probabilities': {'Iris-setosa': 0.0, 'Iris-versicolor': 0.13, 'Iris-virginica': 0.87}}", "petal width": 1.8, "sepal width": 2.9, "petal length": 6.3, "sepal length": 7.3} | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 1 row in set (1.12 sec)
Before MySQL 8.0.32, the
ML_PREDICT_ROW
routine does not include options, and the results do not include theml_results
field:mysql> SELECT sys.ML_PREDICT_ROW(@row_input, @iris_model); +---------------------------------------------------------------------------+ | sys.ML_PREDICT_ROW(@row_input, @iris_model) | +---------------------------------------------------------------------------+ | {"Prediction": "Iris-virginica", "petal width": 1.8, "sepal width": 2.9, | | "petal length": 6.3, "sepal length": 7.3} | +---------------------------------------------------------------------------+
Based on the feature inputs that were provided, the model predicts that the Iris plant is of the class
Iris-virginica
. The feature values used to make the prediction are also shown. -
Make predictions for a table of data using the
ML_PREDICT_TABLE
routine. The routine takes data from theiris_test
table as input and writes the predictions to aniris_predictions
output table.mysql> CALL sys.ML_PREDICT_TABLE('ml_data.iris_test', @iris_model, 'ml_data.iris_predictions', NULL);
To view
ML_PREDICT_TABLE
results, query the output table; for example:mysql> SELECT * from ml_data.iris_predictions LIMIT 5; +-----+--------------+-------------+--------------+-------------+-----------------+-----------------+---------------------------------------------------------------------------------------------------------------------------------------+ | _id | sepal length | sepal width | petal length | petal width | class | Prediction | ml_results | +-----+--------------+-------------+--------------+-------------+-----------------+-----------------+---------------------------------------------------------------------------------------------------------------------------------------+ | 1 | 7.3 | 2.9 | 6.3 | 1.8 | Iris-virginica | Iris-virginica | {'predictions': {'class': 'Iris-virginica'}, 'probabilities': {'Iris-setosa': 0.0, 'Iris-versicolor': 0.13, 'Iris-virginica': 0.87}} | | 2 | 6.1 | 2.9 | 4.7 | 1.4 | Iris-versicolor | Iris-versicolor | {'predictions': {'class': 'Iris-versicolor'}, 'probabilities': {'Iris-setosa': 0.0, 'Iris-versicolor': 1.0, 'Iris-virginica': 0.0}} | | 3 | 6.3 | 2.8 | 5.1 | 1.5 | Iris-virginica | Iris-versicolor | {'predictions': {'class': 'Iris-versicolor'}, 'probabilities': {'Iris-setosa': 0.0, 'Iris-versicolor': 0.6, 'Iris-virginica': 0.4}} | | 4 | 6.3 | 3.3 | 4.7 | 1.6 | Iris-versicolor | Iris-versicolor | {'predictions': {'class': 'Iris-versicolor'}, 'probabilities': {'Iris-setosa': 0.0, 'Iris-versicolor': 0.99, 'Iris-virginica': 0.01}} | | 5 | 6.1 | 3 | 4.9 | 1.8 | Iris-virginica | Iris-virginica | {'predictions': {'class': 'Iris-virginica'}, 'probabilities': {'Iris-setosa': 0.0, 'Iris-versicolor': 0.32, 'Iris-virginica': 0.68}} | +-----+--------------+-------------+--------------+-------------+-----------------+-----------------+---------------------------------------------------------------------------------------------------------------------------------------+ 5 rows in set (0.00 sec)
Before MySQL 8.0.32, the
ML_PREDICT_TABLE
routine does not include options, and the results do not include theml_results
column:mysql> CALL sys.ML_PREDICT_TABLE('ml_data.iris_test', @iris_model, 'ml_data.iris_predictions'); mysql> SELECT * FROM ml_data.iris_predictions LIMIT 3; *************************** 1. row *************************** sepal length: 7.3 sepal width: 2.9 petal length: 6.3 petal width: 1.8 Prediction: Iris-virginica *************************** 2. row *************************** sepal length: 6.1 sepal width: 2.9 petal length: 4.7 petal width: 1.4 Prediction: Iris-versicolor *************************** 3. row *************************** sepal length: 6.3 sepal width: 2.8 petal length: 5.1 petal width: 1.5 Prediction: Iris-virginica
The table shows the predictions and the feature column values used to make each prediction.
After creating predictions (either on a single row of data or the table of data), you can generate explanations to understand how the predictions were made and review which features had the most influence on predictions.
-
Generate an explanation for a prediction made on a row of data using the
ML_EXPLAIN_ROW
routine with the Permutation Importance prediction explainer:mysql> SELECT sys.ML_EXPLAIN_ROW(JSON_OBJECT("sepal length", 7.3, "sepal width", 2.9, "petal length", 6.3, "petal width", 1.8), @iris_model, JSON_OBJECT('prediction_explainer', 'permutation_importance')); +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | sys.ML_EXPLAIN_ROW(JSON_OBJECT("sepal length", 7.3, "sepal width", 2.9, "petal length", 6.3, "petal width", 1.8), @iris_model, NULL) | +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | {"Notes": "petal width (1.8) had the largest impact towards predicting Iris-virginica", "Prediction": "Iris-virginica", "ml_results": "{'attributions': {'petal length': 0.57, 'petal width': 0.73}, 'predictions': {'class': 'Iris-virginica'}, 'notes': 'petal width (1.8) had the largest impact towards predicting Iris-virginica'}", "petal width": 1.8, "sepal width": 2.9, "petal length": 6.3, "sepal length": 7.3, "petal width_attribution": 0.73, "petal length_attribution": 0.57} | +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 1 row in set (5.92 sec)
Before MySQL 8.0.32, the results do not include the
ml_results
field:+------------------------------------------------------------------------------+ | sys.ML_EXPLAIN_ROW(JSON_OBJECT("sepal length", 7.3, "sepal width", 2.9, | "petal length", 6.3, "petal width", 1.8), @iris_model) | +------------------------------------------------------------------------------+ | {"Prediction": "Iris-virginica", "petal width": 1.8, "sepal width": 2.9, | | "petal length": 6.3, "sepal length": 7.3, "petal width_attribution": 0.73, | | "petal length_attribution": 0.57} | +------------------------------------------------------------------------------+
The attribution values show which features contributed most to the prediction, with petal length and petal width being the most important features. The other features have a 0 value indicating that they did not contribute to the prediction.
-
Generate explanations for predictions made for a table of data using the
ML_EXPLAIN_TABLE
routine with the Permutation Importance prediction explainer.Feature importance is presented as an attribution value ranging from -1 to 1. A positive value indicates that a feature contributed toward the prediction. A negative value indicates that the feature contributes positively towards one of the other possible predictions.
mysql> CALL sys.ML_EXPLAIN_TABLE('ml_data.iris_test', @iris_model, 'ml_data.iris_explanations', JSON_OBJECT('prediction_explainer', 'permutation_importance'));
To view
ML_EXPLAIN_TABLE
results, query the output table; for example:mysql> SELECT * FROM ml_data.iris_explanations| _id | sepal length | sepal width | petal length | petal width | class | Prediction | Notes | petal length_attribution | petal width_attribution | ml_results || 1 | 7.3 | 2.9 | 6.3 | 1.8 | Iris-virginica | Iris-virginica | petal width (1.7999999523162842) had the largest impact towards predicting Iris-virginica | 0.57 | 0.73 | {'attributions': {'petal length': 0.57, 'petal width': 0.73}, 'predictions': {'class': 'Iris-virginica'}, 'notes': 'petal width (1.7999999523162842) had the largest impact towards predicting Iris-virginica'} | | 2 | 6.1 | 2.9 | 4.7 | 1.4 | Iris-versicolor | Iris-versicolor | petal width (1.399999976158142) had the largest impact towards predicting Iris-versicolor | 0.14 | 0.6 | {'attributions': {'petal length': 0.14, 'petal width': 0.6}, 'predictions': {'class': 'Iris-versicolor'}, 'notes': 'petal width (1.399999976158142) had the largest impact towards predicting Iris-versicolor'} | | 3 | 6.3 | 2.8 | 5.1 | 1.5 | Iris-virginica | Iris-versicolor | petal width (1.5) had the largest impact towards predicting Iris-versicolor, whereas petal length (5.099999904632568) contributed the most against predicting Iris-versicolor | -0.25 | 0.31 | {'attributions': {'petal length': -0.25, 'petal width': 0.31}, 'predictions': {'class': 'Iris-versicolor'}, 'notes': 'petal width (1.5) had the largest impact towards predicting Iris-versicolor, whereas petal length (5.099999904632568) contributed the most against predicting Iris-versicolor'} | | 4 | 6.3 | 3.3 | 4.7 | 1.6 | Iris-versicolor | Iris-versicolor | petal width (1.600000023841858) had the largest impact towards predicting Iris-versicolor | 0.14 | 0.58 | {'attributions': {'petal length': 0.14, 'petal width': 0.58}, 'predictions': {'class': 'Iris-versicolor'}, 'notes': 'petal width (1.600000023841858) had the largest impact towards predicting Iris-versicolor'} | | 5 | 6.1 | 3 | 4.9 | 1.8 | Iris-virginica | Iris-virginica | petal width (1.7999999523162842) had the largest impact towards predicting Iris-virginica | 0.38 | 0.61 | {'attributions': {'petal length': 0.38, 'petal width': 0.61}, 'predictions': {'class': 'Iris-virginica'}, 'notes': 'petal width (1.7999999523162842) had the largest impact towards predicting Iris-virginica'} |rows in set (0.00 sec)
Before MySQL 8.0.32, the output table does not include the
ml_results
column:mysql> SELECT * FROM ml_data.iris_explanations LIMIT 3; *************************** 1. row *************************** sepal length: 7.3 sepal width: 2.9 petal length: 6.3 petal width: 1.8 Prediction: Iris-virginica petal length_attribution: 0.57 petal width_attribution: 0.73 *************************** 2. row *************************** sepal length: 6.1 sepal width: 2.9 petal length: 4.7 petal width: 1.4 Prediction: Iris-versicolor petal length_attribution: 0.14 petal width_attribution: 0.6 *************************** 3. row *************************** sepal length: 6.3 sepal width: 2.8 petal length: 5.1 petal width: 1.5 Prediction: Iris-virginica petal length_attribution: -0.25 petal width_attribution: 0.31 3 rows in set (0.0006 sec)
Score the model with ML_SCORE
to
assess the reliability of the model. This example uses the
balanced_accuracy
metric, which is one of the
many scoring metrics that HeatWave AutoML supports.
mysql> CALL sys.ML_SCORE('ml_data.iris_validate', 'class', @iris_model, 'balanced_accuracy',
@score, NULL);
Before MySQL 8.2.0, there is no options parameter available. So,
the NULL
parameter is not required.
mysql> CALL sys.ML_SCORE('ml_data.iris_validate', 'class', @iris_model,
'balanced_accuracy', @score);
To retrieve the computed score, query the
@score
session variable.
mysql> SELECT @score;
+--------------------+
| @score |
+--------------------+
| 0.9583333134651184 |
+--------------------+
Unload the model using
ML_MODEL_UNLOAD
:
mysql> CALL sys.ML_MODEL_UNLOAD(@iris_model);
To avoid consuming too much space, it is good practice to unload a model when you are finished using it.