水铝英石

Add_SA_to_cycleGAN

因为去整seg再加上BME会议导致这玩意忘了写一下

首先在model/network>defineG()添加

    elif netG == 'SA_resnet_1blocks':
        net = SA_ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=1)

然后在model/network 追加类 SA_ResnetGenerator 复制于ResnetGenerator ,更改下面的地方

for i in range(n_blocks):       # add ResNet blocks
    model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

org

for i in range(n_blocks):       # add ResNet blocks
    model += [SA_ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

update

然后在 model/network 追加类 SA_ResnetBlock 复制于ResnetBlock,并引入类 Self_Attn
首先是init中追加Self_Attn

class SA_ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block

        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(SA_ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
        self.self_attn = Self_Attn(dim, 'relu')

然后更改SA_ResnetBlock中的build_conv_block()方法
众所周知,这个resblock里的一通操作只是为了:

选择加什么边 (个人强烈建议用 ReplicationPad2d, 选直接补0的也行)
一个卷积
dropout(0.5) ( 然而它默认是不使用,那我也不加了吧)
选择加什么边 (同上)
一个卷积

那我们就把原来的卷积后面再追加就完事了

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        conv_block += [self.self_attn]

然后为了防止错误,把self_attn的前传改一下(仅输出同大小特征图)

    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B X CX(N)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)

        out = self.gamma * out + x
        return out

最后把原network.py替换

现在打印一个SA_resblock 的生成器是

DataParallel(
  (module): SA_ResnetGenerator(
    (model): Sequential(
      (0): ReflectionPad2d((3, 3, 3, 3))
      (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
      (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (9): ReLU(inplace=True)
      (10): SA_ResnetBlock(
        (self_attn): Self_Attn(
          (query_conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
          (key_conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
          (value_conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
          (softmax): Softmax(dim=-1)
        )
        (conv_block): Sequential(
          (0): ReflectionPad2d((1, 1, 1, 1))
          (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
          (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (3): ReLU(inplace=True)
          (4): Self_Attn(
            (query_conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
            (key_conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
            (value_conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
            (softmax): Softmax(dim=-1)
          )
          (5): ReflectionPad2d((1, 1, 1, 1))
          (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
          (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (8): Self_Attn(
            (query_conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
            (key_conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
            (value_conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
            (softmax): Softmax(dim=-1)
          )
        )
      )
      (11): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (12): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (13): ReLU(inplace=True)
      (14): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (15): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (16): ReLU(inplace=True)
      (17): ReflectionPad2d((3, 3, 3, 3))
      (18): Conv2d(64, 3, kernel_size=(7, 7), stride=(1, 1))
      (19): Tanh()
    )
  )

PS: 这东西可不能这么塞(扩大复杂度得到的精度提高和说DenseBlock比ResBlock性能高,所以DenseBlock全面碾压ResBlock一样耍流氓)
建议把resblock里的两个conv替换成一个SA
或者之后直接将一个resblock换成一个SA
再或者在上面的基础上把SA的value_conv()换成3*3的并把resblock的补边方式继承到value_conv上
又或者类似原文,在后面几个结束时候的常规卷积(256>128>64>3)之间塞

滚动到顶部