Documentation Home
HeatWave User Guide
Related Documentation Download this Manual
PDF (US Ltr) - 3.8Mb
PDF (A4) - 3.8Mb


HeatWave User Guide  /  ...  /  Generate Predictions for a Row of Data

6.5.4.1 Generate Predictions for a Row of Data

ML_PREDICT_ROW generates predictions for one or more rows of data specified in JSON format. You invoke the routine with the SELECT statement.

This topic has the following sections.

Before You Begin
Prepare to Generate a Row Prediction

Before running ML_PREDICT_ROW, you must train, and then load the model you want to use.

  1. The following example trains a dataset with the classification machine learning task.

    mysql> CALL sys.ML_TRAIN('census_data.census_train', 'revenue', JSON_OBJECT('task', 'classification'), @census_model);
  2. The following example loads the trained model.

    mysql> CALL sys.ML_MODEL_LOAD(@census_model, NULL);

For more information about training and loading models, see Train a Model and Load a Model.

After training and loading the model, you can generate predictions on one or more rows of data. For parameter and option descriptions, see ML_PREDICT_ROW.

Input Data to Generate a Row Prediction

One way to generate predictions on row data is to manually enter the row data into a session variable, and then generate a prediction by specifying the session variable.

  1. Define values for each column to predict. The column names must match the feature column names in the trained table.

    mysql> SET @variable = (JSON_OBJECT("column_name", value, "column_name", value, ...), model_handle, options);

    In the following example, create the @row_input session variable and enter the data to predict into the session variable.

    mysql> SET @row_input = JSON_OBJECT( 
              "age", 25, 
              "workclass", "Private", 
              "fnlwgt", 226802, 
              "education", "11th", 
              "education-num", 7, 
              "marital-status", "Never-married", 
              "occupation", "Machine-op-inspct", 
              "relationship", "Own-child", 
              "race", "Black", 
              "sex", "Male", 
              "capital-gain", 0, 
              "capital-loss", 0, 
              "hours-per-week", 40, 
              "native-country", "United-States");
  2. Run ML_PREDICT_ROW and specify the session variable set previously. Optionally, use \G to display information in an easily readable format.

    mysql> SELECT sys.ML_PREDICT_ROW(@variable, ...), model_handle, options);

    Replace variable, model_handle, and options with your own values. For example:

    mysql> SELECT sys.ML_PREDICT_ROW(@row_input, @census_model, NULL)\G
    *************************** 1. row ***************************
    sys.ML_PREDICT_ROW(@row_input, @census_model, NULL): 
    {
        "age": 25,
        "sex": "Male",
        "race": "Black",
        "fnlwgt": 226802,
        "education": "11th",
        "workclass": "Private",
        "Prediction": "<=50K",
        "ml_results": {
            "predictions": {
                "revenue": "<=50K"
            },
            "probabilities": {
                ">50K": 0.0032,
                "<=50K": 0.9968
            }
        },
        "occupation": "Machine-op-inspct",
        "capital-gain": 0,
        "capital-loss": 0,
        "relationship": "Own-child",
        "education-num": 7,
        "hours-per-week": 40,
        "marital-status": "Never-married",
        "native-country": "United-States"
    }
    1 row in set (2.2218 sec)

    Where:

    • @row_input is a session variable containing a row of unlabeled data. The data is specified in JSON key-value format. The column names must match the feature column names in the training dataset.

    • @census_model is the session variable that contains the model handle. Learn more about Model Handles.

    • NULL sets no options to the routine.

    The prediction on the data is that the revenue is <=50K with a probability of 99.7%..

Generate Predictions on One or More Rows of Data

Another way to generate predictions is to create a JSON_OBJECT with specified columns and labels, and then generate predictions on one or more rows of data in the table.

mysql> SELECT sys.ML_PREDICT_ROW(JSON_OBJECT("output_col_name", schema.`input_col_name`, "output_col_name", schema.`input_col_name`, ...), model_handle, options) 
        FROM input_table_name LIMIT N;

The following example specifies the table and columns to use for the prediction and assigns output labels for each table-column pair. No options are set with NULL. It also defines to predict the top two rows of the table. Optionally, use \G to display information in an easily readable format.

mysql> SELECT sys.ML_PREDICT_ROW(JSON_OBJECT(
	"age", census_train.`age`,
	"workclass", census_train.`workclass`,
	"fnlwgt", census_train.`fnlwgt`,
	"education", census_train.`education`,
	"education-num", census_train.`education-num`,
	"marital-status", census_train.`marital-status`,
	"occupation", census_train.`occupation`,
	"relationship", census_train.`relationship`,
	"race", census_train.`race`,
	"sex", census_train.`sex`,
	"capital-gain", census_train.`capital-gain`,
	"capital-loss", census_train.`capital-loss`,
	"hours-per-week", census_train.`hours-per-week`,
	"native-country", census_train.`native-country`),
	@census_model, NULL)FROM census_data.census_train LIMIT 2\G
*************************** 1. row ***************************
sys.ML_PREDICT_ROW(JSON_OBJECT(
"age", census_train.`age`,
"workclass", census_train.`workclass`,
"fnlwgt", census_train.`fnlwgt`,
"education", census_train.`education`,
"education-num", census_train.`education-num`,
"marital-status", census_train.`marita: {
                                            "age": 62,
                                            "sex": "Female",
                                            "race": "White",
                                            "fnlwgt": 123582,
                                            "education": "10th",
                                            "workclass": "Private",
                                            "Prediction": "<=50K",
                                            "ml_results": {
                                                "predictions": {
                                                    "revenue": "<=50K"
                                                },
                                                "probabilities": {
                                                    ">50K": 0.0106,
                                                    "<=50K": 0.9894
                                                }
                                            },
                                            "occupation": "Other-service",
                                            "capital-gain": 0,
                                            "capital-loss": 0,
                                            "relationship": "Unmarried",
                                            "education-num": 6,
                                            "hours-per-week": 40,
                                            "marital-status": "Divorced",
                                            "native-country": "United-States"
                                        }
*************************** 2. row ***************************
sys.ML_PREDICT_ROW(JSON_OBJECT(
"age", census_train.`age`,
"workclass", census_train.`workclass`,
"fnlwgt", census_train.`fnlwgt`,
"education", census_train.`education`,
"education-num", census_train.`education-num`,
"marital-status", census_train.`marita: {
                                            "age": 32,
                                            "sex": "Female",
                                            "race": "White",
                                            "fnlwgt": 174215,
                                            "education": "Bachelors",
                                            "workclass": "Federal-gov",
                                            "Prediction": "<=50K",
                                            "ml_results": {
                                                "predictions": {
                                                    "revenue": "<=50K"
                                                },
                                                "probabilities": {
                                                    ">50K": 0.3249,
                                                    "<=50K": 0.6751
                                                }
                                            },
                                            "occupation": "Exec-managerial",
                                            "capital-gain": 0,
                                            "capital-loss": 0,
                                            "relationship": "Not-in-family",
                                            "education-num": 13,
                                            "hours-per-week": 60,
                                            "marital-status": "Never-married",
                                            "native-country": "United-States"
                                        }
2 rows in set (9.6548 sec)

The output generates revenue predictions for the four rows of data.

What's Next