Newer
Older
casic_iris_recognize / model / unet_vgg16_multitask_attention.py
zhangyingjie on 23 Sep 2021 5 KB irst commit
import torch
from .unet_parts import *
from .attention import ASPP


class UnetWithVGG16Attention(torch.nn.Module):

    def __init__(self, encoder, n_classes=1, bilinear=True):
        super(UnetWithVGG16Attention, self).__init__()

        self.features = encoder.features[:-1]  # drop last maxpooling
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.downs = []

        self.feature_len: int = len(self.features)

        self.attention = ASPP(in_channel=512,depth=256)

        self.up1 = Up(1536, 256, self.bilinear)
        self.up2 = Up(512, 128, self.bilinear)
        self.up3 = Up(256, 64, self.bilinear)
        self.up4 = Up(128, 32, self.bilinear)
        self.outc = OutConv(32, n_classes * 3)
        self.outc1 = OutConv(32, n_classes)
        self.outc2 = OutConv(32, n_classes)
        self.outc3 = OutConv(32, n_classes)

    def forward(self, x):
        # dict = {'5':0, '12':1, '22':2, '32':3}
        # concat_index = [5, 12, 22, 32]  # bn
        concat_index = [3, 8, 15, 22]
        downs = []

        f_len = self.feature_len
        o = x
        for i, feature in enumerate(self.features):
            o = feature(o)
            if i in concat_index:
                downs.append(o)
                # print(len(self.downs))
                # print(str(i) + ' ' + str(dict[str(i)]))
                # self.downs[dict[str(i)]] = o

        o = self.attention(o)

        x = self.up1(o, downs[-1])
        x = self.up2(x, downs[-2])
        x = self.up3(x, downs[-3])
        x = self.up4(x, downs[-4])

        mask_iris_pupil = self.outc(x)

        spilt_mask_iris_pupil = torch.split(mask_iris_pupil, 1, dim=1)
        pupil = spilt_mask_iris_pupil[0]
        mask = spilt_mask_iris_pupil[1]
        iris = spilt_mask_iris_pupil[2]

        # mask = self.outc1(x)
        # iris = self.outc2(x)
        # pupil = self.outc3(x)

        return mask, iris, pupil

    def get_encoder(self, features, batch_norm=True):
        downs = []
        pools = []

        conv1_1 = features[0]
        conv1_1_bn = nn.BatchNorm2d(64)
        conv1_1_relu = features[1]
        conv1_2 = features[2]
        conv1_2_bn = nn.BatchNorm2d(64)
        conv1_2_relu = features[3]

        # downs.append([conv1_1, conv1_1_bn, conv1_1_relu,
        #               conv1_2, conv1_2_bn, conv1_2_relu])
        downs.append(
            nn.Sequential(
                conv1_1, conv1_1_bn, conv1_1_relu,
                conv1_2, conv1_2_bn, conv1_2_relu))

        pools.append(features[4])

        conv2_1 = features[5]
        conv2_1_bn = nn.BatchNorm2d(128)
        conv2_1_relu = features[6]
        conv2_2 = features[7]
        conv2_2_bn = nn.BatchNorm2d(128)
        conv2_2_relu = features[8]

        # downs.append([conv2_1, conv2_1_bn, conv2_1_relu,
        #               conv2_2, conv2_2_bn, conv2_2_relu])
        downs.append(
            nn.Sequential(
                conv2_1, conv2_1_bn, conv2_1_relu,
                conv2_2, conv2_2_bn, conv2_2_relu))

        pools.append(features[9])

        conv3_1 = features[10]
        conv3_1_bn = nn.BatchNorm2d(256)
        conv3_1_relu = features[11]
        conv3_2 = features[12]
        conv3_2_bn = nn.BatchNorm2d(256)
        conv3_2_relu = features[13]
        conv3_3 = features[14]
        conv3_3_bn = nn.BatchNorm2d(256)
        conv3_3_relu = features[15]

        # downs.append([conv3_1, conv3_1_bn, conv3_1_relu,
        #               conv3_2, conv3_2_bn, conv3_2_relu,
        #               conv3_3, conv3_3_bn, conv3_3_relu])
        downs.append(
            nn.Sequential(
                conv3_1, conv3_1_bn, conv3_1_relu,
                conv3_2, conv3_2_bn, conv3_2_relu,
                conv3_3, conv3_3_bn, conv3_3_relu))

        pools.append(features[16])

        conv4_1 = features[17]
        conv4_1_bn = nn.BatchNorm2d(512)
        conv4_1_relu = features[18]
        conv4_2 = features[19]
        conv4_2_bn = nn.BatchNorm2d(512)
        conv4_2_relu = features[20]
        conv4_3 = features[21]
        conv4_3_bn = nn.BatchNorm2d(512)
        conv4_3_relu = features[22]

        # downs.append([conv4_1, conv4_1_bn, conv4_1_relu,
        #               conv4_2, conv4_2_bn, conv4_2_relu,
        #               conv4_3, conv4_3_bn, conv4_3_relu])
        downs.append(
            nn.Sequential(
                conv4_1, conv4_1_bn, conv4_1_relu,
                conv4_2, conv4_2_bn, conv4_2_relu,
                conv4_3, conv4_3_bn, conv4_3_relu))


        pools.append(features[23])

        conv5_1 = features[24]
        conv5_1_bn = nn.BatchNorm2d(512)
        conv5_1_relu = features[25]
        conv5_2 = features[26]
        conv5_2_bn = nn.BatchNorm2d(512)
        conv5_2_relu = features[27]
        conv5_3 = features[28]
        conv5_3_bn = nn.BatchNorm2d(512)
        conv5_3_relu = features[29]

        # downs.append([conv5_1, conv5_1_bn, conv5_1_relu,
        #               conv5_2, conv5_2_bn, conv5_2_relu,
        #               conv5_3, conv5_3_bn, conv5_3_relu])
        downs.append(
            nn.Sequential(
                conv5_1, conv5_1_bn, conv5_1_relu,
                conv5_2, conv5_2_bn, conv5_2_relu,
                conv5_3, conv5_3_bn, conv5_3_relu))

        return downs, pools