forked from alibaba/TinyNeuralNetwork
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathidentity_pruner.py
48 lines (34 loc) · 1.65 KB
/
identity_pruner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from tinynn.prune import OneShotChannelPruner
from tinynn.util.util import get_logger
log = get_logger(__name__)
class IdentityChannelPruner(OneShotChannelPruner):
required_params = ('sparsity', 'metrics')
bn_compensation: bool
exclude_ops: list
def __init__(self, model, dummy_input, config=None):
"""Constructs a new IdentityChannelPruner
Args:
model: The model to be pruned
dummy_input: A viable input to the model
config (dict, str): Configuration of the pruner (could be path to the json file)
Raises:
Exception: If a model without parameters or prunable ops is given, the exception will be thrown
"""
self.bn_compensation = True
if config is None:
config = {"metrics": "l2_norm", "sparsity": 0.5}
if "metrics" not in config:
config["metrics"] = "l2_norm"
config["sparsity"] = 1.0
super(IdentityChannelPruner, self).__init__(model, dummy_input, config)
def register_mask(self):
"""Computes the mask for the parameters in the model and register them through the maskers"""
log.info("Register a mask for each operator")
for sub_graph in self.graph_modifier.sub_graphs.values():
if sub_graph.skip:
log.info(f"skip subgraph {sub_graph.center}")
continue
sub_graph.calc_prune_idx(None, self.sparsity, multiple=self.multiple)
log.info(f"subgraph [{sub_graph.center}] compute over")
for m in self.graph_modifier.modifiers.values():
m.register_mask(self.graph_modifier.modifiers, None, self.sparsity)