For model_metadata
, see:
Section 3.14.1.3, “Model Metadata”. This includes
onnx_inputs_info
and
onnx_outputs_info
.
onnx_inputs_info
includes
data_types_map
. See
Section 3.14.1.3, “Model Metadata” for the
default value.
onnx_outputs_info
includes
predictions_name
,
prediction_probabilities_name
, and
labels_map
.
Use the data_types_map
to map the data
type of each column to an ONNX model data type. For example,
to convert inputs of the type
tensor(float)
to
float64
:
data_types_map = {"tensor(float)": "float64"}
HeatWave AutoML first checks the user
data_types_map
, and then the default
data_types_map
to check if the data type
exists. HeatWave AutoML supports the following numpy data types:
Table 3.1 Supported numpy data types
str_ |
unicode_ |
int8 |
int16 |
int32 |
int64 |
int_ |
uint16 |
uint32 |
uint64 |
byte |
ubyte |
short |
ushort |
intc |
uintc |
uint |
longlong |
ulonglong |
intp |
uintp |
float16 |
float32 |
float64 |
half |
single |
longfloat |
double |
longdouble |
bool_ |
datetime64 |
complex_ |
complex64 |
complex128 |
complex256 |
csingle |
cdouble |
clongdouble |
The use of any other numpy data type will cause an error.
Use predictions_name
to determine which
of the ONNX model outputs is associated with predictions.
Use prediction_probabilities_name
to
determine which of the ONNX model outputs is associated with
prediction probabilities. Use use a
labels_map
to map prediction
probabilities to predictions, known as labels.
For regression tasks, if the ONNX model generates only one
output, then predictions_name
is
optional. If the ONNX model generates more than one output,
then predictions_name
is required. Do not
provide prediction_probabilities_name
as
this will cause an error.
For classification tasks use
predictions_name
or
prediction_probabilities_name
or both.
Failure to provide at least one will cause an error. The
model explainers SHAP, Fast SHAP and Partial Dependence
require prediction_probabilities_name
.
Only use a labels_map
with classification
tasks. A labels_map
requires
predictions_probabilities_name
. The use
of a labels_map
with any other task, or
with predictions_name
or without
predictions_probabilities_name
will cause
an error.
An example of a
predictions_probabilities_name
with a
labels_map
produces these labels:
predictions_probabilities_name = array([[0.35, 0.50, 0.15],
[0.10, 0.20, 0.70],
[0.90, 0.05, 0.05],
[0.55, 0.05, 0.40]], dtype=float32)
labels_map = {0:'Iris-virginica', 1:'Iris-versicolor', 2:'Iris-setosa'}
labels=['Iris-versicolor', 'Iris-setosa', 'Iris-virginica', 'Iris-virginica']
Do not provide predictions_name
or
prediction_probabilities_name
when the
task is NULL
as this will cause an error.
HeatWave AutoML adds a note for ONNX models that have inputs with
four dimensions about the reshaping of data to a suitable
shape for an ONNX model. This would typically be for ONNX
models that are trained on image data. An example of this
note added to the ml_results
column:
mysql> CALL sys.ML_PREDICT_TABLE('mlcorpus_v5.mnist_test_temp', @model,
'mlcorpus_v5.`mnist_predictions`', NULL);
Query OK, 0 rows affected (20.6296 sec)
mysql> SELECT ml_results FROM mnist_predictions;;
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ml_results |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| {'predictions': {'prediction': 7}, 'Notes': 'Input data is reshaped into (1, 28, 28).', 'probabilities': {0: -552.7100219726562, 1: 138.27000427246094, 2: 2178.510009765625, 3: 2319.860107421875, 4: -3466.5400390625, 5: -1778.3499755859375, 6: -6441.83984375, 7: 8062.9599609375, 8: -1860.2099609375, 9: 1034.239990234375}} |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+