forked from HawkAaron/mxnet-transducer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrnnt_loss.cc
113 lines (94 loc) · 4.4 KB
/
rnnt_loss.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file rnnt_loss.cc
* \brief
* \author Mingkun Huang
*/
#include "./rnnt_loss-inl.h"
#include "./rnnt_include/detail/cpu_rnnt.h"
namespace mshadow {
template <typename DType>
void compute_rnnt_cost(const Tensor<cpu, 4, DType> acts, // BTUV
DType *costs, DType *grads, int *labels,
int *label_lengths, int *data_lengths,
void *workspace, int train, int blank_label) {
int minibatch = static_cast<int>(acts.size(0));
int maxT = static_cast<int>(acts.size(1));
int maxU = static_cast<int>(acts.size(2));
int alphabet_size = static_cast<int>(acts.size(3));
warp_rnnt::CpuRNNT<DType> rnnt(minibatch, maxT, maxU, alphabet_size, workspace, blank_label);
if (train) {
rnnt.cost_and_grad(acts.dptr_, grads, costs, labels, label_lengths, data_lengths);
} else {
rnnt.score_forward(acts.dptr_, costs, labels, label_lengths, data_lengths);
}
}
} // namespace mshadow
namespace mxnet {
namespace op {
template <>
Operator *CreateOp<cpu>(RNNTLossParam param, int dtype) {
return new RNNTLossOp<cpu>(param);
}
// DO_BIND_DISPATCH comes from operator_common.h
Operator *RNNTLossProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
}
DMLC_REGISTER_PARAMETER(RNNTLossParam);
MXNET_REGISTER_OP_PROPERTY(_contrib_RNNTLoss, RNNTLossProp)
.describe(R"code(RNN Transducer Loss.
The shapes of the inputs and outputs:
- **data**: `(batch_size, sequence_length, label_length + 1, alphabet_size)`
- **label**: `(batch_size, label_sequence_length)`
- **out**: `(batch_size)`
The `data` tensor consists of sequences of activation vectors (after applying softmax),
with i-th channel in the last dimension corresponding to i-th label
for i between 0 and alphabet_size-1 (i.e always 0-indexed).
Alphabet size should include one additional value reserved for blank label.
When `blank_label` is ``"first"``, the ``0``-th channel is be reserved for
activation of blank label, or otherwise if it is "last", ``(alphabet_size-1)``-th channel should be
reserved for blank label.
``label`` is an index matrix of integers. When `blank_label` is ``"first"``,
the value 0 is then reserved for blank label, and should not be passed in this matrix. Otherwise,
when `blank_label` is ``"last"``, the value `(alphabet_size-1)` is reserved for blank label.
``out`` is a list of RNNT loss values, one per example in the batch.
See *Sequence Transduction with Recurrent Neural Networks*, A. Graves. for more
information on the definition and the algorithm.
)code" ADD_FILELINE)
.add_argument("data", "NDArray-or-Symbol", "Input data to the rnnt_loss op.")
.add_argument("label", "NDArray-or-Symbol",
"Ground-truth labels for the loss.")
.add_argument("data_lengths", "NDArray-or-Symbol",
"Lengths of data for each of the samples. Only required "
"when use_data_lengths is true.")
.add_argument("label_lengths", "NDArray-or-Symbol",
"Lengths of labels for each of the samples. Only required "
"when use_label_lengths is true.")
.add_arguments(RNNTLossParam::__FIELDS__());
NNVM_REGISTER_OP(_contrib_RNNTLoss).add_alias("_contrib_rnnt_loss");
} // namespace op
} // namespace mxnet