Documentation Home
HeatWave User Guide
Related Documentation Download this Manual
PDF (US Ltr) - 2.1Mb
PDF (A4) - 2.1Mb


HeatWave User Guide  /  ...  /  Data Drift Detection

3.14.11 Data Drift Detection

HeatWave AutoML includes data drift detection for classification and regression models.

Machine learning typically makes an assumption that the training data and test data are similar. Over time, the similarity between the training data and the test data can decrease. This is known as data drift.

HeatWave AutoML monitors data drift with the following additions to the model catalog and to the ML_PREDICT_ROW and ML_PREDICT_TABLE routines:

  • The model_metadata column in the model catalog includes the training_drift_metric JSON object literal which contains mean and variance numeric values. See: Section 3.14.1.3, “Model Metadata”.

  • The ML_PREDICT_ROW and ML_PREDICT_TABLE options parameter includes the additional_details boolean value.

  • The ML_PREDICT_ROW and ML_PREDICT_TABLE ml_results column includes the drift JSON object literal which contains the metric numeric value and the attribution_percent JSON object literal. attribution_percent records up to 3 features with the highest attribution percentage values for each result.

To use data drift detection, follow this process:

  1. During training, the ML_TRAIN routine records the training_drift_metric. Once training is complete, review the mean and variance values.

    mean and variance indicate the quality of the trained drift detector, and both values should be low. mean is more important, and if it is greater than 1.0, then drift evaluation for the test results might not be reliable.

  2. Set the additional_details option to true for ML_PREDICT_ROW and ML_PREDICT_TABLE to record drift in ml_results.

  3. Run ML_PREDICT_ROW or ML_PREDICT_TABLE, and review drift in ml_results.

    metric indicates the similarity between training and test data. A low value indicates similar values. A value grater than 1.0 indicates data drift, and the prediction results will be questionable.

    attribution_percent indicates the top three features that contribute to data drift for each result. The higher the percentage value, the greater the contribution.

Syntax Examples

  • The model_metadata includes the training_drift_metric JSON object literal.

    mysql> CALL sys.ML_TRAIN('mlcorpus_v5.`titanic_train`', 'survived', NULL, @model);
    Query OK, 0 rows affected (1 min 11.7002 sec)
    
    mysql> SELECT JSON_PRETTY(model_metadata) FROM ML_SCHEMA_root.MODEL_CATALOG WHERE model_handle=@model;
    +-----------------------------------------------------+
    | JSON_PRETTY(model_metadata)                         |
    +-----------------------------------------------------+
    | {
      "task": "classification",
      "notes": null,
      "format": "HWMLv1.0",
      "n_rows": 916,
      "status": "Ready",
      "options": {
        "model_explainer": "permutation_importance",
        "prediction_explainer": "permutation_importance"
      },
      "n_columns": 13,
      "column_names": [
        "pclass",
        "name",
        "sex",
        "age",
        "sibsp",
        "parch",
        "ticket",
        "fare",
        "cabin",
        "embarked",
        "boat",
        "body",
        "home.dest"
      ],
      "contamination": null,
      "model_quality": "high",
      "training_time": 57.53120040893555,
      "algorithm_name": "XGBClassifier",
      "training_score": -0.07736892998218536,
      "build_timestamp": 1699468966,
      "n_selected_rows": 732,
      "training_params": {
        "recommend": "ratings",
        "force_use_X": false,
        "recommend_k": 3
      },
      "train_table_name": "mlcorpus_v5.titanic_train",
      "model_explanation": {
        "permutation_importance": {
          "age": 0.0,
          "sex": 0.0,
          "boat": 0.4445,
          "body": 0.0,
          "fare": 0.0,
          "name": 0.0,
          "cabin": 0.0,
          "parch": 0.0,
          "sibsp": 0.0,
          "pclass": 0.0,
          "ticket": 0.0,
          "embarked": 0.0,
          "home.dest": 0.0
        }
      },
      "n_selected_columns": 2,
      "target_column_name": "survived",
      "optimization_metric": "neg_log_loss",
      "selected_column_names": [
        "boat",
        "sex"
      ],
      "training_drift_metric": {
        "mean": 0.278,
        "variance": 0.2356
      }
    } |
    +-----------------------------------------------------+
    1 row in set (0.0004 sec)
  • A ML_PREDICT_TABLE example with additional_details set to true.

    mysql> CALL sys.ML_TRAIN('mlcorpus_v5.`diamonds_train`', 'price', JSON_OBJECT('task','regression'), @model);
    Query OK, 0 rows affected (7 min 47.9567 sec)
    
    mysql> CALL sys.ML_MODEL_LOAD(@model, NULL);
    Query OK, 0 rows affected (0.7665 sec)
    
    mysql> DROP TABLE IF EXISTS diamonds_predictions_experiment_results;
    Query OK, 0 rows affected (0.0106 sec)
    
    mysql> CALL sys.ML_PREDICT_TABLE('mlcorpus_v5.`diamonds_test`', @model, 'mlcorpus_v5.`diamonds_predictions_experiment_results`', JSON_OBJECT('additional_details', true));
    Query OK, 0 rows affected (28.5353 sec)
    
    mysql> SELECT ml_results FROM diamonds_predictions_experiment_results 
              WHERE JSON_EXTRACT(ml_results, '$.drift.metric') > 0.5 
              LIMIT 10;
    +-------------------------------------------------------------------------------------------------------------------------------------------------+
    | ml_results                                                                                                                                      |
    +-------------------------------------------------------------------------------------------------------------------------------------------------+
    | {"predictions": {"price": 4769.22265625}, "drift": {"metric": 0.69, "attribution_percent": {"cut": 100.0, "carat": 0.0, "clarity": 0.0}}}       |
    | {"predictions": {"price": 2610.075439453125}, "drift": {"metric": 0.57, "attribution_percent": {"color": 91.25, "cut": 8.75, "carat": 0.0}}}    |
    | {"predictions": {"price": 2725.368896484375}, "drift": {"metric": 0.54, "attribution_percent": {"cut": 100.0, "carat": 0.0, "clarity": 0.0}}}   |
    | {"predictions": {"price": 7102.55224609375}, "drift": {"metric": 2.49, "attribution_percent": {"z": 64.53, "y": 16.86, "x": 11.58}}}            |
    | {"predictions": {"price": 3622.7236328125}, "drift": {"metric": 0.55, "attribution_percent": {"color": 81.2, "cut": 18.8, "carat": 0.0}}}       |
    | {"predictions": {"price": 3879.93701171875}, "drift": {"metric": 2.24, "attribution_percent": {"z": 70.23, "y": 15.57, "x": 9.89}}}             |
    | {"predictions": {"price": 566.2338256835938}, "drift": {"metric": 0.67, "attribution_percent": {"color": 96.65, "cut": 3.35, "carat": 0.0}}}    |
    | {"predictions": {"price": 2495.825439453125}, "drift": {"metric": 0.64, "attribution_percent": {"cut": 100.0, "carat": 0.0, "clarity": 0.0}}}   |
    | {"predictions": {"price": 421.9180603027344}, "drift": {"metric": 0.58, "attribution_percent": {"color": 100.0, "carat": 0.0, "clarity": 0.0}}} |
    | {"predictions": {"price": 325.4655456542969}, "drift": {"metric": 0.53, "attribution_percent": {"color": 100.0, "carat": 0.0, "clarity": 0.0}}} |
    +-------------------------------------------------------------------------------------------------------------------------------------------------+
    10 rows in set (0.0048 sec)
  • A ML_PREDICT_ROW example with additional_details set to true.

    mysql> SELECT JSON_OBJECT('carat', `diamonds_test`.`carat`, 'cut', `diamonds_test`.`cut`, 'color', `diamonds_test`.`color`, 'clarity', `diamonds_test`.`clarity`, 'depth', `diamonds_test`.`depth`, '_table', `diamonds_test`.`_table`, 'x', `diamonds_test`.`x`, 'y', `diamonds_test`.`y`, 'z', `diamonds_test`.`z`) 
              AS obj
              FROM `diamonds_test`
              WHERE JSON_UNQUOTE(JSON_OBJECT(JSON_OBJECT('carat', `diamonds_test`.`carat`, 'cut', `diamonds_test`.`cut`, 'color', `diamonds_test`.`color`, 'clarity', `diamonds_test`.`clarity`, 'depth', `diamonds_test`.`depth`, '_table', `diamonds_test`.`_table`, 'x', `diamonds_test`.`x`, 'y', `diamonds_test`.`y`, 'z', `diamonds_test`.`z`), '$.carat' )) > 3 
              LIMIT 1
              INTO @row;
    Query OK, 1 row affected (0.0033 sec)
    
    mysql> SELECT sys.ML_PREDICT_ROW(@row, @model, JSON_OBJECT('additional_details', TRUE)) FROM `diamonds_test` LIMIT 1;
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | sys.ML_PREDICT_ROW(@row, @model, JSON_OBJECT('additional_details', TRUE))                                                                                                                                                                                                                                                                                            |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | {"x": 9.1000003815, "y": 8.970000267, "z": 5.6700000763, "cut": "Premium", "carat": 3.0099999905, "color": "I", "depth": 62.7000007629, "_table": 58.0, "clarity": "I1", "Prediction": 10211.1376953125, "ml_results": {"drift": {"metric": 0.19, "attribution_percent": {"y": 32.25, "carat": 67.75, "clarity": 0.0}}, "predictions": {"price": 10211.1376953125}}} |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    1 row in set (0.1670 sec)