Documentation Home
HeatWave User Guide
Related Documentation Download this Manual
PDF (US Ltr) - 2.1Mb
PDF (A4) - 2.1Mb


HeatWave User Guide  /  ...  /  How to Import ONNX Models

3.14.2.2 How to Import ONNX Models

Follow these steps to import a model in ONNX format to the model catalog:

  1. Convert the .onnx file containing the model to Base64 encoding and carry out string serialization. Do this with the Python base64 module. The following example converts the file iris.onnx:

    $> python -c "import onnx; import base64; 
    $> open('iris_base64.onnx', 'wb').write(
    $> base64.b64encode(onnx.load('iris.onnx').SerializeToString()))"
  2. Connect to the MySQL DB System for the HeatWave Cluster as a client, and create a temporary table to upload the model. For example:

    mysql> CREATE TEMPORARY TABLE onnx_temp (onnx_string LONGTEXT);
  3. Use a LOAD DATA INFILE statement to load the preprocessed .onnx file into the temporary table. For example:

    mysql> LOAD DATA INFILE 'iris_base64.onnx' 
              INTO TABLE onnx_temp 
              CHARACTER SET binary 
              FIELDS TERMINATED BY '\t' 
              LINES TERMINATED BY '\r' (onnx_string);
  4. Select the uploaded model from the temporary table into a session variable. For example:

    mysql> SELECT onnx_string FROM onnx_temp INTO @onnx_encode;
  5. Call the ML_MODEL_IMPORT routine to import the ONNX model into the model catalog. For example:

    mysql> CALL sys.ML_MODEL_IMPORT(@onnx_encode, NULL, 'iris_onnx');

    In this example, the model handle is iris_onnx, and the optional model metadata is omitted and set to NULL. For details of the supported metadata for imported ONNX models, see ML_MODEL_IMPORT and Section 3.14.1.3, “Model Metadata”.

After import, all the HeatWave AutoML routines can be used with the ONNX model. It is added to the model catalog and can be managed in the same ways as a model created by HeatWave AutoML.

ONNX Import Examples
  • A classification task example:

    mysql> SET @model := 'sklearn_pipeline_classification_3_onnx';
    Query OK, 0 rows affected (0.0003 sec)
    
    mysql> SET @model_metadata := JSON_OBJECT('task','classification',
              'onnx_outputs_info',JSON_OBJECT('predictions_name','label','prediction_probabilities_name', 'probabilities'));
    Query OK, 0 rows affected (0.0003 sec)
    
    mysql> CALL sys.ML_MODEL_IMPORT(@onnx_encode_sklearn_pipeline_classification_3, @
              model_metadata, @model);
    Query OK, 0 rows affected (1.2438 sec)
    
    mysql> CALL sys.ML_MODEL_LOAD(@model, NULL);
    Query OK, 0 rows affected (0.5372 sec)
    
    mysql> CALL sys.ML_PREDICT_TABLE('mlcorpus.classification_3_predict', @model, 
              'mlcorpus.predictions', NULL);
    Query OK, 0 rows affected (0.8743 sec)
    
    mysql> SELECT * FROM mlcorpus.predictions;
    +-------------------+----+----+-----+------------+----------------------------------------------------------------------------------------------------------+
    | _4aad19ca6e_pk_id | f1 | f2 | f3  | Prediction | ml_results                                                                                               |
    +-------------------+----+----+-----+------------+----------------------------------------------------------------------------------------------------------+
    |                 1 | a  | 20 | 1.2 |          0 | {"predictions": {"prediction": 0}, "probabilities": {"0": 0.5099999904632568, "1": 0.49000000953674316}} |
    |                 2 | b  | 21 | 3.6 |          1 | {"predictions": {"prediction": 1}, "probabilities": {"0": 0.3199999928474426, "1": 0.6800000071525574}}  |
    |                 3 | c  | 19 | 7.8 |          1 | {"predictions": {"prediction": 1}, "probabilities": {"0": 0.3199999928474426, "1": 0.6800000071525574}}  |
    |                 4 | d  | 18 |   9 |          0 | {"predictions": {"prediction": 0}, "probabilities": {"0": 0.5199999809265137, "1": 0.47999998927116394}} |
    |                 5 | e  | 17 | 3.6 |          1 | {"predictions": {"prediction": 1}, "probabilities": {"0": 0.3199999928474426, "1": 0.6800000071525574}}  |
    +-------------------+----+----+-----+------------+----------------------------------------------------------------------------------------------------------+
    5 rows in set (0.0005 sec)
    
    mysql> CALL sys.ML_SCORE('mlcorpus.classification_3_table','target', @model,
              'accuracy', @score, NULL);
    Query OK, 0 rows affected (0.9573 sec)
    
    mysql> SELECT @score;
    +--------+
    | @score |
    +--------+
    |      1 |
    +--------+
    1 row in set (0.0003 sec)
    
    mysql> CALL sys.ML_EXPLAIN('mlcorpus.classification_3_table', 'target', @model, 
              JSON_OBJECT('model_explainer', 'shap', 'prediction_explainer', 'shap'));
    Query OK, 0 rows affected (10.1771 sec)
    
    mysql> SELECT model_explanation FROM ML_SCHEMA_root.MODEL_CATALOG
              WHERE model_handle=@model;
    +------------------------------------------------------+
    | model_explanation                                    |
    +------------------------------------------------------+
    | {"shap": {"f1": 0.0928, "f2": 0.0007, "f3": 0.0039}} |
    +------------------------------------------------------+
    1 row in set (0.0005 sec)
    
    mysql> CALL sys.ML_EXPLAIN_TABLE('mlcorpus.classification_3_predict', @model, 
              'mlcorpus.explanations_shap', JSON_OBJECT('prediction_explainer', 'shap'));
    Query OK, 0 rows affected (7.6577 sec)
    
    mysql> SELECT * FROM mlcorpus.explanations_shap;
    +-------------------+----+----+-----+------------+----------------+-----------------+----------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | _4aad19ca6e_pk_id | f1 | f2 | f3  | Prediction | f1_attribution | f2_attribution  | f3_attribution | ml_results                                                                                                                                                                     |
    +-------------------+----+----+-----+------------+----------------+-----------------+----------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    |                 1 | a  | 20 | 1.2 |          0 |       0.116909 |     0.000591494 |    -0.00524929 | {"predictions": {"prediction": 0}, "attributions": {"f1_attribution": 0.11690924863020577, "f2_attribution": 0.0005914936463038284, "f3_attribution": -0.005249293645222988}}  |
    |                 2 | b  | 21 | 3.6 |          1 |      0.0772133 |     -0.00110559 |     0.00219658 | {"predictions": {"prediction": 1}, "attributions": {"f1_attribution": 0.07721325159072877, "f2_attribution": -0.0011055856943130368, "f3_attribution": 0.002196577191352772}}  |
    |                 3 | c  | 19 | 7.8 |          1 |      0.0781372 | 0.0000000913938 |    -0.00324671 | {"predictions": {"prediction": 1}, "attributions": {"f1_attribution": 0.07813718219598137, "f2_attribution": 9.139378859268632e-08, "f3_attribution": -0.0032467077175776238}} |
    |                 4 | d  | 18 |   9 |          0 |       0.115209 |    -0.000592354 |     0.00639341 | {"predictions": {"prediction": 0}, "attributions": {"f1_attribution": 0.11520911753177646, "f2_attribution": -0.0005923539400101152, "f3_attribution": 0.006393408775329595}}  |
    |                 5 | e  | 17 | 3.6 |          1 |      0.0767679 |      0.00110463 |     0.00219425 | {"predictions": {"prediction": 1}, "attributions": {"f1_attribution": 0.0767679293950399, "f2_attribution": 0.0011046340068181504, "f3_attribution": 0.002194248636563534}}    |
    +-------------------+----+----+-----+------------+----------------+-----------------+----------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    5 rows in set (0.0005 sec)
  • A regression task example:

    mysql> SET @model := 'sklearn_pipeline_regression_2_onnx';
    Query OK, 0 rows affected (0.0003 sec)
    
    mysql> set @model_metadata := JSON_OBJECT('task','regression',
              'onnx_outputs_info',JSON_OBJECT('predictions_name','variable'));
    Query OK, 0 rows affected (0.0003 sec)
    
    mysql> CALL sys.ML_MODEL_IMPORT(@onnx_encode_sklearn_pipeline_regression_2, 
              @model_metadata, @model);
    Query OK, 0 rows affected (1.0652 sec)
    
    mysql> CALL sys.ML_MODEL_LOAD(@model, NULL);
    Query OK, 0 rows affected (0.5141 sec)
    
    mysql> CALL sys.ML_PREDICT_TABLE('mlcorpus.regression_2_table', 
              @model, 'mlcorpus.predictions', NULL);
    Query OK, 0 rows affected (0.8902 sec)
    
    mysql> SELECT * FROM mlcorpus.predictions;
    +-------------------+----+----+-----+--------+------------+-----------------------------------------------------+
    | _4aad19ca6e_pk_id | f1 | f2 | f3  | target | Prediction | ml_results                                          |
    +-------------------+----+----+-----+--------+------------+-----------------------------------------------------+
    |                 1 | a  | 20 | 1.2 |   22.4 |     22.262 | {"predictions": {"prediction": 22.262039184570312}} |
    |                 2 | b  | 21 | 3.6 |   32.9 |    32.4861 | {"predictions": {"prediction": 32.486114501953125}} |
    |                 3 | c  | 19 | 7.8 |   56.8 |    56.2482 | {"predictions": {"prediction": 56.24815368652344}}  |
    |                 4 | d  | 18 |   9 |   31.8 |       31.8 | {"predictions": {"prediction": 31.80000114440918}}  |
    |                 5 | e  | 17 | 3.6 |   56.4 |    55.9861 | {"predictions": {"prediction": 55.986114501953125}} |
    +-----+----+----+-----+--------+------------+-----------------------------------------------------+
    5 rows in set (0.0005 sec)
    
    mysql> CALL sys.ML_SCORE('mlcorpus.regression_2_table','target', @model,
              'r2', @score, NULL);
    Query OK, 0 rows affected (0.8688 sec)
    
    mysql> SELECT @score;
    +--------------------+
    | @score             |
    +--------------------+
    | 0.9993192553520203 |
    +--------------------+
    1 row in set (0.0003 sec)
    
    mysql> CALL sys.ML_EXPLAIN('mlcorpus.regression_2_table', 'target', @model, 
              JSON_OBJECT('model_explainer', 'partial_dependence', 
              'columns_to_explain', JSON_ARRAY('f1'), 'prediction_explainer', 'shap'));
    Query OK, 0 rows affected (9.9860 sec)
    
    mysql> SELECT model_explanation FROM ML_SCHEMA_root.MODEL_CATALOG 
              WHERE model_handle=@model;
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | model_explanation                                                                                                                                                                                                                                                                                                                                                                                                                                  |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | {"partial_dependence": {"f1": {"0": "a", "1": "b", "2": "c", "3": "d", "4": "e"}, "Mean": {"0": 28.996999740600582, "1": 42.09299850463867, "2": 54.59299850463867, "3": 24.51300048828125, "4": 48.58700180053711}, "Median": {"0": 30.653, "1": 43.749, "2": 56.248, "3": 26.169, "4": 50.242}, "Standard Deviation": {"0": 7.046000003814697, "1": 7.046000003814697, "2": 7.046000003814697, "3": 7.046000003814697, "4": 7.046000003814697}}} |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    1 row in set (0.0005 sec)
    
    mysql> CALL sys.ML_EXPLAIN_TABLE('mlcorpus.regression_2_predict', @model, 
              'mlcorpus.explanations', JSON_OBJECT('prediction_explainer', 'shap'));
    Query OK, 0 rows affected (8.2625 sec)
    
    mysql> SELECT * FROM mlcorpus.explanations;
    +-------------------+----+----+-----+------------+----------------+----------------+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    | _4aad19ca6e_pk_id | f1 | f2 | f3  | Prediction | f1_attribution | f2_attribution | f3_attribution | ml_results                                                                                                                                                                               |
    +-------------------+----+----+-----+------------+----------------+----------------+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    |                 1 | a  | 20 | 1.2 |     22.262 |       -10.7595 |       -4.25162 |       -2.48331 | {"predictions": {"prediction": 22.262039184570312}, "attributions": {"f1_attribution": -10.759506797790523, "f2_attribution": -4.251623916625977, "f3_attribution": -2.483314704895024}} |
    |                 2 | b  | 21 | 3.6 |    32.4861 |        2.33657 |       -8.50325 |        -1.1037 | {"predictions": {"prediction": 32.486114501953125}, "attributions": {"f1_attribution": 2.336572837829592, "f2_attribution": -8.50324745178223, "f3_attribution": -1.1036954879760748}}   |
    |                 3 | c  | 19 | 7.8 |    56.2482 |        14.8361 |              0 |        1.65554 | {"predictions": {"prediction": 56.24815368652344}, "attributions": {"f1_attribution": 14.83612575531006, "f2_attribution": 0.0, "f3_attribution": 1.6555433273315412}}                   |
    |                 4 | d  | 18 |   9 |       31.8 |       -15.2433 |        4.25162 |        3.03516 | {"predictions": {"prediction": 31.80000114440918}, "attributions": {"f1_attribution": -15.243269538879392, "f2_attribution": 4.251623725891111, "f3_attribution": 3.0351623535156236}}   |
    |                 5 | e  | 17 | 3.6 |    55.9861 |        8.83008 |        8.50325 |        -1.1037 | {"predictions": {"prediction": 55.986114501953125}, "attributions": {"f1_attribution": 8.830077743530275, "f2_attribution": 8.50324764251709, "f3_attribution": -1.1036954879760756}}    |
    +-------------------+----+----+-----+------------+----------------+----------------+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    5 rows in set (0.0006 sec)
  • An example with task set to NULL.

    mysql> SET @model := 'tensorflow_recsys_onnx';
    
    mysql> CALL sys.ML_MODEL_IMPORT(@onnx_encode_tensorflow_recsys, NULL, @model);
    Query OK, 0 rows affected (1.0037 sec)
    
    mysql> CALL sys.ML_MODEL_LOAD(@model, NULL);
    Query OK, 0 rows affected (0.5116 sec)
    
    mysql> CALL sys.ML_PREDICT_TABLE('mlcorpus.recsys_predict', @model, 
              'mlcorpus.predictions', NULL);
    Query OK, 0 rows affected (0.8271 sec)
    
    mysql> SELECT * FROM mlcorpus.predictions;
    +-------------------+---------+-------------+--------------------------+-----------------------------------------------------------+
    | _4aad19ca6e_pk_id | user_id | movie_title | Prediction               | ml_results                                                |
    +-------------------+---------+-------------+--------------------------+-----------------------------------------------------------+
    |                 1 | a       | A           | {"output_1": ["0.7558"]} | {"predictions": {"prediction": {"output_1": ["0.7558"]}}} |
    |                 2 | b       | B           | {"output_1": ["1.0443"]} | {"predictions": {"prediction": {"output_1": ["1.0443"]}}} |
    |                 3 | c       | A           | {"output_1": ["0.8483"]} | {"predictions": {"prediction": {"output_1": ["0.8483"]}}} |
    |                 4 | d       | B           | {"output_1": ["1.2986"]} | {"predictions": {"prediction": {"output_1": ["1.2986"]}}} |
    |                 5 | e       | C           | {"output_1": ["1.1568"]} | {"predictions": {"prediction": {"output_1": ["1.1568"]}}} |
    +-------------------+---------+-------------+--------------------------+-----------------------------------------------------------+
    5 rows in set (0.0005 sec)