Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Support tf.nn.leaky_relu and fix advanced_activations (#514)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom authored Jun 4, 2020
1 parent 85e8132 commit babb949
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
30 changes: 23 additions & 7 deletions keras2onnx/ke2onnx/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import tensorflow as tf
from ..proto import keras, is_tf_keras
from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_relu, apply_relu_6, apply_sigmoid, apply_tanh, \
apply_softmax, apply_identity, apply_selu, apply_mul
from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_leaky_relu, apply_relu, apply_relu_6, \
apply_tanh, apply_softmax, apply_identity, apply_selu, apply_mul, apply_prelu, apply_sigmoid
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE

activation_get = keras.activations.get
Expand All @@ -21,6 +21,11 @@
if not relu6 and hasattr(keras.applications.mobilenet, 'relu6'):
relu6 = keras.applications.mobilenet.relu6


def apply_leaky_relu_keras(scope, input_name, output_name, container, operator_name=None, alpha=0.2):
apply_leaky_relu(scope, input_name, output_name, container, operator_name, alpha)


activation_map = {activation_get('sigmoid'): apply_sigmoid,
activation_get('softmax'): apply_softmax,
activation_get('linear'): apply_identity,
Expand All @@ -29,6 +34,7 @@
activation_get('selu'): apply_selu,
activation_get('tanh'): apply_tanh,
activation_get('hard_sigmoid'): apply_hard_sigmoid,
tf.nn.leaky_relu: apply_leaky_relu_keras,
tf.nn.sigmoid: apply_sigmoid,
tf.nn.softmax: apply_softmax,
tf.nn.relu: apply_relu,
Expand All @@ -40,6 +46,7 @@
if hasattr(tf.compat, 'v1'):
activation_map.update({tf.compat.v1.nn.sigmoid: apply_sigmoid})
activation_map.update({tf.compat.v1.nn.softmax: apply_softmax})
activation_map.update({tf.compat.v1.nn.leaky_relu: apply_leaky_relu_keras})
activation_map.update({tf.compat.v1.nn.relu: apply_relu})
activation_map.update({tf.compat.v1.nn.relu6: apply_relu_6})
activation_map.update({tf.compat.v1.nn.elu: apply_elu})
Expand All @@ -51,29 +58,38 @@ def convert_keras_activation(scope, operator, container):
input_name = operator.input_full_names[0]
output_name = operator.output_full_names[0]
activation = operator.raw_operator.activation
activation_type = type(activation)
if activation in [activation_get('sigmoid'), keras.activations.sigmoid]:
apply_sigmoid(scope, input_name, output_name, container)
elif activation in [activation_get('tanh'), keras.activations.tanh]:
apply_tanh(scope, input_name, output_name, container)
elif activation in [activation_get('relu'), keras.activations.relu]:
elif activation in [activation_get('relu'), keras.activations.relu] or \
(hasattr(keras.layers.advanced_activations, 'ReLU') and
activation_type == keras.layers.advanced_activations.ReLU):
apply_relu(scope, input_name, output_name, container)
elif activation in [activation_get('softmax'), keras.activations.softmax]:
elif activation in [activation_get('softmax'), keras.activations.softmax] or \
activation_type == keras.layers.advanced_activations.Softmax:
apply_softmax(scope, input_name, output_name, container, axis=-1)
elif activation in [activation_get('elu'), keras.activations.elu]:
elif activation in [activation_get('elu'), keras.activations.elu] or \
activation_type == keras.layers.advanced_activations.ELU:
apply_elu(scope, input_name, output_name, container, alpha=1.0)
elif activation in [activation_get('hard_sigmoid'), keras.activations.hard_sigmoid]:
apply_hard_sigmoid(scope, input_name, output_name, container, alpha=0.2, beta=0.5)
elif activation in [activation_get('linear'), keras.activations.linear]:
apply_identity(scope, input_name, output_name, container)
elif activation in [activation_get('selu'), keras.activations.selu]:
apply_selu(scope, input_name, output_name, container, alpha=1.673263, gamma=1.050701)
elif activation in [relu6] or activation.__name__ == 'relu6':
elif activation_type == keras.layers.advanced_activations.LeakyReLU:
apply_leaky_relu(scope, input_name, output_name, container, alpha=activation.alpha.item(0))
elif activation_type == keras.layers.advanced_activations.PReLU:
apply_prelu(scope, input_name, output_name, container, slope=operator.raw_operator.get_weights()[0])
elif activation in [relu6] or (hasattr(activation, '__name__') and activation.__name__ == 'relu6'):
# relu6(x) = min(relu(x), 6)
np_type = TENSOR_TYPE_TO_NP_TYPE[operator.inputs[0].type.to_onnx_type().tensor_type.elem_type]
zero_value = np.zeros(shape=(1,), dtype=np_type)
apply_relu_6(scope, input_name, output_name, container,
zero_value=zero_value)
elif activation.__name__ in ['swish']:
elif hasattr(activation, '__name__') and activation.__name__ == 'swish':
apply_sigmoid(scope, input_name, output_name + '_sig', container)
apply_mul(scope, [input_name, output_name + '_sig'], output_name, container)
else:
Expand Down
2 changes: 1 addition & 1 deletion keras2onnx/ke2onnx/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def convert_keras_conv_core(scope, operator, container, is_transpose, n_dims, in

# The construction of convolution is done. Now, we create an activation operator to apply the activation specified
# in this Keras layer.
if op.activation.__name__ == 'swish':
if hasattr(op.activation, '__name__') and op.activation.__name__ == 'swish':
apply_sigmoid(scope, transpose_output_name, transpose_output_name + '_sig', container)
apply_mul(scope, [transpose_output_name, transpose_output_name + '_sig'], operator.outputs[0].full_name,
container)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,13 +1373,19 @@ def test_Softmax(advanced_activation_runner):
advanced_activation_runner(layer, data)


@pytest.mark.skipif(is_tensorflow_older_than('1.14.0') and is_tf_keras, reason='old tf version')
def test_tf_nn_activation(runner):
for activation in [tf.nn.relu, 'relu', tf.nn.relu6, tf.nn.softmax]:
for activation in ['relu', tf.nn.relu, tf.nn.relu6, tf.nn.softmax, tf.nn.leaky_relu]:
model = keras.Sequential([
Dense(64, activation=activation, input_shape=[10]),
Dense(64, activation=activation),
Dense(1)
])
if is_tf_keras:
model.add(Activation(tf.keras.layers.LeakyReLU(alpha=0.2)))
model.add(Activation(tf.keras.layers.ReLU()))
model.add(tf.keras.layers.PReLU())
model.add(tf.keras.layers.LeakyReLU(alpha=0.5))
x = np.random.rand(5, 10).astype(np.float32)
expected = model.predict(x)
onnx_model = keras2onnx.convert_keras(model, model.name)
Expand Down

0 comments on commit babb949

Please sign in to comment.