-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPaddedUnfold.lua
66 lines (52 loc) · 2.25 KB
/
PaddedUnfold.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
local PaddedUnfold, parent = torch.class('PaddedUnfold', 'nn.Module')
function PaddedUnfold:__init(shiftDim, shiftRange)
parent.__init(self)
self.shiftDim = shiftDim
self.shiftRange = shiftRange
self.padded = torch.Tensor()
self.gradPadded = torch.Tensor()
end
function PaddedUnfold:updateOutput(input)
local width = input:size(self.shiftDim)
local leftPad = input:narrow(self.shiftDim, 1, self.shiftRange)
local rightPad = input:narrow(self.shiftDim, width-self.shiftRange+1, self.shiftRange)
torch.cat(self.padded, {leftPad, input, rightPad}, self.shiftDim)
self.output = self.padded:unfold(self.shiftDim, width, 1)
assert(self.output:size(self.shiftDim) == self.shiftRange*2+1)
self.output = self.output:transpose(3, 4)
-- self.output:resize(self.shiftRange*2+1, table.unpack(input:size()))
-- for i=1, self.shiftRange*2+1 do
-- local outslice = self.output:select(self.shiftDim, i):transpose(3, 4)
-- assert(outslice:eq(self.padded:narrow(self.shiftDim, i, width)))
-- end
-- print(#self.output)
-- debugger.enter()
-- self.output = self.output:permute(2, 3, 4, 5, 1)
return self.output
end
function PaddedUnfold:updateGradInput(input, gradOutput)
gradOutput = gradOutput:transpose(3, 4)
-- local m1 = cutorch.getMemoryUsage(2)
local width = input:size(self.shiftDim)
self.gradPadded:resizeAs(self.padded):zero()
for i=1, self.shiftRange*2+1 do
local gradOutSlice = gradOutput:select(self.shiftDim, i):transpose(3, 4)
self.gradPadded:narrow(self.shiftDim, i, width):add(gradOutSlice)
end
local gradLeftPad = self.gradPadded:narrow(self.shiftDim, 1, self.shiftRange)
local gradRightPad = self.gradPadded:narrow(self.shiftDim,
self.gradPadded:size(self.shiftDim)-self.shiftRange+1, self.shiftRange)
-- print('check gradPadded bounds')
-- debugger.enter()
local gradCenter = self.gradPadded:narrow(self.shiftDim, self.shiftRange+1, width)
self.gradInput:resize(input:size()):copy(gradCenter)
self.gradInput:narrow(self.shiftDim, 1, self.shiftRange):add(gradLeftPad)
self.gradInput:narrow(self.shiftDim, width-self.shiftRange+1, self.shiftRange):add(gradRightPad)
-- local m2 = cutorch.getMemoryUsage(2)
return self.gradInput
end
function PaddedUnfold:clearState()
self.padded:set()
self.gradPadded:set()
return parent.clearState(self)
end