├── README.md ├── layers.py └── network.py /README.md: -------------------------------------------------------------------------------- 1 | # [WIP] Auto Deeplab in PyTorch 2 | 3 | This is a PyTorch implementation of the architecture found by the Hierarchical Neural Architecture Search introduced in the paper: https://arxiv.org/abs/1901.02985v1 4 | 5 | TODO: 6 | - [x] Cell architecture 7 | - [x] Module architecture 8 | - [x] Evaluation in TensorBoard 9 | - [x] Training with LFW (http://vis-www.cs.umass.edu/lfw/) 10 | - [ ] Training with SURREAL 11 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def fixed_padding(inputs, kernel_size, dilation): 7 | """ 8 | https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/xception.py 9 | :param kernel_size: 10 | :param dilation: 11 | :return: 12 | """ 13 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 14 | pad_total = kernel_size_effective - 1 15 | pad_beg = pad_total // 2 16 | pad_end = pad_total - pad_beg 17 | padded_inputs = F.pad(inputs, [pad_beg, pad_end, pad_beg, pad_end]) 18 | return padded_inputs 19 | 20 | 21 | # from https://github.com/quark0/darts/blob/master/cnn/operations.py 22 | class DilConv(nn.Module): 23 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 24 | super(DilConv, self).__init__() 25 | self.op = nn.Sequential( 26 | nn.ReLU(inplace=False), 27 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 28 | groups=C_in, bias=False), 29 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 30 | nn.BatchNorm2d(C_out, affine=affine), 31 | ) 32 | 33 | def forward(self, x): 34 | return self.op(x) 35 | 36 | 37 | class SepConv(nn.Module): 38 | 39 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 40 | super(SepConv, self).__init__() 41 | self.op = nn.Sequential( 42 | nn.ReLU(inplace=False), 43 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 44 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 45 | nn.BatchNorm2d(C_in, affine=affine), 46 | nn.ReLU(inplace=False), 47 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 48 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 49 | nn.BatchNorm2d(C_out, affine=affine), 50 | ) 51 | 52 | def forward(self, x): 53 | return self.op(x) 54 | 55 | 56 | class Cell(nn.Module): 57 | def __init__(self, in_channels_h1, in_channels_h2, out_channels, dilation=1, activation=nn.ReLU6, 58 | bn=nn.BatchNorm2d): 59 | """ 60 | Initialization of inverted residual block 61 | :param in_channels_h1: number of input channels in h-1 62 | :param in_channels_h2: number of input channels in h-2 63 | :param out_channels: number of output channels 64 | :param t: the expansion factor of block 65 | :param s: stride of the first convolution 66 | :param dilation: dilation rate of 3*3 depthwise conv ?? fixme 67 | """ 68 | super(Cell, self).__init__() 69 | self.in_ = in_channels_h1 70 | self.out_ = out_channels 71 | self.activation = activation 72 | 73 | if in_channels_h1 > in_channels_h2: 74 | self.preprocess = FactorizedReduce(in_channels_h2, in_channels_h1) 75 | elif in_channels_h1 < in_channels_h2: 76 | # todo check this 77 | self.preprocess = nn.ConvTranspose2d(in_channels_h2, in_channels_h1, 3, stride=2, padding=1, output_padding=1) 78 | else: 79 | self.preprocess = None 80 | 81 | #self.atr3x3 = DilConv(in_channels_h1, out_channels, 3, 1, 1, dilation) 82 | #self.atr5x5 = DilConv(in_channels_h1, out_channels, 5, 1, 2, dilation) 83 | 84 | #self.sep3x3 = SepConv(in_channels_h1, out_channels, 3, 1, 1) 85 | #self.sep5x5 = SepConv(in_channels_h1, out_channels, 5, 1, 2) 86 | 87 | # Top 1 88 | self.top1_atr5x5 = DilConv(in_channels_h1, in_channels_h1, 5, 1, 2, dilation) 89 | self.top1_sep3x3 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 90 | 91 | # Top 2 92 | self.top2_sep5x5_1 = SepConv(in_channels_h1, in_channels_h1, 5, 1, 2) 93 | self.top2_sep5x5_2 = SepConv(in_channels_h1, in_channels_h1, 5, 1, 2) 94 | 95 | # Middle 96 | self.middle_sep3x3_1 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 97 | self.middle_sep3x3_2 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 98 | 99 | # Bottom 1 100 | self.bottom1_atr3x3 = DilConv(in_channels_h1, in_channels_h1, 3, 1, 1, dilation) 101 | self.bottom1_sep3x3 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 102 | 103 | # Bottom 2 104 | self.bottom2_atr5x5 = DilConv(in_channels_h1, in_channels_h1, 5, 1, 2, dilation) 105 | self.bottom2_sep5x5 = SepConv(in_channels_h1, in_channels_h1, 5, 1, 2) 106 | 107 | self.concate_conv = nn.Conv2d(in_channels_h1*5, out_channels, 1) 108 | 109 | def forward(self, h_1, h_2): 110 | """ 111 | 112 | :param h_1: 113 | :param h_2: 114 | :return: 115 | """ 116 | 117 | if self.preprocess is not None: 118 | h_2 = self.preprocess(h_2) 119 | 120 | top1 = self.top1_atr5x5(h_2) + self.top1_sep3x3(h_1) 121 | bottom1 = self.bottom1_atr3x3(h_1) + self.bottom1_sep3x3(h_2) 122 | middle = self.middle_sep3x3_1(h_2) + self.middle_sep3x3_2(bottom1) 123 | 124 | top2 = self.top2_sep5x5_1(top1) + self.top2_sep5x5_2(middle) 125 | bottom2 = self.bottom2_atr5x5(top2) + self.bottom2_sep5x5(bottom1) 126 | 127 | concat = torch.cat([top1, top2, middle, bottom2, bottom1], dim=1) 128 | 129 | return self.concate_conv(concat) 130 | 131 | 132 | class ASPP(nn.Module): 133 | def __init__(self, in_channels, out_channels, paddings, dilations): 134 | # todo depthwise separable conv 135 | super(ASPP, self).__init__() 136 | self.conv11 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False, ), 137 | nn.BatchNorm2d(256)) 138 | self.conv33_1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 139 | padding=paddings[0], dilation=dilations[0], bias=False, ), 140 | nn.BatchNorm2d(256)) 141 | self.conv33_2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 142 | padding=paddings[1], dilation=dilations[1], bias=False, ), 143 | nn.BatchNorm2d(256)) 144 | self.conv33_3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 145 | padding=paddings[2], dilation=dilations[2], bias=False, ), 146 | nn.BatchNorm2d(256)) 147 | self.concate_conv = nn.Sequential(nn.Conv2d(out_channels * 5, out_channels, 1, bias=False), 148 | nn.BatchNorm2d(256)) 149 | # self.upsample = nn.Upsample(mode='bilinear', align_corners=True) 150 | 151 | def forward(self, x): 152 | conv11 = self.conv11(x) 153 | conv33_1 = self.conv33_1(x) 154 | conv33_2 = self.conv33_2(x) 155 | conv33_3 = self.conv33_3(x) 156 | 157 | # image pool and upsample 158 | image_pool = nn.AvgPool2d(kernel_size=x.size()[2:]) 159 | image_pool = image_pool(x) 160 | image_pool = self.conv11(image_pool) 161 | upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True) 162 | upsample = upsample(image_pool) 163 | 164 | # concate 165 | concate = torch.cat([conv11, conv33_1, conv33_2, conv33_3, upsample], dim=1) 166 | 167 | return self.concate_conv(concate) 168 | 169 | 170 | # Based on quark0/darts on github 171 | class FactorizedReduce(nn.Module): 172 | 173 | def __init__(self, C_in, C_out, affine=True): 174 | super(FactorizedReduce, self).__init__() 175 | assert C_out % 2 == 0 176 | self.relu = nn.ReLU(inplace=False) 177 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 178 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 179 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 180 | 181 | def forward(self, x): 182 | x = self.relu(x) 183 | padded = F.pad(x, (0, 1, 0, 1), "constant", 0) 184 | path2 = self.conv_2(padded[:, :, 1:, 1:]) 185 | out = torch.cat([self.conv_1(x), path2], dim=1) 186 | out = self.bn(out) 187 | return out 188 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import layers 4 | 5 | 6 | 7 | class AutoDeeplab(nn.Module): 8 | def __init__(self, in_channels, out_channels, layout, cell=layers.Cell, activation=nn.ReLU6, upsample_at_end=True): 9 | """ 10 | A general implementation of the network architecture presented in the Auto Deeplab paper 11 | :param layout: A list of integers representing the y coordinate of a cell in the diagram used in the paper (zero-indexed) 12 | :param cell: The cell class to use. 13 | """ 14 | super(AutoDeeplab, self).__init__() 15 | self.upsample_at_end = upsample_at_end 16 | self.cells = [] 17 | 18 | self.initial_stem = nn.Sequential( 19 | nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), 20 | nn.BatchNorm2d(64), 21 | activation() 22 | ).cuda() 23 | 24 | self.cells.append(nn.Sequential( 25 | nn.Conv2d(64, 64, 3, padding=1), 26 | nn.BatchNorm2d(64), 27 | activation() 28 | ).cuda()) 29 | 30 | self.cells.append(nn.Sequential( 31 | nn.Conv2d(64, 128, 3, stride=2, padding=1), 32 | nn.BatchNorm2d(128), 33 | activation() 34 | ).cuda()) 35 | 36 | 37 | #self.stem = nn.Sequential( 38 | # nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), 39 | # nn.Conv2d(64, 64, 3, padding=1), 40 | # nn.Conv2d(64, 128, 3, stride=2, padding=1), 41 | #).cuda() 42 | 43 | prev_channels = 64 44 | channels = 128 45 | assert layout[0] == 2 46 | for i, depth in enumerate(layout): 47 | curr_cell = cell(channels, prev_channels, channels).cuda() 48 | prev_channels = channels 49 | layer = [] 50 | # todo dilation? 51 | 52 | if i != len(layout) - 1: 53 | next_depth = layout[i + 1] 54 | assert abs(depth - next_depth) <= 1 55 | if next_depth > depth: 56 | # Downsampling 57 | layer.append(nn.Conv2d(channels, channels * 2, 3, stride=2, padding=1)) 58 | channels = channels * 2 59 | elif next_depth < depth: 60 | # Upsampling 61 | layer.append(nn.Upsample(scale_factor=2, mode="bilinear")) 62 | layer.append(nn.Conv2d(channels, channels // 2, 1)) 63 | channels = channels // 2 64 | 65 | # The cell is held outside the Sequential as it needs two arguments, while Sequential only accepts one 66 | self.cells.append((curr_cell, nn.Sequential(*layer).cuda())) 67 | 68 | # Pool, then reduce channels to the desired value 69 | self.pool = nn.Sequential( 70 | layers.ASPP(channels, 256, (6, 12, 18), (6, 12, 18)), 71 | nn.Conv2d(256, out_channels, 3, padding=1) 72 | ).cuda() 73 | 74 | self.upsampler = nn.Upsample(scale_factor=2 ** layout[-1], mode="bilinear") 75 | 76 | def forward(self, x): 77 | x = self.initial_stem(x) 78 | 79 | # Run stem layers 80 | prev_hs = [self.cells[0](x)] 81 | prev_hs.append(self.cells[1](prev_hs[0])) 82 | 83 | for i, layer in enumerate(self.cells[2:], 2): 84 | curr = layer[0](prev_hs[-1], prev_hs[-2]) # Execute cell 85 | curr = layer[1](curr) # Execute rest of the layer 86 | prev_hs[-2] = prev_hs[-1] 87 | prev_hs[-1] = curr 88 | 89 | x = self.pool(prev_hs[-1]) 90 | if self.upsample_at_end: 91 | x = self.upsampler(x) 92 | 93 | return x 94 | 95 | 96 | if __name__ == '__main__': 97 | layout = [2, 2, 2, 2, 3, 4, 3, 4, 4, 5, 5, 4, 3] 98 | model = AutoDeeplab(3, 3, layout, layers.Cell) 99 | print(model) 100 | print(model.cells) 101 | x = torch.rand((2, 3, 128, 128)).cuda() 102 | model(x) 103 | --------------------------------------------------------------------------------