Skip to content

Commit

Permalink
replace maxpool with blurred maxpool (ProGamerGov#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
genekogan committed Dec 11, 2020
1 parent a0e5e36 commit 6fcc807
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
14 changes: 10 additions & 4 deletions example.ipynb

Large diffs are not rendered by default.

51 changes: 48 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import math
import torch
import torch.utils.cpp_extension
import torch.nn as nn
Expand Down Expand Up @@ -138,7 +139,9 @@ def __init__(self, params, dtype, multidevice, backward_device, verbose=True):
r += 1

if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net.add_module(str(len(net)), layer)
#net.add_module(str(len(net)), layer)
net.add_module(str(len(net)), maxpool2d_blurred_layer)


self.net = net
log(self.net, self.verbose)
Expand Down Expand Up @@ -311,7 +314,8 @@ def __setup_layer_masks__(self):
self.content_masks[c] = torch.mean(content_mask.type(self.dtype), axis=1)[0]

for L, layer in enumerate(self.net):
if (isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d)):
#if (isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d)):
if (isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d) or isinstance(layer, MaxPool2d)):
if self.content_masks != None:
for k in range(self.num_styles):
h, w = self.content_masks[k].shape
Expand Down Expand Up @@ -756,6 +760,48 @@ def forward(self, input, source_tensor):
return self.match(input, source_tensor)



######################################################
# Blurred MaxPool2D (see https://github.com/ProGamerGov/neural-style-pt/issues/71)

class GaussianBlur(nn.Module):

def __init__(self, kernel_size=6, sigma = math.pi / 2):
super().__init__()
if type(sigma) is not list and type(sigma) is not tuple:
kernel_size = [kernel_size] * 2
if type(sigma) is not list and type(sigma) is not tuple:
sigma = [sigma] * 2

kernel = 1
meshgrid_tensor = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])

for size, std, mgrid in zip(kernel_size, sigma, meshgrid_tensor):
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - ((size - 1) / 2)) / std) ** 2 / 2)
self.kernel = (kernel / torch.sum(kernel)).view(1, 1, *kernel.size()).cuda()

def forward(self, x):
assert x.dim() == 4
groups = x.size(1)
weight = self.kernel.repeat(groups, * [1] * (self.kernel.dim() - 1))
x = torch.nn.functional.pad(x, (3,2,3,2), mode='reflect') # No idea if this is a good idea for keeping input the same size
x = torch.nn.functional.conv2d(x, weight=weight, groups=groups)
return x

blur_input = GaussianBlur(6, sigma = 0.25)

class MaxPool2d(torch.nn.MaxPool2d):
def forward(self, x):
x = blur_input(x)
x = x.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
x = x.contiguous().view(x.size()[:4] + (-1,))
pool, _ = torch.max(x, dim=-1)
return pool

maxpool2d_blurred_layer = MaxPool2d(kernel_size=2, stride=2)



######################################################
# TV regularization
Expand Down Expand Up @@ -835,4 +881,3 @@ def setup_cpu():
setup_cpu()
dtype, backward_device = torch.FloatTensor, "cpu"
return dtype, multidevice, backward_device

0 comments on commit 6fcc807

Please sign in to comment.