import torch from .unet_parts import * from .attention import ASPP class UnetWithVGG16Attention(torch.nn.Module): def __init__(self, encoder, n_classes=1, bilinear=True): super(UnetWithVGG16Attention, self).__init__() self.features = encoder.features[:-1] # drop last maxpooling self.n_classes = n_classes self.bilinear = bilinear self.downs = [] self.feature_len: int = len(self.features) self.attention = ASPP(in_channel=512,depth=256) self.up1 = Up(1536, 256, self.bilinear) self.up2 = Up(512, 128, self.bilinear) self.up3 = Up(256, 64, self.bilinear) self.up4 = Up(128, 32, self.bilinear) self.outc = OutConv(32, n_classes * 3) self.outc1 = OutConv(32, n_classes) self.outc2 = OutConv(32, n_classes) self.outc3 = OutConv(32, n_classes) def forward(self, x): # dict = {'5':0, '12':1, '22':2, '32':3} # concat_index = [5, 12, 22, 32] # bn concat_index = [3, 8, 15, 22] downs = [] f_len = self.feature_len o = x for i, feature in enumerate(self.features): o = feature(o) if i in concat_index: downs.append(o) # print(len(self.downs)) # print(str(i) + ' ' + str(dict[str(i)])) # self.downs[dict[str(i)]] = o o = self.attention(o) x = self.up1(o, downs[-1]) x = self.up2(x, downs[-2]) x = self.up3(x, downs[-3]) x = self.up4(x, downs[-4]) mask_iris_pupil = self.outc(x) spilt_mask_iris_pupil = torch.split(mask_iris_pupil, 1, dim=1) pupil = spilt_mask_iris_pupil[0] mask = spilt_mask_iris_pupil[1] iris = spilt_mask_iris_pupil[2] # mask = self.outc1(x) # iris = self.outc2(x) # pupil = self.outc3(x) return mask, iris, pupil def get_encoder(self, features, batch_norm=True): downs = [] pools = [] conv1_1 = features[0] conv1_1_bn = nn.BatchNorm2d(64) conv1_1_relu = features[1] conv1_2 = features[2] conv1_2_bn = nn.BatchNorm2d(64) conv1_2_relu = features[3] # downs.append([conv1_1, conv1_1_bn, conv1_1_relu, # conv1_2, conv1_2_bn, conv1_2_relu]) downs.append( nn.Sequential( conv1_1, conv1_1_bn, conv1_1_relu, conv1_2, conv1_2_bn, conv1_2_relu)) pools.append(features[4]) conv2_1 = features[5] conv2_1_bn = nn.BatchNorm2d(128) conv2_1_relu = features[6] conv2_2 = features[7] conv2_2_bn = nn.BatchNorm2d(128) conv2_2_relu = features[8] # downs.append([conv2_1, conv2_1_bn, conv2_1_relu, # conv2_2, conv2_2_bn, conv2_2_relu]) downs.append( nn.Sequential( conv2_1, conv2_1_bn, conv2_1_relu, conv2_2, conv2_2_bn, conv2_2_relu)) pools.append(features[9]) conv3_1 = features[10] conv3_1_bn = nn.BatchNorm2d(256) conv3_1_relu = features[11] conv3_2 = features[12] conv3_2_bn = nn.BatchNorm2d(256) conv3_2_relu = features[13] conv3_3 = features[14] conv3_3_bn = nn.BatchNorm2d(256) conv3_3_relu = features[15] # downs.append([conv3_1, conv3_1_bn, conv3_1_relu, # conv3_2, conv3_2_bn, conv3_2_relu, # conv3_3, conv3_3_bn, conv3_3_relu]) downs.append( nn.Sequential( conv3_1, conv3_1_bn, conv3_1_relu, conv3_2, conv3_2_bn, conv3_2_relu, conv3_3, conv3_3_bn, conv3_3_relu)) pools.append(features[16]) conv4_1 = features[17] conv4_1_bn = nn.BatchNorm2d(512) conv4_1_relu = features[18] conv4_2 = features[19] conv4_2_bn = nn.BatchNorm2d(512) conv4_2_relu = features[20] conv4_3 = features[21] conv4_3_bn = nn.BatchNorm2d(512) conv4_3_relu = features[22] # downs.append([conv4_1, conv4_1_bn, conv4_1_relu, # conv4_2, conv4_2_bn, conv4_2_relu, # conv4_3, conv4_3_bn, conv4_3_relu]) downs.append( nn.Sequential( conv4_1, conv4_1_bn, conv4_1_relu, conv4_2, conv4_2_bn, conv4_2_relu, conv4_3, conv4_3_bn, conv4_3_relu)) pools.append(features[23]) conv5_1 = features[24] conv5_1_bn = nn.BatchNorm2d(512) conv5_1_relu = features[25] conv5_2 = features[26] conv5_2_bn = nn.BatchNorm2d(512) conv5_2_relu = features[27] conv5_3 = features[28] conv5_3_bn = nn.BatchNorm2d(512) conv5_3_relu = features[29] # downs.append([conv5_1, conv5_1_bn, conv5_1_relu, # conv5_2, conv5_2_bn, conv5_2_relu, # conv5_3, conv5_3_bn, conv5_3_relu]) downs.append( nn.Sequential( conv5_1, conv5_1_bn, conv5_1_relu, conv5_2, conv5_2_bn, conv5_2_relu, conv5_3, conv5_3_bn, conv5_3_relu)) return downs, pools