-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquantizer_step.py
executable file
·90 lines (64 loc) · 2.05 KB
/
quantizer_step.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
import math
import torch
import torch.nn as nn
import torch.optim as optim
class Floor(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return torch.floor(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
# zero (correct gradient)
grad_input *= 0
# linear
# grad_input *= 1
# quadratic
# grad_input *= 2 * (input - torch.floor(input))
# cubic
# grad_input *= 3 * (input - torch.floor(input)) ** 2
# Fourier series expansion
# grad_input *= 1 + 2 * torch.cos(2 * math.pi * input)
# log
grad_input *= 1.0 / (input + 1.0)
return grad_input
class Quantizer(nn.Module):
def __init__(self):
super(Quantizer, self).__init__()
self.delta = nn.Parameter(torch.tensor(1.0))
self.floor = Floor.apply
def forward(self, x):
encoded = self.floor(x / self.delta)
decoded = self.delta * (encoded + 0.5)
return decoded
def main():
batch_size = 128
lr = 0.01
epochs = 100000
# whether CUDA is available
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
# training
model = Quantizer().to(device)
model.train()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
for epoch in range(1, epochs + 1):
data = 2 * torch.rand(batch_size, 1) - 1
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, data)
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print('Epoch: {}, Loss: {:.6f}'.format(epoch, loss.item()))
print(model.delta, model.delta.grad)
if abs(model.delta.item()) < 0.0001:
print('Done {} epochs'.format(epoch))
return
if __name__ == '__main__':
main()