Newer
Older
casic_iris_recognize / model / attention.py
zhangyingjie on 23 Sep 2021 2 KB irst commit
import torch
from torch import nn
import torch.nn.functional as F


class ASPP(nn.Module):
    def __init__(self, in_channel=512, depth=256):
        super(ASPP, self).__init__()
        # global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True)
        self.mean = nn.AdaptiveAvgPool2d((1, 1)) # (1,1) means output size
        self.conv = nn.Conv2d(in_channel, depth, 1, 1)
        # k=1 s=1 no pad
        self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1)
        self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6)
        self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12)
        self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18)

        self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1)

        self.conv_3x3_output1 = nn.Conv2d(in_channels= depth * 5, out_channels= 256, kernel_size=3, stride=1, padding=1)
        self.conv_3x3_output2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.sigmoid = F.sigmoid

    def forward(self, x):
        size = x.shape[2:]

        image_features = self.mean(x)
        image_features = self.conv(image_features)
        image_features = F.upsample(image_features, size=size, mode='bilinear')

        atrous_block1 = self.atrous_block1(x)

        atrous_block6 = self.atrous_block6(x)

        atrous_block12 = self.atrous_block12(x)

        atrous_block18 = self.atrous_block18(x)

        concat = torch.cat([image_features, atrous_block1, atrous_block6,
                                              atrous_block12, atrous_block18], dim=1)

        # 256 3*3 conv
        out_conv = self.conv_3x3_output1(concat)
        # 512 3*3 conv
        out_conv = self.conv_3x3_output2(out_conv)
        # sigmoid , M
        M = self.sigmoid(out_conv)
        # element-wise dot product:M*input
        out = M * x
        # concat with input
        net = torch.cat([x, out], dim=1)

        # net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6,
        #                                       atrous_block12, atrous_block18], dim=1))
        return net