-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathfusion_biasgelu.py
62 lines (51 loc) · 2.26 KB
/
fusion_biasgelu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
from logging import getLogger
from onnx import helper
from onnx_model import OnnxModel
from fusion_base import Fusion
from fusion_utils import NumpyHelper
logger = getLogger(__name__)
class FusionBiasGelu(Fusion):
def __init__(self, model: OnnxModel, is_fastgelu):
if is_fastgelu:
super().__init__(model, 'FastGelu', 'FastGelu', 'add bias')
else:
super().__init__(model, 'BiasGelu', 'Gelu')
def fuse(self, node, input_name_to_nodes, output_name_to_node):
gelu_op_type = node.op_type
fuse_op_type = 'BiasGelu' if gelu_op_type == 'Gelu' else 'FastGelu'
if len(node.input) != 1:
return
nodes = self.model.match_parent_path(node, ['Add', 'MatMul'], [0, None])
if nodes is None:
return
(add, matmul) = nodes
bias_weight = None
# bias should be one dimension
bias_index = -1
for i, input in enumerate(add.input):
initializer = self.model.get_initializer(input)
if initializer is None:
continue
bias_index = i
bias_weight = NumpyHelper.to_array(initializer)
break
if bias_weight is None:
return
if len(bias_weight.shape) != 1:
return
subgraph_nodes = [node, add]
if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [node.output[0]], input_name_to_nodes,
output_name_to_node):
return
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(fuse_op_type,
inputs=[matmul.output[0], add.input[bias_index]],
outputs=node.output,
name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"))
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name