Skip to content

vlavla/mxnet-transducer

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mxnet-transducer

A fast parallel implementation of RNN Transducer (Graves 2013 joint network), on both CPU and GPU for mxnet.

GPU version is now available for Graves2012 add network.

Install and Test

First get mxnet and the code:

git clone --recursive https://github.com/apache/incubator-mxnet
git clone https://github.com/HawkAaron/mxnet-transducer

Copy all files into mxnet dir:

cp -r mxnet-transducer/rnnt* incubator-mxnet/src/operator/contrib/

Then follow the installation instructions of mxnet:

https://mxnet.incubator.apache.org/install/index.html

Finally, add Python API into /path/to/mxnet_root/mxnet/gluon/loss.py:

class RNNTLoss(Loss):
    def __init__(self, batch_first=True, blank_label=0, weight=None, **kwargs):
        batch_axis = 0 if batch_first else 2
        super(RNNTLoss, self).__init__(weight, batch_axis, **kwargs)
        self.batch_first = batch_first
        self.blank_label = blank_label

    def hybrid_forward(self, F, pred, label, pred_lengths, label_lengths):
        if not self.batch_first:
            pred = F.transpose(pred, (2, 0, 1, 3))

        loss = F.contrib.RNNTLoss(pred, label.astype('int32', False), 
                                    pred_lengths.astype('int32', False), 
                                    label_lengths.astype('int32', False), 
                                    blank_label=self.blank_label)
        return loss

From the repo test with:

python test/test.py --mx

Reference

About

Fast parallel RNN-Transducer.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • C++ 73.9%
  • Python 22.0%
  • Cuda 3.6%
  • C 0.5%