Skip to content

Commit

Permalink
Update fuse_block.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbaijin authored May 30, 2024
1 parent 2770bda commit f23ffb0
Showing 1 changed file with 75 additions and 34 deletions.
109 changes: 75 additions & 34 deletions fuse_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import numpy as np
import numbers
from einops import rearrange


import math
## Layer Norm

def to_3d(x):
Expand Down Expand Up @@ -89,27 +88,45 @@ def forward(self, x):
return x


##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

# 原始的键值和查询生成
self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias)
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

def forward(self, x, y):
# 针对深度图的键值和查询生成
self.depth_kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
self.depth_kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias)
self.depth_q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.depth_q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=bias)

self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)

def forward(self, x, y, depth):
b, c, h, w = x.shape

kv = self.kv_dwconv(self.kv(x))
# 原始的键值和查询
kv = self.kv_dwconv(self.kv(x)) #3x3 卷积
k, v = kv.chunk(2, dim=1)
q = self.q_dwconv(self.q(y))

# 深度图的键值和查询
depth_kv = self.depth_kv_dwconv(self.depth_kv(depth))
depth_k, depth_v = depth_kv.chunk(2, dim=1)
depth_q = self.depth_q_dwconv(self.depth_q(depth))

# 结合原始和深度图信息
k = torch.cat([k, depth_k], dim=1)
v = torch.cat([v, depth_v], dim=1)
q = torch.cat([q, depth_q], dim=1)

# 以下部分与原始注意力机制类似
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
Expand All @@ -128,58 +145,81 @@ def forward(self, x, y):
return out


##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim_2, dim, num_heads=2, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias'):
super(TransformerBlock, self).__init__()

self.conv1 = nn.Conv2d(dim_2, dim, (1, 1))
# self.conv2 = nn.Conv2d(dim, dim_2, (1, 1))
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

def forward(self, input_R, input_S):
# input_ch = input_R.size()[1]
input_S = F.interpolate(input_S, [input_R.shape[2], input_R.shape[3]])
input_S = self.conv1(input_S)
# input_S = F.interpolate(input_S, size=input_size, mode='bilinear', align_corners=True)
input_R = self.norm1(input_R)
input_S = self.norm1(input_S)
input_R = input_R + self.attn(input_R, input_S)
input_R = input_R + self.ffn(self.norm2(input_R))
# ##########################################################################
# ## Multi-DConv Head Transposed Self-Attention (MDTA)
# class Attention(nn.Module):
# def __init__(self, dim, num_heads, bias):
# super(Attention, self).__init__()
# self.num_heads = num_heads
# self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

# self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
# self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias)
# self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
# self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=bias)
# self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

# def forward(self, x, y):
# b, c, h, w = x.shape

# kv = self.kv_dwconv(self.kv(x))
# k, v = kv.chunk(2, dim=1)
# q = self.q_dwconv(self.q(y))

# q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
# k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
# v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

# q = torch.nn.functional.normalize(q, dim=-1)
# k = torch.nn.functional.normalize(k, dim=-1)

# attn = (q @ k.transpose(-2, -1)) * self.temperature
# attn = attn.softmax(dim=-1)

# out = (attn @ v)

# out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

# out = self.project_out(out)
# return out

return input_R

##########################################################################
class TransformerBlock_1(nn.Module):
def __init__(self, dim_2, dim, dim_in, num_heads=2, ffn_expansion_factor=1, bias=False, LayerNorm_type='WithBias'):
def __init__(self, dim_2, dim, dim_in, num_heads=3, ffn_expansion_factor=1, bias=False, LayerNorm_type='WithBias'):
super(TransformerBlock_1, self).__init__()

self.conv1 = nn.Conv2d(dim_2, dim_in, (1, 1))
self.conv2 = nn.Conv2d(dim, dim_in, (1, 1))
self.conv3 = nn.Conv2d(dim_in, dim, (1, 1))
self.norm1 = LayerNorm(dim_in, LayerNorm_type)
self.attn = Attention(dim_in, num_heads, bias)


self.norm2 = LayerNorm(dim_in, LayerNorm_type)
self.ffn = FeedForward(dim_in, ffn_expansion_factor, bias)

def forward(self, input_R, input_S):
def forward(self, input_R, input_S,input_depth):
# input_ch = input_R.size()[1]
input_S = F.interpolate(input_S, [input_R.shape[2], input_R.shape[3]])
input_depth = F.interpolate(input_depth, [input_R.shape[2], input_R.shape[3]])

input_S = self.conv1(input_S)
input_R = self.conv2(input_R)
# input_S = F.interpolate(input_S, size=input_size, mode='bilinear', align_corners=True)

input_depth=self.conv1(input_depth)

input_R = self.norm1(input_R)
input_S = self.norm1(input_S)
input_R = input_R + self.attn(input_R, input_S)
input_depth = self.norm1(input_depth)
input_R = input_R + self.attn(input_R, input_S,input_depth)
input_R = input_R + self.ffn(self.norm2(input_R))
input_R = self.conv3(input_R)

return input_R




def W(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, bias=False),
Expand Down Expand Up @@ -258,3 +298,4 @@ def forward(self, key, query):

return self.fuse(torch.cat([key, out], dim=1))


0 comments on commit f23ffb0

Please sign in to comment.