-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathself-attention-3D.py
78 lines (67 loc) · 2.94 KB
/
self-attention-3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
'''forked from https://github.com/openseg-group/OCNet.pytorch'''
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
class SelfAttentionBlock(nn.Module):
'''
The basic implementation for self-attention block/non-local block
Input:
N X C X z*y*x
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
value_channels : the dimension after the value transform
scale : choose the scale to downsample the input feature maps (save memory cost)
Return:
N X C X z*y*x
position-aware context features.(w/o concate or add with the input)
'''
def __init__(self, in_channels, key_channels=None, value_channels=None, out_channels=None, scale=1):
super(SelfAttentionBlock, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.out_channels = out_channels
self.key_channels = key_channels
self.value_channels = value_channels
if out_channels is None:
self.out_channels = in_channels
if key_channels is None:
self.key_channels = in_channels // 2
if value_channels is None:
self.value_channels = in_channels // 2
self.pool = nn.MaxPool3d(kernel_size=(scale, scale))
self.f_key = nn.Sequential(
nn.Conv3d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm3d(self.key_channels),
)
self.f_query = self.f_key
self.f_value = nn.Conv3d(in_channels=self.in_channels, out_channels=self.value_channels,
kernel_size=1, stride=1, padding=0)
self.W = nn.Conv3d(in_channels=self.value_channels, out_channels=self.out_channels,
kernel_size=1, stride=1, padding=0)
def forward(self, x):
batch_size, c, d, h, w = x.size()
if self.scale > 1:
x = self.pool(x)
value = self.f_value(x).view(batch_size, self.value_channels, -1)
value = value.permute(0, 2, 1)
query = self.f_query(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_key(x).view(batch_size, self.key_channels, -1)
sim_map = torch.matmul(query, key)
sim_map = (self.key_channels ** -.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.value_channels, *x.size()[2:])
context = self.W(context)
return context
if __name__ == '__main__':
import torch
img = torch.randn(2, 32, 8, 20, 20)
net = SelfAttentionBlock(in_channels=32)
out = net(img)
print(out.size())