-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathoptimizer.py
132 lines (115 loc) · 4.7 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
"""
import torch
from torch.optim import Optimizer
## stolen from torch.optim.optimizer
class _RequiredParameter(object):
"""Singleton class representing a required parameter for an Optimizer."""
def __repr__(self):
return "<required parameter>"
required = _RequiredParameter()
####################################
class SGDLRD(Optimizer):
r"""Implements learning rate dropout from the paper "Learning Rate Dropout"
by Lin et. al (https://arxiv.org/abs/1912.00144)
Original SGD implementation is taken from:
https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
lr_dropout_rate (float, optional): Bernoulli parameter of binary mask
for each update. Each update retained w.p. `lr_dropout_rate`
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = SGDLRD(model.parameters(), lr=0.1, lr_dropout_rate=0.5,
momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
"""
def __init__(
self,
params,
lr=required,
lr_dropout_rate=0.0,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if lr_dropout_rate < 0.0:
raise ValueError(
"Invalid learning rate dropout parameter: {}".format(lr_dropout_rate)
)
elif lr_dropout_rate == 0.0:
raise ValueError(
"Learning rate dropout must be positive in order to retain some entries"
)
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(
lr=lr,
lr_dropout_rate=lr_dropout_rate,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGDLRD, self).__init__(params, defaults)
def __setstate__(self, state):
super(SGDLRD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
lr_dropout_rate = group["lr_dropout_rate"]
dampening = group["dampening"]
nesterov = group["nesterov"]
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p = d_p.add(weight_decay, p.data)
if momentum != 0:
param_state = self.state[p]
if "momentum_buffer" not in param_state:
buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
else:
buf = param_state["momentum_buffer"]
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
# construct random binary mask
# each parameter retained with probability `lr_dropout_rate`
device = d_p.get_device() if d_p.is_cuda else "cpu"
mask = (
torch.rand_like(d_p, device=device, requires_grad=False)
< lr_dropout_rate
).type(dtype=d_p.dtype)
# apply the mask!
d_p.mul_(mask)
p.data.add_(-group["lr"], d_p)
return loss