Skip to content

Commit

Permalink
Adding support of XLA
Browse files Browse the repository at this point in the history
Adding support of AMP (FP16)
Adding logging
  • Loading branch information
ekuznetsov139 committed Jun 12, 2020
1 parent 0685d8b commit 4b0a3ca
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 118 deletions.
12 changes: 12 additions & 0 deletions cond_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import tensorflow as tf

use_xla = 0

def conditional_xla():
def decorator(func):
if use_xla==1:
return tf.function(experimental_compile=True,experimental_relax_shapes=True)(func)
else:
return func
return decorator

35 changes: 35 additions & 0 deletions fp16_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np


def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
"""
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable

46 changes: 31 additions & 15 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import numpy as np
import six
import tensorflow as tf
tf.compat.v1.disable_resource_variables()
tf.compat.v1.disable_eager_execution()

from gpu_environment import get_custom_getter
from cond_xla import conditional_xla, use_xla

class BertConfig(object):
"""Configuration for `BertModel`."""
Expand Down Expand Up @@ -137,7 +137,8 @@ def __init__(self,
input_mask=None,
token_type_ids=None,
use_one_hot_embeddings=False,
scope=None):
scope=None,
compute_type=tf.float32):
"""Constructor for BertModel.
Args:
Expand Down Expand Up @@ -170,7 +171,7 @@ def __init__(self,
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

with tf.compat.v1.variable_scope(scope, default_name="bert"):
with tf.compat.v1.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)):
with tf.compat.v1.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
Expand Down Expand Up @@ -205,7 +206,8 @@ def __init__(self,
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model(
input_tensor=self.embedding_output,
input_tensor=tf.saturate_cast(self.embedding_output, compute_type) \
if self.embedding_output.dtype!=compute_type else self.embedding_output,
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
Expand Down Expand Up @@ -262,7 +264,7 @@ def get_embedding_output(self):
def get_embedding_table(self):
return self.embedding_table


@conditional_xla()
def gelu(x):
"""Gaussian Error Linear Unit.
Expand All @@ -274,12 +276,14 @@ def gelu(x):
Returns:
`x` with the GELU activation applied.
"""
try:
return tf.nn.gelu(x)
except:
cdf = 0.5 * (1.0 + tf.tanh(
if not use_xla:
try:
return tf.nn.gelu(x)
except:
pass
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
return x * cdf


def get_activation(activation_string):
Expand Down Expand Up @@ -366,10 +370,22 @@ def dropout(input_tensor, dropout_prob):

def layer_norm(input_tensor, name=None):
"""Run layer normalization on the last dimension of the tensor."""
# return tf.contrib.layers.layer_norm(
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
return tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12)(inputs=input_tensor)

param_shape = input_tensor.shape[-1:]
scale = tf.compat.v1.get_variable(name="scale", shape=param_shape,
initializer=tf.ones_initializer())
offset = tf.compat.v1.get_variable(name="offset", shape=param_shape,
initializer=tf.zeros_initializer())
@conditional_xla()
def batch_norm(t, s, o):
mean, variance = tf.nn.moments(t, axes=[-1], keepdims=True)
return tf.nn.batch_normalization(
t,
mean,
variance,
offset=o,
scale=s,
variance_epsilon=1e-12)
return batch_norm(input_tensor, scale, offset)

def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
"""Runs layer normalization followed by dropout."""
Expand Down
Loading

0 comments on commit 4b0a3ca

Please sign in to comment.