forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrozen_linear_folding.cpp
142 lines (121 loc) · 4.95 KB
/
frozen_linear_folding.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/fold_linear_bn.h>
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
#include <torch/csrc/jit/passes/utils/optimization_utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros_like.h>
#endif
namespace torch::jit {
namespace {
using Tensor = at::Tensor;
bool supportedLinearNode(Node* n) {
if (n->kind() == aten::linear) {
return true;
} else {
return false;
}
}
bool FoldFrozenLinearBatchnorm(Block* b) {
bool graph_modified = false;
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
graph_modified |= FoldFrozenLinearBatchnorm(block);
}
if (n->kind() == aten::batch_norm &&
supportedLinearNode(n->inputs().at(0)->node())) {
auto linear = n->inputs().at(0)->node();
auto bn = n;
if (nonConstantParameters(linear) || nonConstantParameters(bn)) {
continue;
}
auto bn_rm_ivalue = bn->namedInput("running_mean");
auto bn_rv_ivalue = bn->namedInput("running_var");
// check running_mean and running_var has value, if they are
// None(track_running_stats=False), skipping the folding path.
if (bn_rm_ivalue->type() == NoneType::get() &&
bn_rv_ivalue->type() == NoneType::get()) {
continue;
}
auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value();
int64_t linear_out_features = linear_w.size(0);
int64_t bn_num_features = bn_rm.size(0);
// Linear-BN needs to be fused while preserving the shapes of linear
// weight/bias. To preserve the shapes of linear weight/bias, the channel
// dim of bn needs to be broadcastable with the last dim of linear,
// because bn operates over the channel dim, (N, C_in, H, W) while linear
// operates over the last dim, (*, H_in). To be broadcastable, the number
// of features in bn and the number of output features from linear must
// satisfy the following condition:
// 1. they are equal, or
// 2. the number of features in bn is 1
// Otherwise, skip the folding path
if (!(linear_out_features == bn_num_features || bn_num_features == 1)) {
continue;
}
// implementation taken from torch/nn/utils/fusion.py
Tensor linear_b;
if (linear->namedInput("bias")->type() == NoneType::get()) {
at::ScalarType bias_dtype = bn_rm.scalar_type();
at::ScalarType weight_dtype = linear_w.scalar_type();
at::DeviceType weight_device = linear_w.device().type();
if (weight_device == at::kCUDA &&
(weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
bias_dtype == at::kFloat) {
bias_dtype = weight_dtype;
}
linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
} else {
linear_b = constant_as<Tensor>(linear->namedInput("bias")).value();
}
Tensor bn_w;
if (bn->namedInput("weight")->type() == NoneType::get()) {
bn_w = at::ones_like(bn_rm);
} else {
bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
}
Tensor bn_b;
if (n->namedInput("bias")->type() == NoneType::get()) {
bn_b = at::zeros_like(bn_rm);
} else {
bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
}
LinearBNParameters params;
params.linear_w = linear_w;
params.linear_b = linear_b;
params.bn_rm = bn_rm;
params.bn_rv = bn_rv;
params.bn_eps = bn_eps;
params.bn_w = bn_w;
params.bn_b = bn_b;
std::tuple<Tensor, Tensor> out =
computeUpdatedLinearWeightAndBias(params);
WithInsertPoint guard(linear);
auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out));
auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out));
auto linear_w_value = linear->namedInput("weight");
auto linear_b_value = linear->namedInput("bias");
fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn");
fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn");
linear->replaceInputWith(linear_w_value, fused_linear_w);
linear->replaceInputWith(linear_b_value, fused_linear_b);
bn->output()->replaceAllUsesWith(linear->output());
graph_modified = true;
}
}
return graph_modified;
}
} // namespace
bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) {
bool graph_modified = FoldFrozenLinearBatchnorm(graph->block());
EliminateDeadCode(graph);
return graph_modified;
}
} // namespace torch::jit