From babb94998fc120e9591fcab5072c694029ceb623 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Thu, 4 Jun 2020 15:37:52 -0700 Subject: [PATCH] Support tf.nn.leaky_relu and fix advanced_activations (#514) --- keras2onnx/ke2onnx/activation.py | 30 +++++++++++++++++++++++------- keras2onnx/ke2onnx/conv.py | 2 +- tests/test_layers.py | 8 +++++++- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/keras2onnx/ke2onnx/activation.py b/keras2onnx/ke2onnx/activation.py index 216b8ce1..d9369940 100644 --- a/keras2onnx/ke2onnx/activation.py +++ b/keras2onnx/ke2onnx/activation.py @@ -6,8 +6,8 @@ import numpy as np import tensorflow as tf from ..proto import keras, is_tf_keras -from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_relu, apply_relu_6, apply_sigmoid, apply_tanh, \ - apply_softmax, apply_identity, apply_selu, apply_mul +from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_leaky_relu, apply_relu, apply_relu_6, \ + apply_tanh, apply_softmax, apply_identity, apply_selu, apply_mul, apply_prelu, apply_sigmoid from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE activation_get = keras.activations.get @@ -21,6 +21,11 @@ if not relu6 and hasattr(keras.applications.mobilenet, 'relu6'): relu6 = keras.applications.mobilenet.relu6 + +def apply_leaky_relu_keras(scope, input_name, output_name, container, operator_name=None, alpha=0.2): + apply_leaky_relu(scope, input_name, output_name, container, operator_name, alpha) + + activation_map = {activation_get('sigmoid'): apply_sigmoid, activation_get('softmax'): apply_softmax, activation_get('linear'): apply_identity, @@ -29,6 +34,7 @@ activation_get('selu'): apply_selu, activation_get('tanh'): apply_tanh, activation_get('hard_sigmoid'): apply_hard_sigmoid, + tf.nn.leaky_relu: apply_leaky_relu_keras, tf.nn.sigmoid: apply_sigmoid, tf.nn.softmax: apply_softmax, tf.nn.relu: apply_relu, @@ -40,6 +46,7 @@ if hasattr(tf.compat, 'v1'): activation_map.update({tf.compat.v1.nn.sigmoid: apply_sigmoid}) activation_map.update({tf.compat.v1.nn.softmax: apply_softmax}) + activation_map.update({tf.compat.v1.nn.leaky_relu: apply_leaky_relu_keras}) activation_map.update({tf.compat.v1.nn.relu: apply_relu}) activation_map.update({tf.compat.v1.nn.relu6: apply_relu_6}) activation_map.update({tf.compat.v1.nn.elu: apply_elu}) @@ -51,15 +58,20 @@ def convert_keras_activation(scope, operator, container): input_name = operator.input_full_names[0] output_name = operator.output_full_names[0] activation = operator.raw_operator.activation + activation_type = type(activation) if activation in [activation_get('sigmoid'), keras.activations.sigmoid]: apply_sigmoid(scope, input_name, output_name, container) elif activation in [activation_get('tanh'), keras.activations.tanh]: apply_tanh(scope, input_name, output_name, container) - elif activation in [activation_get('relu'), keras.activations.relu]: + elif activation in [activation_get('relu'), keras.activations.relu] or \ + (hasattr(keras.layers.advanced_activations, 'ReLU') and + activation_type == keras.layers.advanced_activations.ReLU): apply_relu(scope, input_name, output_name, container) - elif activation in [activation_get('softmax'), keras.activations.softmax]: + elif activation in [activation_get('softmax'), keras.activations.softmax] or \ + activation_type == keras.layers.advanced_activations.Softmax: apply_softmax(scope, input_name, output_name, container, axis=-1) - elif activation in [activation_get('elu'), keras.activations.elu]: + elif activation in [activation_get('elu'), keras.activations.elu] or \ + activation_type == keras.layers.advanced_activations.ELU: apply_elu(scope, input_name, output_name, container, alpha=1.0) elif activation in [activation_get('hard_sigmoid'), keras.activations.hard_sigmoid]: apply_hard_sigmoid(scope, input_name, output_name, container, alpha=0.2, beta=0.5) @@ -67,13 +79,17 @@ def convert_keras_activation(scope, operator, container): apply_identity(scope, input_name, output_name, container) elif activation in [activation_get('selu'), keras.activations.selu]: apply_selu(scope, input_name, output_name, container, alpha=1.673263, gamma=1.050701) - elif activation in [relu6] or activation.__name__ == 'relu6': + elif activation_type == keras.layers.advanced_activations.LeakyReLU: + apply_leaky_relu(scope, input_name, output_name, container, alpha=activation.alpha.item(0)) + elif activation_type == keras.layers.advanced_activations.PReLU: + apply_prelu(scope, input_name, output_name, container, slope=operator.raw_operator.get_weights()[0]) + elif activation in [relu6] or (hasattr(activation, '__name__') and activation.__name__ == 'relu6'): # relu6(x) = min(relu(x), 6) np_type = TENSOR_TYPE_TO_NP_TYPE[operator.inputs[0].type.to_onnx_type().tensor_type.elem_type] zero_value = np.zeros(shape=(1,), dtype=np_type) apply_relu_6(scope, input_name, output_name, container, zero_value=zero_value) - elif activation.__name__ in ['swish']: + elif hasattr(activation, '__name__') and activation.__name__ == 'swish': apply_sigmoid(scope, input_name, output_name + '_sig', container) apply_mul(scope, [input_name, output_name + '_sig'], output_name, container) else: diff --git a/keras2onnx/ke2onnx/conv.py b/keras2onnx/ke2onnx/conv.py index 2b0f20a9..ec774310 100644 --- a/keras2onnx/ke2onnx/conv.py +++ b/keras2onnx/ke2onnx/conv.py @@ -197,7 +197,7 @@ def convert_keras_conv_core(scope, operator, container, is_transpose, n_dims, in # The construction of convolution is done. Now, we create an activation operator to apply the activation specified # in this Keras layer. - if op.activation.__name__ == 'swish': + if hasattr(op.activation, '__name__') and op.activation.__name__ == 'swish': apply_sigmoid(scope, transpose_output_name, transpose_output_name + '_sig', container) apply_mul(scope, [transpose_output_name, transpose_output_name + '_sig'], operator.outputs[0].full_name, container) diff --git a/tests/test_layers.py b/tests/test_layers.py index 57c9e6fb..46505af0 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1373,13 +1373,19 @@ def test_Softmax(advanced_activation_runner): advanced_activation_runner(layer, data) +@pytest.mark.skipif(is_tensorflow_older_than('1.14.0') and is_tf_keras, reason='old tf version') def test_tf_nn_activation(runner): - for activation in [tf.nn.relu, 'relu', tf.nn.relu6, tf.nn.softmax]: + for activation in ['relu', tf.nn.relu, tf.nn.relu6, tf.nn.softmax, tf.nn.leaky_relu]: model = keras.Sequential([ Dense(64, activation=activation, input_shape=[10]), Dense(64, activation=activation), Dense(1) ]) + if is_tf_keras: + model.add(Activation(tf.keras.layers.LeakyReLU(alpha=0.2))) + model.add(Activation(tf.keras.layers.ReLU())) + model.add(tf.keras.layers.PReLU()) + model.add(tf.keras.layers.LeakyReLU(alpha=0.5)) x = np.random.rand(5, 10).astype(np.float32) expected = model.predict(x) onnx_model = keras2onnx.convert_keras(model, model.name)