├── LICENSE ├── README.md ├── backbone ├── __init__.py ├── activation.py ├── microconfig.py └── micronet.py ├── main.py ├── scripts ├── eval_micronet_m0.sh ├── eval_micronet_m1.sh ├── eval_micronet_m2.sh ├── eval_micronet_m3.sh ├── train_micronet_m0_2gpu.sh ├── train_micronet_m0_4gpu.sh ├── train_micronet_m1_2gpu.sh ├── train_micronet_m1_4gpu.sh ├── train_micronet_m2_2gpu.sh ├── train_micronet_m2_4gpu.sh └── train_micronet_m3_4gpu.sh └── utils ├── __init__.py ├── dataloaders.py ├── defaults.py ├── eval.py ├── imagenet.py ├── larc.py ├── logger.py ├── misc.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 The Regents of the University of California. 4 | All Rights Reserved. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MicroNet: Improving Image Recognition with Extremely Low FLOPs (ICCV 2021) 2 | A [pytorch](http://pytorch.org/) implementation of [MicroNet](https://arxiv.org/abs/2108.05894). 3 | If you use this code in your research please consider citing 4 | >@article{li2021micronet, 5 | title={MicroNet: Improving Image Recognition with Extremely Low FLOPs}, 6 | author={Li, Yunsheng and Chen, Yinpeng and Dai, Xiyang and Chen, Dongdong and Liu, Mengchen and Yuan, Lu and Liu, Zicheng and Zhang, Lei and Vasconcelos, Nuno}, 7 | journal={arXiv preprint arXiv:2108.05894}, 8 | year={2021} 9 | } 10 | ## Requirements 11 | 12 | - Linux or macOS with Python ≥ 3.6. 13 | - *Anaconda3*, *PyTorch ≥ 1.5* with matched [torchvision](https://github.com/pytorch/vision/) 14 | 15 | ## Models 16 | Model | #Param | MAdds | Top-1 | download 17 | --- |:---:|:---:|:---:|:---: 18 | MicroNet-M3 | 2.6M | 21M | 62.5 | [model](http://www.svcl.ucsd.edu/projects/micronet/assets/micronet-m3.pth) 19 | MicroNet-M2 | 2.4M | 12M | 59.4 | [model](http://www.svcl.ucsd.edu/projects/micronet/assets/micronet-m2.pth) 20 | MicroNet-M1 | 1.8M | 6M | 51.4 | [model](http://www.svcl.ucsd.edu/projects/micronet/assets/micronet-m1.pth) 21 | MicroNet-M0 | 1.0M | 4M | 46.6 | [model](http://www.svcl.ucsd.edu/projects/micronet/assets/micronet-m0.pth) 22 | 23 | ## Evaluate MicroNet on ImageNet 24 | 25 | Download the pretrained MicroNet M0-M3 with the link above. The scripts used for evaluation can be found [here](script). For example, if you want to test MicroNet-M3, you can use the following command. 26 | 27 | ``` 28 | sh scripts/eval_micronet_m3.sh /path/to/imagenet /path/to/output /path/to/pretrained_model 29 | ``` 30 | 31 | ## Train MicroNet on ImageNet 32 | 33 | The scripts used for training MicroNet M0-M3 can be found [here](script) and can be implemented as follows (You can choose to use different scripts for 2 gpu or 4 gpu training based on the resources you can access). 34 | ``` 35 | sh scripts/train_micronet_m3_4gpu.sh /path/to/imagenet /path/to/output 36 | ``` 37 | -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .micronet import * 2 | -------------------------------------------------------------------------------- /backbone/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def _make_divisible(v, divisor, min_value=None): 6 | """ 7 | This function is taken from the original tf repo. 8 | It ensures that all layers have a channel number that is divisible by 8 9 | It can be seen here: 10 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 11 | :param v: 12 | :param divisor: 13 | :param min_value: 14 | :return: 15 | """ 16 | if min_value is None: 17 | min_value = divisor 18 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 19 | # Make sure that round down does not go down by more than 10%. 20 | if new_v < 0.9 * v: 21 | new_v += divisor 22 | return new_v 23 | 24 | 25 | ######################################################################## 26 | # sigmoid and tanh 27 | ######################################################################## 28 | # h_sigmoid (x: [-3 3], y: [0, h_max]] 29 | class h_sigmoid(nn.Module): 30 | def __init__(self, inplace=True, h_max=1): 31 | super(h_sigmoid, self).__init__() 32 | self.relu = nn.ReLU6(inplace=inplace) 33 | self.h_max = h_max / 6 34 | 35 | def forward(self, x): 36 | return self.relu(x + 3) * self.h_max 37 | 38 | # h_tanh x: [-3, 3], y: [-h_max, h_max] 39 | class h_tanh(nn.Module): 40 | def __init__(self, inplace=True, h_max=1): 41 | super(h_tanh, self).__init__() 42 | self.relu = nn.ReLU6(inplace=inplace) 43 | self.h_max = h_max 44 | 45 | def forward(self, x): 46 | return self.relu(x + 3)*self.h_max / 3 - self.h_max 47 | 48 | 49 | ######################################################################## 50 | # wrap functions 51 | ######################################################################## 52 | 53 | def get_act_layer(inp, oup, mode='SE1', act_relu=True, act_max=2, act_bias=True, init_a=[1.0, 0.0], reduction=4, init_b=[0.0, 0.0], g=None, act='relu', expansion=True): 54 | layer = None 55 | if mode == 'SE1': 56 | layer = nn.Sequential( 57 | SELayer(inp, oup, reduction=reduction), 58 | nn.ReLU6(inplace=True) if act_relu else nn.Sequential() 59 | ) 60 | elif mode == 'SE0': 61 | layer = nn.Sequential( 62 | SELayer(inp, oup, reduction=reduction), 63 | ) 64 | elif mode == 'NA': 65 | layer = nn.ReLU6(inplace=True) if act_relu else nn.Sequential() 66 | elif mode == 'LeakyReLU': 67 | layer = nn.LeakyReLU(inplace=True) if act_relu else nn.Sequential() 68 | elif mode == 'RReLU': 69 | layer = nn.RReLU(inplace=True) if act_relu else nn.Sequential() 70 | elif mode == 'PReLU': 71 | layer = nn.PReLU() if act_relu else nn.Sequential() 72 | elif mode == 'DYShiftMax': 73 | layer = DYShiftMax(inp, oup, act_max=act_max, act_relu=act_relu, init_a=init_a, reduction=reduction, init_b=init_b, g=g, expansion=expansion) 74 | return layer 75 | 76 | ######################################################################## 77 | # dynamic activation layers (SE, DYShiftMax, etc) 78 | ######################################################################## 79 | 80 | class SELayer(nn.Module): 81 | def __init__(self, inp, oup, reduction=4): 82 | super(SELayer, self).__init__() 83 | self.oup = oup 84 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 85 | 86 | # determine squeeze 87 | squeeze = get_squeeze_channels(inp, reduction) 88 | print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze)) 89 | 90 | 91 | self.fc = nn.Sequential( 92 | nn.Linear(inp, squeeze), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(squeeze, oup), 95 | h_sigmoid() 96 | ) 97 | 98 | def forward(self, x): 99 | if isinstance(x, list): 100 | x_in = x[0] 101 | x_out = x[1] 102 | else: 103 | x_in = x 104 | x_out = x 105 | b, c, _, _ = x_in.size() 106 | y = self.avg_pool(x_in).view(b, c) 107 | y = self.fc(y).view(b, self.oup, 1, 1) 108 | return x_out * y 109 | 110 | class DYShiftMax(nn.Module): 111 | def __init__(self, inp, oup, reduction=4, act_max=1.0, act_relu=True, init_a=[0.0, 0.0], init_b=[0.0, 0.0], relu_before_pool=False, g=None, expansion=False): 112 | super(DYShiftMax, self).__init__() 113 | self.oup = oup 114 | self.act_max = act_max * 2 115 | self.act_relu = act_relu 116 | self.avg_pool = nn.Sequential( 117 | nn.ReLU(inplace=True) if relu_before_pool == True else nn.Sequential(), 118 | nn.AdaptiveAvgPool2d(1) 119 | ) 120 | 121 | self.exp = 4 if act_relu else 2 122 | self.init_a = init_a 123 | self.init_b = init_b 124 | 125 | # determine squeeze 126 | squeeze = _make_divisible(inp // reduction, 4) 127 | if squeeze < 4: 128 | squeeze = 4 129 | print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze)) 130 | print('init-a: {}, init-b: {}'.format(init_a, init_b)) 131 | 132 | self.fc = nn.Sequential( 133 | nn.Linear(inp, squeeze), 134 | nn.ReLU(inplace=True), 135 | nn.Linear(squeeze, oup*self.exp), 136 | h_sigmoid() 137 | ) 138 | if g is None: 139 | g = 1 140 | self.g = g[1] 141 | if self.g !=1 and expansion: 142 | self.g = inp // self.g 143 | print('group shuffle: {}, divide group: {}'.format(self.g, expansion)) 144 | self.gc = inp//self.g 145 | index=torch.Tensor(range(inp)).view(1,inp,1,1) 146 | index=index.view(1,self.g,self.gc,1,1) 147 | indexgs = torch.split(index, [1, self.g-1], dim=1) 148 | indexgs = torch.cat((indexgs[1], indexgs[0]), dim=1) 149 | indexs = torch.split(indexgs, [1, self.gc-1], dim=2) 150 | indexs = torch.cat((indexs[1], indexs[0]), dim=2) 151 | self.index = indexs.view(inp).type(torch.LongTensor) 152 | self.expansion = expansion 153 | 154 | def forward(self, x): 155 | x_in = x 156 | x_out = x 157 | 158 | b, c, _, _ = x_in.size() 159 | y = self.avg_pool(x_in).view(b, c) 160 | y = self.fc(y).view(b, self.oup*self.exp, 1, 1) 161 | y = (y-0.5) * self.act_max 162 | 163 | n2, c2, h2, w2 = x_out.size() 164 | x2 = x_out[:,self.index,:,:] 165 | 166 | if self.exp == 4: 167 | a1, b1, a2, b2 = torch.split(y, self.oup, dim=1) 168 | 169 | a1 = a1 + self.init_a[0] 170 | a2 = a2 + self.init_a[1] 171 | 172 | b1 = b1 + self.init_b[0] 173 | b2 = b2 + self.init_b[1] 174 | 175 | z1 = x_out * a1 + x2 * b1 176 | z2 = x_out * a2 + x2 * b2 177 | 178 | out = torch.max(z1, z2) 179 | 180 | elif self.exp == 2: 181 | a1, b1 = torch.split(y, self.oup, dim=1) 182 | a1 = a1 + self.init_a[0] 183 | b1 = b1 + self.init_b[0] 184 | out = x_out * a1 + x2 * b1 185 | 186 | return out 187 | 188 | def get_squeeze_channels(inp, reduction): 189 | if reduction == 4: 190 | squeeze = inp // reduction 191 | else: 192 | squeeze = _make_divisible(inp // reduction, 4) 193 | return squeeze 194 | -------------------------------------------------------------------------------- /backbone/microconfig.py: -------------------------------------------------------------------------------- 1 | msnx_dy6_exp4_4M_221_cfgs = [ 2 | #s, n, c, ks, c1, c2, g1, g2, c3, g3, g4,y1,y2,y3,r 3 | [2, 1, 8, 3, 2, 2, 0, 4, 8, 2, 2, 2, 0, 1, 1], #6->12(0, 0)->24 ->8(4,2)->8 4 | [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 2, 1, 1], #8->16(0, 0)->32 ->16(4,4)->12 5 | [2, 1, 16, 5, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1], #16->32(0, 0)->64 ->16(8,2)->16 6 | [1, 1, 32, 5, 1, 4, 4, 4, 32, 4, 4, 2, 2, 1, 1], #16->16(2,8)->96 ->32(8,4)->32 7 | [2, 1, 64, 5, 1, 4, 8, 8, 64, 8, 8, 2, 2, 1, 1], #32->32(2,16)->192 ->64(12,4)->64 8 | [1, 1, 96, 3, 1, 4, 8, 8, 96, 8, 8, 2, 2, 1, 2], #64->64(3,16)->384 ->96(16,6)->96 9 | [1, 1, 384, 3, 1, 4, 12, 12, 0, 0, 0, 2, 2, 1, 2], #96->96(4,24)->384 10 | ] 11 | msnx_dy6_exp6_6M_221_cfgs = [ 12 | #s, n, c, ks, c1, c2, g1, g2, c3, g3, g4 13 | [2, 1, 8, 3, 2, 2, 0, 6, 8, 2, 2, 2, 0, 1, 1], #6->12(0, 0)->24 ->8(4,2)->8 14 | [2, 1, 16, 3, 2, 2, 0, 8, 16, 4, 4, 2, 2, 1, 1], #8->16(0, 0)->32 ->16(4,4)->16 15 | [2, 1, 16, 5, 2, 2, 0, 16, 16, 4, 4, 2, 2, 1, 1], #16->32(0, 0)->64 ->16(8,2)->16 16 | [1, 1, 32, 5, 1, 6, 4, 4, 32, 4, 4, 2, 2, 1, 1], #16->16(2,8)->96 ->32(8,4)->32 17 | [2, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 1], #32->32(2,16)->192 ->64(12,4)->64 18 | [1, 1, 96, 3, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2], #64->64(3,16)->384 ->96(16,6)->96 19 | [1, 1, 576, 3, 1, 6, 12, 12, 0, 0, 0, 2, 2, 1, 2], #96->96(4,24)->576 20 | ] 21 | msnx_dy9_exp6_12M_221_cfgs = [ 22 | #s, n, c, ks, c1, c2, g1, g2, c3, g3, g4 23 | [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 0, 1, 1], #8->16(0, 0)->32 ->12(4,3)->12 24 | [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1], #12->24(0,0)->48 ->16(8, 2)->16 25 | [1, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 2, 2, 1, 1], #16->16(0, 0)->64 ->24(8,3)->24 26 | [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 2, 2, 1, 1], #24->24(2, 12)->144 ->32(16,2)->32 27 | [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 2, 2, 1, 2], #32->32(2,16)->192 ->32(16,2)->32 28 | [1, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 2], #32->32(2,16)->192 ->64(12,4)->64 29 | [2, 1, 96, 5, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2], #64->64(4,12)->384 ->96(16,5)->96 30 | [1, 1, 128, 3, 1, 6, 12, 12, 128, 8, 8, 2, 2, 1, 2], #96->96(5,16)->576->128(16,8)->128 31 | [1, 1, 768, 3, 1, 6, 16, 16, 0, 0, 0, 2, 2, 1, 2], #128->128(4,32)->768 32 | ] 33 | msnx_dy12_exp6_20M_020_cfgs = [ 34 | #s, n, c, ks, c1, c2, g1, g2, c3, g3, g4 35 | [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 0, 2, 0, 1], #12->24(0, 0)->48 ->16(8,2)->16 36 | [2, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 0, 2, 0, 1], #16->32(0, 0)->64 ->24(8,3)->24 37 | [1, 1, 24, 3, 2, 2, 0, 24, 24, 4, 4, 0, 2, 0, 1], #24->48(0, 0)->96 ->24(8,3)->24 38 | [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 0, 2, 0, 1], #24->24(2,12)->144 ->32(16,2)->32 39 | [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 0, 2, 0, 2], #32->32(2,16)->192 ->32(16,2)->32 40 | [1, 1, 64, 5, 1, 6, 8, 8, 48, 8, 8, 0, 2, 0, 2], #32->32(2,16)->192 ->48(12,4)->64 41 | [1, 1, 80, 5, 1, 6, 8, 8, 80, 8, 8, 0, 2, 0, 2], #48->48(3,16)->288 ->80(16,5)->80 42 | [1, 1, 80, 5, 1, 6, 10, 10, 80, 8, 8, 0, 2, 0, 2], #80->80(4,20)->480->80(20,4)->80 43 | [2, 1, 120, 5, 1, 6, 10, 10, 120, 10, 10, 0, 2, 0, 2], #80->80(4,20)->480->128(16,8)->120 44 | [1, 1, 120, 5, 1, 6, 12, 12, 120, 10, 10, 0, 2, 0, 2], #120->128(4,32)->720->128(32,4)->120 45 | [1, 1, 144, 3, 1, 6, 12, 12, 144, 12, 12, 0, 2, 0, 2], #120->128(4,32)->720->160(32,5)->144 46 | [1, 1, 864, 3, 1, 6, 12, 12, 0, 0, 0, 0, 2, 0, 2], #144->144(5,32)->864 47 | ] 48 | 49 | def get_micronet_config(mode): 50 | return eval(mode+'_cfgs') 51 | -------------------------------------------------------------------------------- /backbone/micronet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import backbone.activation as activation 7 | import backbone.microconfig as microcfg 8 | 9 | 10 | import math 11 | import pdb 12 | 13 | __all__ = ['MicroNet', 'micronet'] 14 | 15 | TAU = 20 16 | #####################################################################3 17 | # part 1: functions 18 | #####################################################################3 19 | 20 | def _make_divisible(v, divisor, min_value=None): 21 | """ 22 | This function is taken from the original tf repo. 23 | It ensures that all layers have a channel number that is divisible by 8 24 | It can be seen here: 25 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 26 | :param v: 27 | :param divisor: 28 | :param min_value: 29 | :return: 30 | """ 31 | if min_value is None: 32 | min_value = divisor 33 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 34 | # Make sure that round down does not go down by more than 10%. 35 | if new_v < 0.9 * v: 36 | new_v += divisor 37 | return new_v 38 | 39 | class h_sigmoid(nn.Module): 40 | def __init__(self, inplace=True): 41 | super(h_sigmoid, self).__init__() 42 | self.relu = nn.ReLU6(inplace=inplace) 43 | 44 | def forward(self, x): 45 | return self.relu(x + 3) / 6 46 | 47 | 48 | class h_swish(nn.Module): 49 | def __init__(self, inplace=True): 50 | super(h_swish, self).__init__() 51 | self.sigmoid = h_sigmoid(inplace=inplace) 52 | 53 | def forward(self, x): 54 | return x * self.sigmoid(x) 55 | 56 | def conv_3x3_bn(inp, oup, stride, dilation=1): 57 | return nn.Sequential( 58 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False, dilation=dilation), 59 | nn.BatchNorm2d(oup), 60 | nn.ReLU6(inplace=True) 61 | ) 62 | 63 | def conv_1x1_bn(inp, oup): 64 | return nn.Sequential( 65 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 66 | nn.BatchNorm2d(oup), 67 | nn.ReLU6(inplace=True) 68 | ) 69 | 70 | def gcd(a, b): 71 | a, b = (a, b) if a >= b else (b, a) 72 | while b: 73 | a, b = b, a%b 74 | return a 75 | 76 | #####################################################################3 77 | # part 2: modules 78 | #####################################################################3 79 | 80 | class MaxGroupPooling(nn.Module): 81 | def __init__(self, channel_per_group=2): 82 | super(MaxGroupPooling, self).__init__() 83 | self.channel_per_group = channel_per_group 84 | 85 | def forward(self, x): 86 | if self.channel_per_group == 1: 87 | return x 88 | # max op 89 | b, c, h, w = x.size() 90 | 91 | # reshape 92 | y = x.view(b, c // self.channel_per_group, -1, h, w) 93 | out, _ = torch.max(y, dim=2) 94 | return out 95 | 96 | class SwishLinear(nn.Module): 97 | def __init__(self, inp, oup): 98 | super(SwishLinear, self).__init__() 99 | self.linear = nn.Sequential( 100 | nn.Linear(inp, oup), 101 | nn.BatchNorm1d(oup), 102 | h_swish() 103 | ) 104 | 105 | def forward(self, x): 106 | return self.linear(x) 107 | 108 | class StemLayer(nn.Module): 109 | def __init__(self, inp, oup, stride, dilation=1, mode='default', groups=(4,4)): 110 | super(StemLayer, self).__init__() 111 | 112 | self.exp = 1 if mode == 'default' else 2 113 | g1, g2 = groups 114 | if mode == 'default': 115 | self.stem = nn.Sequential( 116 | nn.Conv2d(inp, oup*self.exp, 3, stride, 1, bias=False, dilation=dilation), 117 | nn.BatchNorm2d(oup*self.exp), 118 | nn.ReLU6(inplace=True) if self.exp == 1 else MaxGroupPooling(self.exp) 119 | ) 120 | elif mode == 'spatialsepsf': 121 | self.stem = nn.Sequential( 122 | SpatialSepConvSF(inp, groups, 3, stride), 123 | MaxGroupPooling(2) if g1*g2==2*oup else nn.ReLU6(inplace=True) 124 | ) 125 | else: 126 | raise ValueError('Undefined stem layer') 127 | 128 | def forward(self, x): 129 | out = self.stem(x) 130 | return out 131 | 132 | class GroupConv(nn.Module): 133 | def __init__(self, inp, oup, groups=2): 134 | super(GroupConv, self).__init__() 135 | self.inp = inp 136 | self.oup = oup 137 | self.groups = groups 138 | print ('inp: %d, oup:%d, g:%d' %(inp, oup, self.groups[0])) 139 | self.conv = nn.Sequential( 140 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False, groups=self.groups[0]), 141 | nn.BatchNorm2d(oup) 142 | ) 143 | 144 | def forward(self, x): 145 | x = self.conv(x) 146 | return x 147 | 148 | class ChannelShuffle(nn.Module): 149 | def __init__(self, groups): 150 | super(ChannelShuffle, self).__init__() 151 | self.groups = groups 152 | 153 | def forward(self, x): 154 | b, c, h, w = x.size() 155 | 156 | channels_per_group = c // self.groups 157 | 158 | # reshape 159 | x = x.view(b, self.groups, channels_per_group, h, w) 160 | 161 | x = torch.transpose(x, 1, 2).contiguous() 162 | out = x.view(b, -1, h, w) 163 | 164 | return out 165 | 166 | class ChannelShuffle2(nn.Module): 167 | def __init__(self, groups): 168 | super(ChannelShuffle2, self).__init__() 169 | self.groups = groups 170 | 171 | def forward(self, x): 172 | b, c, h, w = x.size() 173 | 174 | channels_per_group = c // self.groups 175 | 176 | # reshape 177 | x = x.view(b, self.groups, channels_per_group, h, w) 178 | 179 | x = torch.transpose(x, 1, 2).contiguous() 180 | out = x.view(b, -1, h, w) 181 | 182 | return out 183 | 184 | ######################################################################3 185 | # part 3: new block 186 | #####################################################################3 187 | 188 | class SpatialSepConvSF(nn.Module): 189 | def __init__(self, inp, oups, kernel_size, stride): 190 | super(SpatialSepConvSF, self).__init__() 191 | 192 | oup1, oup2 = oups 193 | self.conv = nn.Sequential( 194 | nn.Conv2d(inp, oup1, 195 | (kernel_size, 1), 196 | (stride, 1), 197 | (kernel_size//2, 0), 198 | bias=False, groups=1 199 | ), 200 | nn.BatchNorm2d(oup1), 201 | nn.Conv2d(oup1, oup1*oup2, 202 | (1, kernel_size), 203 | (1, stride), 204 | (0, kernel_size//2), 205 | bias=False, groups=oup1 206 | ), 207 | nn.BatchNorm2d(oup1*oup2), 208 | ChannelShuffle(oup1), 209 | ) 210 | 211 | def forward(self, x): 212 | out = self.conv(x) 213 | return out 214 | 215 | class DepthConv(nn.Module): 216 | def __init__(self, inp, oup, kernel_size, stride): 217 | super(DepthConv, self).__init__() 218 | self.conv = nn.Sequential( 219 | nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False, groups=inp), 220 | nn.BatchNorm2d(oup) 221 | ) 222 | 223 | def forward(self, x): 224 | out = self.conv(x) 225 | return out 226 | 227 | class DepthSpatialSepConv(nn.Module): 228 | def __init__(self, inp, expand, kernel_size, stride): 229 | super(DepthSpatialSepConv, self).__init__() 230 | 231 | exp1, exp2 = expand 232 | 233 | hidden_dim = inp*exp1 234 | oup = inp*exp1*exp2 235 | 236 | self.conv = nn.Sequential( 237 | nn.Conv2d(inp, inp*exp1, 238 | (kernel_size, 1), 239 | (stride, 1), 240 | (kernel_size//2, 0), 241 | bias=False, groups=inp 242 | ), 243 | nn.BatchNorm2d(inp*exp1), 244 | nn.Conv2d(hidden_dim, oup, 245 | (1, kernel_size), 246 | (1, stride), 247 | (0, kernel_size//2), 248 | bias=False, groups=hidden_dim 249 | ), 250 | nn.BatchNorm2d(oup) 251 | ) 252 | 253 | def forward(self, x): 254 | out = self.conv(x) 255 | return out 256 | 257 | def get_pointwise_conv(mode, inp, oup, hiddendim, groups): 258 | 259 | if mode == 'group': 260 | return GroupConv(inp, oup, groups) 261 | elif mode == '1x1': 262 | return nn.Sequential( 263 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 264 | nn.BatchNorm2d(oup) 265 | ) 266 | else: 267 | return None 268 | 269 | class DYMicroBlock(nn.Module): 270 | def __init__(self, inp, oup, kernel_size=3, stride=1, ch_exp=(2, 2), ch_per_group=4, groups_1x1=(1, 1), depthsep=True, shuffle=False, pointwise='fft', activation_cfg=None): 271 | super(DYMicroBlock, self).__init__() 272 | 273 | print(activation_cfg.dy) 274 | 275 | self.identity = stride == 1 and inp == oup 276 | 277 | y1, y2, y3 = activation_cfg.dy 278 | act = activation_cfg.MODULE 279 | act_max = activation_cfg.ACT_MAX 280 | act_bias = activation_cfg.LINEARSE_BIAS 281 | act_reduction = activation_cfg.REDUCTION * activation_cfg.ratio 282 | init_a = activation_cfg.INIT_A 283 | init_b = activation_cfg.INIT_B 284 | init_ab3 = activation_cfg.INIT_A_BLOCK3 285 | 286 | t1 = ch_exp 287 | gs1 = ch_per_group 288 | hidden_fft, g1, g2 = groups_1x1 289 | 290 | hidden_dim1 = inp * t1[0] 291 | hidden_dim2 = inp * t1[0] * t1[1] 292 | 293 | if gs1[0] == 0: 294 | self.layers = nn.Sequential( 295 | DepthSpatialSepConv(inp, t1, kernel_size, stride), 296 | activation.get_act_layer( 297 | hidden_dim2, 298 | hidden_dim2, 299 | mode=act, 300 | act_max=act_max, 301 | act_relu=True if y2 == 2 else False, 302 | act_bias=act_bias, 303 | init_a=init_a, 304 | reduction=act_reduction, 305 | init_b=init_b, 306 | g = gs1, 307 | expansion = False 308 | ) if y2 > 0 else nn.ReLU6(inplace=True), 309 | ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(), 310 | ChannelShuffle2(hidden_dim2//2) if shuffle and y2 !=0 else nn.Sequential(), 311 | get_pointwise_conv(pointwise, hidden_dim2, oup, hidden_fft, (g1, g2)), 312 | activation.get_act_layer( 313 | oup, 314 | oup, 315 | mode=act, 316 | act_max=act_max, 317 | act_relu=False, 318 | act_bias=act_bias, 319 | init_a=[init_ab3[0], 0.0], 320 | reduction=act_reduction//2, 321 | init_b=[init_ab3[1], 0.0], 322 | g = (g1, g2), 323 | expansion = False 324 | ) if y3 > 0 else nn.Sequential(), 325 | ChannelShuffle(g2) if shuffle else nn.Sequential(), 326 | ChannelShuffle2(oup//2) if shuffle and oup%2 == 0 and y3!=0 else nn.Sequential(), 327 | ) 328 | elif g2 == 0: 329 | self.layers = nn.Sequential( 330 | get_pointwise_conv(pointwise, inp, hidden_dim2, hidden_dim1, gs1), 331 | activation.get_act_layer( 332 | hidden_dim2, 333 | hidden_dim2, 334 | mode=act, 335 | act_max=act_max, 336 | act_relu=False, 337 | act_bias=act_bias, 338 | init_a=[init_ab3[0], 0.0], 339 | reduction=act_reduction, 340 | init_b=[init_ab3[1], 0.0], 341 | g = gs1, 342 | expansion = False 343 | ) if y3 > 0 else nn.Sequential(), 344 | 345 | ) 346 | 347 | else: 348 | self.layers = nn.Sequential( 349 | get_pointwise_conv(pointwise, inp, hidden_dim2, hidden_dim1, gs1), 350 | activation.get_act_layer( 351 | hidden_dim2, 352 | hidden_dim2, 353 | mode=act, 354 | act_max=act_max, 355 | act_relu=True if y1 == 2 else False, 356 | act_bias=act_bias, 357 | init_a=init_a, 358 | reduction=act_reduction, 359 | init_b=init_b, 360 | g = gs1, 361 | expansion = False 362 | ) if y1 > 0 else nn.ReLU6(inplace=True), 363 | ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(), 364 | DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride) if depthsep else 365 | DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride), 366 | nn.Sequential(), 367 | activation.get_act_layer( 368 | hidden_dim2, 369 | hidden_dim2, 370 | mode=act, 371 | act_max=act_max, 372 | act_relu=True if y2 == 2 else False, 373 | act_bias=act_bias, 374 | init_a=init_a, 375 | reduction=act_reduction, 376 | init_b=init_b, 377 | g = gs1, 378 | expansion = True 379 | ) if y2 > 0 else nn.ReLU6(inplace=True), 380 | ChannelShuffle2(hidden_dim2//4) if shuffle and y1!=0 and y2 !=0 else nn.Sequential() if y1==0 and y2==0 else ChannelShuffle2(hidden_dim2//2), 381 | get_pointwise_conv(pointwise, hidden_dim2, oup, hidden_fft, (g1, g2)), #FFTConv 382 | activation.get_act_layer( 383 | oup, 384 | oup, 385 | mode=act, 386 | act_max=act_max, 387 | act_relu=False, 388 | act_bias=act_bias, 389 | init_a=[init_ab3[0], 0.0], 390 | reduction=act_reduction//2 if oup < hidden_dim2 else act_reduction, 391 | init_b=[init_ab3[1], 0.0], 392 | g = (g1, g2), 393 | expansion = False 394 | ) if y3 > 0 else nn.Sequential(), 395 | ChannelShuffle(g2) if shuffle else nn.Sequential(), 396 | ChannelShuffle2(oup//2) if shuffle and y3!=0 else nn.Sequential(), 397 | ) 398 | 399 | def forward(self, x): 400 | identity = x 401 | out = self.layers(x) 402 | 403 | if self.identity: 404 | out = out + identity 405 | 406 | return out 407 | 408 | ########################################################################### 409 | 410 | class MicroNet(nn.Module): 411 | def __init__(self, cfg, input_size=224, num_classes=1000, teacher=False): 412 | super(MicroNet, self).__init__() 413 | 414 | mode = cfg.MODEL.MICRONETS.NET_CONFIG 415 | self.cfgs = microcfg.get_micronet_config(mode) 416 | 417 | block = eval(cfg.MODEL.MICRONETS.BLOCK) 418 | stem_mode = cfg.MODEL.MICRONETS.STEM_MODE 419 | stem_ch = cfg.MODEL.MICRONETS.STEM_CH 420 | stem_dilation = cfg.MODEL.MICRONETS.STEM_DILATION 421 | stem_groups = cfg.MODEL.MICRONETS.STEM_GROUPS 422 | out_ch = cfg.MODEL.MICRONETS.OUT_CH 423 | depthsep = cfg.MODEL.MICRONETS.DEPTHSEP 424 | shuffle = cfg.MODEL.MICRONETS.SHUFFLE 425 | pointwise = cfg.MODEL.MICRONETS.POINTWISE 426 | dropout_rate = cfg.MODEL.MICRONETS.DROPOUT 427 | 428 | act_max = cfg.MODEL.ACTIVATION.ACT_MAX 429 | act_bias = cfg.MODEL.ACTIVATION.LINEARSE_BIAS 430 | activation_cfg= cfg.MODEL.ACTIVATION 431 | 432 | # building first layer 433 | assert input_size % 32 == 0 434 | input_channel = stem_ch 435 | layers = [StemLayer( 436 | 3, input_channel, 437 | stride=2, 438 | dilation=stem_dilation, 439 | mode=stem_mode, 440 | groups=stem_groups 441 | )] 442 | 443 | for idx, val in enumerate(self.cfgs): 444 | s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r = val 445 | 446 | t1 = (c1, c2) 447 | gs1 = (g1, g2) 448 | gs2 = (c3, g3, g4) 449 | activation_cfg.dy = [y1, y2, y3] 450 | activation_cfg.ratio = r 451 | 452 | output_channel = c 453 | layers.append(block(input_channel, output_channel, 454 | kernel_size=ks, 455 | stride=s, 456 | ch_exp=t1, 457 | ch_per_group=gs1, 458 | groups_1x1=gs2, 459 | depthsep = depthsep, 460 | shuffle = shuffle, 461 | pointwise = pointwise, 462 | activation_cfg=activation_cfg, 463 | )) 464 | input_channel = output_channel 465 | for i in range(1, n): 466 | layers.append(block(input_channel, output_channel, 467 | kernel_size=ks, 468 | stride=1, 469 | ch_exp=t1, 470 | ch_per_group=gs1, 471 | groups_1x1=gs2, 472 | depthsep = depthsep, 473 | shuffle = shuffle, 474 | pointwise = pointwise, 475 | activation_cfg=activation_cfg, 476 | )) 477 | input_channel = output_channel 478 | self.features = nn.Sequential(*layers) 479 | 480 | 481 | self.avgpool = nn.Sequential( 482 | nn.ReLU6(inplace=True), 483 | nn.AdaptiveAvgPool2d((1, 1)), 484 | h_swish() 485 | ) 486 | 487 | # building last several layers 488 | output_channel = out_ch 489 | 490 | self.classifier = nn.Sequential( 491 | SwishLinear(input_channel, output_channel), 492 | nn.Dropout(dropout_rate), 493 | SwishLinear(output_channel, num_classes) 494 | ) 495 | self._initialize_weights() 496 | 497 | def forward(self, x): 498 | x = self.features(x) 499 | x = self.avgpool(x) 500 | 501 | x = x.view(x.size(0), -1) 502 | x = self.classifier(x) 503 | return x 504 | 505 | def _initialize_weights(self): 506 | for m in self.modules(): 507 | if isinstance(m, nn.Conv2d): 508 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 509 | m.weight.data.normal_(0, math.sqrt(2. / n)) 510 | if m.bias is not None: 511 | m.bias.data.zero_() 512 | elif isinstance(m, nn.BatchNorm2d): 513 | m.weight.data.fill_(1) 514 | m.bias.data.zero_() 515 | elif isinstance(m, nn.Linear): 516 | n = m.weight.size(1) 517 | m.weight.data.normal_(0, 0.01) 518 | if m.bias is not None: 519 | m.bias.data.zero_() 520 | 521 | def micronet(**kwargs): 522 | """ 523 | Constructs a MicroNet model 524 | """ 525 | return MicroNet(**kwargs) 526 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import argparse 4 | import os 5 | import sys 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.utils.data 20 | import torch.utils.data.distributed 21 | import torchvision.transforms as transforms 22 | import torchvision.datasets as datasets 23 | import torchvision.models as models 24 | 25 | from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig, cfg, larc 26 | from utils.dataloaders import * 27 | from utils.imagenet import ImageNet 28 | from backbone import * 29 | from tensorboardX import SummaryWriter 30 | 31 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 32 | parser.add_argument('-d', '--data', metavar='DIR', 33 | help='path to dataset') 34 | parser.add_argument('--data-backend', metavar='BACKEND', default='pytorch', 35 | choices=DATA_BACKEND_CHOICES) 36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 37 | help='model architecture') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=256, type=int, 45 | metavar='N', 46 | help='mini-batch size (default: 256), this is the total ' 47 | 'batch size of all GPUs on the current node when ' 48 | 'using Data Parallel or Distributed Data Parallel') 49 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 50 | metavar='LR', help='initial learning rate', dest='lr') 51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 52 | help='momentum') 53 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 54 | metavar='W', help='weight decay (default: 1e-4)', 55 | dest='weight_decay') 56 | parser.add_argument('-p', '--print-freq', default=10, type=int, 57 | metavar='N', help='print frequency (default: 10)') 58 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 59 | help='path to latest checkpoint (default: none)') 60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 61 | help='evaluate model on validation set') 62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 63 | help='use pre-trained model') 64 | parser.add_argument('--world-size', default=-1, type=int, 65 | help='number of nodes for distributed training') 66 | parser.add_argument('--rank', default=-1, type=int, 67 | help='node rank for distributed training') 68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 69 | help='url used to set up distributed training') 70 | parser.add_argument('--dist-backend', default='nccl', type=str, 71 | help='distributed backend') 72 | parser.add_argument('--seed', default=None, type=int, 73 | help='seed for initializing training. ') 74 | 75 | parser.add_argument('--lr-decay', type=str, default='step', 76 | help='mode for learning rate decay') 77 | parser.add_argument('--step', type=int, default=30, 78 | help='interval for learning rate decay in step mode') 79 | parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 90], 80 | help='decrease learning rate at these epochs.') 81 | parser.add_argument('--gamma', type=float, default=0.1, 82 | help='LR is multiplied by gamma on schedule.') 83 | parser.add_argument('--warmup', action='store_true', 84 | help='set lower initial learning rate to warm up the training') 85 | 86 | parser.add_argument('-c', '--checkpoint', default='checkpoints', type=str, metavar='PATH', 87 | help='path to save checkpoint (default: checkpoints)') 88 | 89 | parser.add_argument('--input-size', type=int, default=224, help='MobileNet model input resolution') 90 | parser.add_argument('--weight', default='', type=str, metavar='WEIGHT', 91 | help='path to pretrained weight (default: none)') 92 | parser.add_argument('--label-smoothing', type=float, default=0.1, help='label smoothing') 93 | parser.add_argument('--mixup', type=float, default=0.0, help='mixup or not') 94 | parser.add_argument( 95 | "opts", 96 | help="Modify config options using the command-line", 97 | default=None, 98 | nargs=argparse.REMAINDER, 99 | ) 100 | 101 | args = parser.parse_args() 102 | cfg.merge_from_list(args.opts) 103 | best_prec1 = 0 104 | 105 | def print_options(save_path, opt): 106 | message = '' 107 | message += '----------------- Options ---------------\n' 108 | for k, v in sorted(vars(opt).items()): 109 | comment = '' 110 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 111 | message += '----------------- End -------------------' 112 | #print(message) 113 | 114 | # save to the disk 115 | file_name = os.path.join(save_path, 'options.txt') 116 | with open(file_name, 'wt') as opt_file: 117 | opt_file.write(message) 118 | opt_file.write('\n') 119 | 120 | def main(): 121 | global args, best_prec1 122 | 123 | if args.seed is not None: 124 | random.seed(args.seed) 125 | torch.manual_seed(args.seed) 126 | cudnn.deterministic = True 127 | warnings.warn('You have chosen to seed training. ' 128 | 'This will turn on the CUDNN deterministic setting, ' 129 | 'which can slow down your training considerably! ' 130 | 'You may see unexpected behavior when restarting ' 131 | 'from checkpoints.') 132 | 133 | args.distributed = args.world_size > 1 134 | 135 | if args.distributed: 136 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 137 | world_size=args.world_size) 138 | 139 | # create model 140 | print("=> creating model '{}'".format(args.arch)) 141 | model = eval(args.arch)(cfg) 142 | 143 | if not args.distributed: 144 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 145 | model.features = torch.nn.DataParallel(model.features) 146 | model.cuda() 147 | else: 148 | model = torch.nn.DataParallel(model).cuda() 149 | else: 150 | model.cuda() 151 | model = torch.nn.parallel.DistributedDataParallel(model) 152 | 153 | if args.label_smoothing > 0: 154 | # using Label Smoothing 155 | criterion = LabelSmoothingLoss(smoothing=args.label_smoothing) 156 | else: 157 | criterion = nn.CrossEntropyLoss() 158 | 159 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 160 | momentum=args.momentum, 161 | weight_decay=args.weight_decay) 162 | # optionally resume from a checkpoint 163 | title = 'ImageNet-' + args.arch 164 | if not os.path.isdir(args.checkpoint): 165 | mkdir_p(args.checkpoint) 166 | 167 | arch = open(os.path.join(args.checkpoint, 'arch.txt'), 'w') 168 | print (os.path.join(args.checkpoint, 'arch.txt')) 169 | print (model,file=arch) 170 | arch.close() 171 | 172 | if os.path.exists(os.path.join('micronet/checkpoints', args.checkpoint.split('/')[-1], 'checkpoint.pth.tar')): 173 | args.resume = os.path.join('micronet/checkpoints', args.checkpoint.split('/')[-1], 'checkpoint.pth.tar') 174 | log_source = os.path.join('micronet/checkpoints', args.checkpoint.split('/')[-1], 'log.txt') 175 | log_target = os.path.join(args.checkpoint, 'log.txt') 176 | os.system('cp %s %s' %(log_source, log_target)) 177 | 178 | if args.resume: 179 | if os.path.isfile(args.resume): 180 | print("=> loading checkpoint '{}'".format(args.resume)) 181 | checkpoint = torch.load(args.resume) 182 | args.start_epoch = checkpoint['epoch'] 183 | best_prec1 = checkpoint['best_prec1'] 184 | model.load_state_dict(checkpoint['state_dict']) 185 | optimizer.load_state_dict(checkpoint['optimizer']) 186 | print("=> loaded checkpoint '{}' (epoch {})" 187 | .format(args.resume, checkpoint['epoch'])) 188 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 189 | else: 190 | print("=> no checkpoint found at '{}'".format(args.resume)) 191 | else: 192 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 193 | logger.set_names(['Epoch', 'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.', 'Valid Acc5.']) 194 | print_options(args.checkpoint, args) 195 | 196 | cudnn.benchmark = True 197 | 198 | # Data loading code 199 | if args.data_backend == 'pytorch': 200 | get_train_loader = get_pytorch_train_loader 201 | get_val_loader = get_pytorch_val_loader 202 | elif args.data_backend == 'dali-gpu': 203 | get_train_loader = get_dali_train_loader(dali_cpu=False) 204 | get_val_loader = get_dali_val_loader() 205 | elif args.data_backend == 'dali-cpu': 206 | get_train_loader = get_dali_train_loader(dali_cpu=True) 207 | get_val_loader = get_dali_val_loader() 208 | 209 | if not args.evaluate: 210 | train_loader, train_loader_len = get_train_loader(args.data, args.batch_size, workers=args.workers, input_size=args.input_size) 211 | val_loader, val_loader_len = get_val_loader(args.data, args.batch_size, workers=args.workers, input_size=args.input_size) 212 | if args.evaluate: 213 | from collections import OrderedDict 214 | if os.path.isfile(args.weight): 215 | print("=> loading pretrained weight '{}'".format(args.weight)) 216 | source_state = torch.load(args.weight) 217 | if 'state_dict' in source_state: 218 | source_state = source_state['state_dict'] 219 | target_state = OrderedDict() 220 | for k, v in source_state.items(): 221 | if k[:7] != 'module.': 222 | k = 'module.' + k 223 | target_state[k] = v 224 | model.load_state_dict(target_state) 225 | else: 226 | print("=> no weight found at '{}'".format(args.weight)) 227 | 228 | validate(val_loader, val_loader_len, model, criterion) 229 | return 230 | 231 | # visualization 232 | writer = SummaryWriter(os.path.join(args.checkpoint, 'logs')) 233 | 234 | for epoch in range(args.start_epoch, args.epochs): 235 | if args.distributed: 236 | train_sampler.set_epoch(epoch) 237 | 238 | print('\nEpoch: [%d | %d]' % (epoch + 1, args.epochs)) 239 | 240 | # train for one epoch 241 | train_loss, train_acc = train(train_loader, train_loader_len, model, criterion, optimizer, epoch) 242 | 243 | # evaluate on validation set 244 | 245 | lr = optimizer.param_groups[0]['lr'] 246 | val_loss, prec1, prec5 = validate(val_loader, val_loader_len, model, criterion) 247 | # append logger file 248 | logger.append([epoch+1, lr, train_loss, val_loss, train_acc, prec1, prec5]) 249 | 250 | # tensorboardX 251 | writer.add_scalar('learning rate', lr, epoch + 1) 252 | writer.add_scalars('loss', {'train loss': train_loss, 'validation loss': val_loss}, epoch + 1) 253 | writer.add_scalars('accuracy', {'train accuracy': train_acc, 'validation accuracy': prec1}, epoch + 1) 254 | 255 | is_best = prec1 > best_prec1 256 | best_prec1 = max(prec1, best_prec1) 257 | save_checkpoint({ 258 | 'epoch': epoch + 1, 259 | 'arch': args.arch, 260 | 'state_dict': model.state_dict(), 261 | 'best_prec1': best_prec1, 262 | 'optimizer' : optimizer.state_dict(), 263 | }, is_best, checkpoint=args.checkpoint) 264 | 265 | logger.close() 266 | writer.close() 267 | 268 | def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, max_alpha_epoch=300): 269 | 270 | batch_time = AverageMeter() 271 | data_time = AverageMeter() 272 | losses = AverageMeter() 273 | top1 = AverageMeter() 274 | top5 = AverageMeter() 275 | # switch to train mode 276 | model.train() 277 | 278 | end = time.time() 279 | if epoch < 100: 280 | mixup_alpha = args.mixup * float(epoch) / 100 281 | else: 282 | mixup_alpha = args.mixup 283 | for i, (input, target) in enumerate(train_loader): 284 | adjust_learning_rate(optimizer, epoch, i, train_loader_len) 285 | 286 | # measure data loading time 287 | data_time.update(time.time() - end) 288 | input = input.cuda() 289 | target = target.cuda() 290 | # compute output 291 | if args.mixup != 0: 292 | # using mixup 293 | input, label_a, label_b, lam = mixup_data(input, target, mixup_alpha) 294 | output = model(input) 295 | loss = mixup_criterion(criterion, output, label_a, label_b, lam) 296 | acc1_a, acc5_a = accuracy(output, label_a, topk=(1, 5)) 297 | acc1_b, acc5_b = accuracy(output, label_b, topk=(1, 5)) 298 | # measure accuracy and record loss 299 | prec1 = lam * acc1_a + (1 - lam) * acc1_b 300 | prec5 = lam * acc5_a + (1 - lam) * acc5_b 301 | else: 302 | # normal forward 303 | output = model(input) 304 | loss = criterion(output, target) 305 | # measure accuracy and record loss 306 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 307 | losses.update(loss.item(), input.size(0)) 308 | top1.update(prec1.item(), input.size(0)) 309 | top5.update(prec5.item(), input.size(0)) 310 | 311 | # compute gradient and do SGD step 312 | optimizer.zero_grad() 313 | loss.backward() 314 | 315 | optimizer.step() 316 | 317 | # measure elapsed time 318 | batch_time.update(time.time() - end) 319 | end = time.time() 320 | if i % 100 == 0: 321 | print('Epoch: [{0}/{1}][{2}/{3}]\t' 322 | 'LR: {4}\t' 323 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 324 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 325 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 326 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 327 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 328 | epoch+1, args.epochs, i, train_loader_len, optimizer.param_groups[0]['lr'], batch_time=batch_time, 329 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 330 | return (losses.avg, top1.avg) 331 | 332 | def validate(val_loader, val_loader_len, model, criterion): 333 | 334 | batch_time = AverageMeter() 335 | data_time = AverageMeter() 336 | losses = AverageMeter() 337 | top1 = AverageMeter() 338 | top5 = AverageMeter() 339 | 340 | # switch to evaluate mode 341 | model.eval() 342 | 343 | end = time.time() 344 | for i, (input, target) in enumerate(val_loader): 345 | # measure data loading time 346 | data_time.update(time.time() - end) 347 | 348 | input = input.cuda() 349 | target = target.cuda() 350 | with torch.no_grad(): 351 | # compute output 352 | output = model(input) 353 | if type(output) is tuple: 354 | output = output[0] 355 | loss = criterion(output, target) 356 | 357 | # measure accuracy and record loss 358 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 359 | losses.update(loss.item(), input.size(0)) 360 | top1.update(prec1.item(), input.size(0)) 361 | top5.update(prec5.item(), input.size(0)) 362 | 363 | # measure elapsed time 364 | batch_time.update(time.time() - end) 365 | end = time.time() 366 | if i % 100 == 0: 367 | print('Test: [{0}/{1}]\t' 368 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 369 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 370 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 371 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 372 | i, val_loader_len, batch_time=batch_time, loss=losses, 373 | top1=top1, top5=top5)) 374 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) 375 | 376 | return (losses.avg, top1.avg, top5.avg) 377 | 378 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 379 | filepath = os.path.join(checkpoint, filename) 380 | torch.save(state, filepath) 381 | if is_best: 382 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 383 | 384 | 385 | from math import cos, pi 386 | def adjust_learning_rate(optimizer, epoch, iteration, num_iter): 387 | lr = optimizer.param_groups[0]['lr'] 388 | 389 | warmup_epoch = 5 if args.warmup else 0 390 | warmup_iter = warmup_epoch * num_iter 391 | current_iter = iteration + epoch * num_iter 392 | max_iter = args.epochs * num_iter 393 | 394 | if args.lr_decay == 'step': 395 | lr = args.lr * (args.gamma ** ((current_iter - warmup_iter) // (max_iter - warmup_iter))) 396 | elif args.lr_decay == 'cos': 397 | lr = args.lr * (1 + cos(pi * (current_iter - warmup_iter) / (max_iter - warmup_iter))) / 2 398 | elif args.lr_decay == 'linear': 399 | lr = args.lr * (1 - (current_iter - warmup_iter) / (max_iter - warmup_iter)) 400 | elif args.lr_decay == 'schedule': 401 | count = sum([1 for s in args.schedule if s <= epoch]) 402 | lr = args.lr * pow(args.gamma, count) 403 | else: 404 | raise ValueError('Unknown lr mode {}'.format(args.lr_decay)) 405 | 406 | if epoch < warmup_epoch: 407 | lr = args.lr * current_iter / warmup_iter 408 | 409 | 410 | for param_group in optimizer.param_groups: 411 | param_group['lr'] = lr 412 | 413 | def mixup_data(x, y, alpha): 414 | ''' 415 | Returns mixed inputs, pairs of targets, and lambda 416 | ''' 417 | if alpha > 0: 418 | lam = np.random.beta(alpha, alpha) 419 | else: 420 | lam = 1 421 | 422 | batch_size = x.size()[0] 423 | index = torch.randperm(batch_size).to(x.device) 424 | 425 | mixed_x = lam * x + (1 - lam) * x[index, :] 426 | y_a, y_b = y, y[index] 427 | return mixed_x, y_a, y_b, lam 428 | 429 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 430 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 431 | 432 | class LabelSmoothingLoss(nn.Module): 433 | 434 | def __init__(self, smoothing=0.0): 435 | super(LabelSmoothingLoss, self).__init__() 436 | self.smoothing = smoothing 437 | 438 | def forward(self, input, target): 439 | log_prob = input.log_softmax(dim=-1) 440 | weight = input.new_ones(input.size()) * \ 441 | self.smoothing / (input.size(-1) - 1.) 442 | weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing)) 443 | loss = (-weight * log_prob).sum(dim=-1).mean() 444 | return loss 445 | 446 | 447 | 448 | if __name__ == '__main__': 449 | main() 450 | -------------------------------------------------------------------------------- /scripts/eval_micronet_m0.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m0-eval 3 | export WEIGHT_PATH=$3 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH -c $OUTPUT_PATH -j 48 --input-size 224 -b 512 -e --weight $WEIGHT_PATH \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy6_exp4_4M_221 \ 8 | MODEL.MICRONETS.STEM_CH 4 \ 9 | MODEL.MICRONETS.STEM_GROUPS 2,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 640 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.05 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/eval_micronet_m1.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m1-eval 3 | export WEIGHT_PATH=$3 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH -c $OUTPUT_PATH -j 48 --input-size 224 -b 512 -e --weight $WEIGHT_PATH \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy6_exp6_6M_221 \ 8 | MODEL.MICRONETS.STEM_CH 6 \ 9 | MODEL.MICRONETS.STEM_GROUPS 3,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 960 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.05 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/eval_micronet_m2.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m2-eval 3 | export WEIGHT_PATH=$3 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH -c $OUTPUT_PATH -j 48 --input-size 224 -b 512 -e --weight $WEIGHT_PATH \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy9_exp6_12M_221 \ 8 | MODEL.MICRONETS.STEM_CH 8 \ 9 | MODEL.MICRONETS.STEM_GROUPS 4,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 1024 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.1 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/eval_micronet_m3.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m3-eval 3 | export WEIGHT_PATH=$3 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH -c $OUTPUT_PATH -j 48 --input-size 224 -b 512 -e --weight $WEIGHT_PATH \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy12_exp6_20M_020 \ 8 | MODEL.MICRONETS.STEM_CH 12 \ 9 | MODEL.MICRONETS.STEM_GROUPS 4,3\ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 1024 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.1 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,0.5 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.5 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | 25 | -------------------------------------------------------------------------------- /scripts/train_micronet_m0_2gpu.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m0-2gpu 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 python main.py --arch MicroNet -d $DATA_PATH --epochs 600 --lr-decay cos --lr 0.1 --wd 3e-5 \ 5 | -c $OUTPUT_PATH -j 48 --input-size 224 --label-smoothing 0.0 -b 256 \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy6_exp4_4M_221 \ 8 | MODEL.MICRONETS.STEM_CH 4 \ 9 | MODEL.MICRONETS.STEM_GROUPS 2,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 640 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.05 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/train_micronet_m0_4gpu.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m0-4gpu 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH --epochs 600 --lr-decay cos --lr 0.2 --wd 3e-5 \ 5 | -c $OUTPUT_PATH -j 48 --input-size 224 --label-smoothing 0.0 -b 512 \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy6_exp4_4M_221 \ 8 | MODEL.MICRONETS.STEM_CH 4 \ 9 | MODEL.MICRONETS.STEM_GROUPS 2,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 640 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.05 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/train_micronet_m1_2gpu.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m1-2gpu 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 python main.py --arch MicroNet -d $DATA_PATH --epochs 600 --lr-decay cos --lr 0.1 --wd 3e-5 \ 5 | -c $OUTPUT_PATH1 -j 48 --input-size 224 --label-smoothing 0.0 -b 256 \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy6_exp6_6M_221 \ 8 | MODEL.MICRONETS.STEM_CH 6 \ 9 | MODEL.MICRONETS.STEM_GROUPS 3,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 960 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.05 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/train_micronet_m1_4gpu.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m1-4gpu 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH --epochs 600 --lr-decay cos --lr 0.2 --wd 3e-5 \ 5 | -c $OUTPUT_PATH1 -j 48 --input-size 224 --label-smoothing 0.0 -b 512 \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy6_exp6_6M_221 \ 8 | MODEL.MICRONETS.STEM_CH 6 \ 9 | MODEL.MICRONETS.STEM_GROUPS 3,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 960 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.05 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/train_micronet_m2_2gpu.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m2-2gpu 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 python main.py --arch MicroNet -d $DATA_PATH --epochs 600 --lr-decay cos --lr 0.1 --wd 3e-5 \ 5 | -c $OUTPUT_PATH -j 48 --input-size 224 --label-smoothing 0.0 -b 256 \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy9_exp6_12M_221 \ 8 | MODEL.MICRONETS.STEM_CH 8 \ 9 | MODEL.MICRONETS.STEM_GROUPS 4,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 1024 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.1 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/train_micronet_m2_4gpu.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m2-4gpu 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH --epochs 600 --lr-decay cos --lr 0.2 --wd 3e-5 \ 5 | -c $OUTPUT_PATH -j 48 --input-size 224 --label-smoothing 0.0 -b 512 \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy9_exp6_12M_221 \ 8 | MODEL.MICRONETS.STEM_CH 8 \ 9 | MODEL.MICRONETS.STEM_GROUPS 4,2 \ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 1024 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.1 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,1.0 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.0 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | -------------------------------------------------------------------------------- /scripts/train_micronet_m3_4gpu.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=$1/imagenet 2 | export OUTPUT_PATH=$2/micronet-m3-4gpu 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --arch MicroNet -d $DATA_PATH --epochs 600 --lr-decay cos --lr 0.2 --wd 4e-5 \ 5 | -c $OUTPUT_PATH -j 48 --input-size 224 -b 512 --warmup --label-smoothing 0.0 \ 6 | MODEL.MICRONETS.BLOCK DYMicroBlock \ 7 | MODEL.MICRONETS.NET_CONFIG msnx_dy12_exp6_20M_020 \ 8 | MODEL.MICRONETS.STEM_CH 12 \ 9 | MODEL.MICRONETS.STEM_GROUPS 4,3\ 10 | MODEL.MICRONETS.STEM_DILATION 1 \ 11 | MODEL.MICRONETS.STEM_MODE spatialsepsf \ 12 | MODEL.MICRONETS.OUT_CH 1024 \ 13 | MODEL.MICRONETS.DEPTHSEP True \ 14 | MODEL.MICRONETS.POINTWISE group \ 15 | MODEL.MICRONETS.DROPOUT 0.1 \ 16 | MODEL.ACTIVATION.MODULE DYShiftMax \ 17 | MODEL.ACTIVATION.ACT_MAX 2.0 \ 18 | MODEL.ACTIVATION.LINEARSE_BIAS False \ 19 | MODEL.ACTIVATION.INIT_A_BLOCK3 1.0,0.0 \ 20 | MODEL.ACTIVATION.INIT_A 1.0,0.5 \ 21 | MODEL.ACTIVATION.INIT_B 0.0,0.5 \ 22 | MODEL.ACTIVATION.REDUCTION 8 \ 23 | MODEL.MICRONETS.SHUFFLE True \ 24 | 25 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | from .defaults import _C as cfg 8 | from .larc import * 9 | # progress bar 10 | import os, sys 11 | #sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 12 | #from progress.bar import Bar as Bar 13 | -------------------------------------------------------------------------------- /utils/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | 7 | DATA_BACKEND_CHOICES = ['pytorch'] 8 | try: 9 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 10 | from nvidia.dali.pipeline import Pipeline 11 | import nvidia.dali.ops as ops 12 | import nvidia.dali.types as types 13 | DATA_BACKEND_CHOICES.append('dali-gpu') 14 | DATA_BACKEND_CHOICES.append('dali-cpu') 15 | except ImportError: 16 | print("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.") 17 | 18 | 19 | def fast_collate(batch): 20 | imgs = [img[0] for img in batch] 21 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 22 | w = imgs[0].size[0] 23 | h = imgs[0].size[1] 24 | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) 25 | for i, img in enumerate(imgs): 26 | nump_array = np.asarray(img, dtype=np.uint8) 27 | tens = torch.from_numpy(nump_array) 28 | if(nump_array.ndim < 3): 29 | nump_array = np.expand_dims(nump_array, axis=-1) 30 | nump_array = np.rollaxis(nump_array, 2) 31 | 32 | tensor[i] += torch.from_numpy(nump_array) 33 | 34 | return tensor, targets 35 | 36 | 37 | class PrefetchedWrapper(object): 38 | def prefetched_loader(loader): 39 | mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 40 | std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 41 | 42 | stream = torch.cuda.Stream() 43 | first = True 44 | 45 | for next_input, next_target in loader: 46 | with torch.cuda.stream(stream): 47 | #next_input = next_input.cuda(async=True) 48 | #next_target = next_target.cuda(async=True) 49 | next_input = next_input.cuda() 50 | next_target = next_target.cuda() 51 | next_input = next_input.float() 52 | next_input = next_input.sub_(mean).div_(std) 53 | 54 | if not first: 55 | yield input, target 56 | else: 57 | first = False 58 | 59 | torch.cuda.current_stream().wait_stream(stream) 60 | input = next_input 61 | target = next_target 62 | 63 | yield input, target 64 | 65 | def __init__(self, dataloader): 66 | self.dataloader = dataloader 67 | self.epoch = 0 68 | 69 | def __iter__(self): 70 | if (self.dataloader.sampler is not None and 71 | isinstance(self.dataloader.sampler, 72 | torch.utils.data.distributed.DistributedSampler)): 73 | 74 | self.dataloader.sampler.set_epoch(self.epoch) 75 | self.epoch += 1 76 | 77 | mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 78 | std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 79 | 80 | stream = torch.cuda.Stream() 81 | first = True 82 | 83 | for next_input, next_target in self.dataloader: 84 | with torch.cuda.stream(stream): 85 | #next_input = next_input.cuda(async=True) 86 | #next_target = next_target.cuda(async=True) 87 | next_input = next_input.cuda() 88 | next_target = next_target.cuda() 89 | next_input = next_input.float() 90 | next_input = next_input.sub_(mean).div_(std) 91 | 92 | if not first: 93 | yield input, target 94 | else: 95 | first = False 96 | 97 | torch.cuda.current_stream().wait_stream(stream) 98 | input = next_input 99 | target = next_target 100 | 101 | yield input, target 102 | #return PrefetchedWrapper.prefetched_loader(self.dataloader) 103 | 104 | def get_pytorch_train_loader(data_path, batch_size, workers=5, _worker_init_fn=None, input_size=224): 105 | traindir = os.path.join(data_path, 'train') 106 | train_dataset = datasets.ImageFolder( 107 | traindir, 108 | transforms.Compose([ 109 | transforms.RandomResizedCrop(input_size), 110 | transforms.RandomHorizontalFlip(), 111 | ])) 112 | 113 | if torch.distributed.is_initialized(): 114 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 115 | else: 116 | train_sampler = None 117 | 118 | train_loader = torch.utils.data.DataLoader( 119 | train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), 120 | num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) 121 | 122 | return PrefetchedWrapper(train_loader), len(train_loader) 123 | 124 | def get_pytorch_val_loader(data_path, batch_size, workers=5, _worker_init_fn=None, input_size=224): 125 | valdir = os.path.join(data_path, 'val') 126 | val_dataset = datasets.ImageFolder( 127 | valdir, transforms.Compose([ 128 | transforms.Resize(int(input_size / 0.875)), 129 | transforms.CenterCrop(input_size), 130 | ])) 131 | 132 | if torch.distributed.is_initialized(): 133 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 134 | else: 135 | val_sampler = None 136 | 137 | val_loader = torch.utils.data.DataLoader( 138 | val_dataset, 139 | sampler=val_sampler, 140 | batch_size=batch_size, shuffle=False, 141 | num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, 142 | collate_fn=fast_collate) 143 | 144 | return PrefetchedWrapper(val_loader), len(val_loader) 145 | -------------------------------------------------------------------------------- /utils/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | # ----------------------------------------------------------------------------- 6 | # Config definition 7 | # ----------------------------------------------------------------------------- 8 | 9 | _C = CN() 10 | 11 | _C.MODEL = CN() 12 | _C.MODEL.DEVICE = "cuda" 13 | _C.MODEL.DEBUG = False # add debug flag 14 | 15 | # If the WEIGHT starts with a catalog://, like :R-50, the code will look for 16 | # the path in paths_catalog. Else, it will use it as the specified absolute 17 | # path 18 | _C.MODEL.WEIGHT = "" 19 | 20 | _C.MODEL.ACTIVATION = CN() 21 | _C.MODEL.ACTIVATION.MODULE = "MaxLUConv" # old for mbnetm2 "MaxLUConv" 22 | _C.MODEL.ACTIVATION.ACT_MAX = 1.0 23 | _C.MODEL.ACTIVATION.LAST_SE_OUP = False #use se-oup for the last 1x1 conv 24 | _C.MODEL.ACTIVATION.LINEARSE_BIAS = True 25 | _C.MODEL.ACTIVATION.INIT_A_BLOCK3 = [1.0, 0.0] 26 | _C.MODEL.ACTIVATION.INIT_A = [1.0, 0.0] 27 | _C.MODEL.ACTIVATION.INIT_B = [0.0, 0.0] 28 | _C.MODEL.ACTIVATION.REDUCTION = 4 29 | _C.MODEL.ACTIVATION.FC = False 30 | _C.MODEL.ACTIVATION.ACT = 'relu' 31 | 32 | _C.MODEL.MICRONETS = CN() 33 | _C.MODEL.MICRONETS.NET_CONFIG = "d12_3322_192_k5" 34 | _C.MODEL.MICRONETS.STEM_CH = 16 35 | _C.MODEL.MICRONETS.STEM_DILATION = 1 36 | _C.MODEL.MICRONETS.STEM_GROUPS = [4, 8] 37 | _C.MODEL.MICRONETS.STEM_MODE = "default" # defaut/max2 38 | _C.MODEL.MICRONETS.BLOCK = "MicroBlock1" 39 | _C.MODEL.MICRONETS.POINTWISE = 'group' #fft/1x1/shuffle 40 | _C.MODEL.MICRONETS.DEPTHSEP = True # YUNSHENG ADD FOR MUTUAL LEARNING 41 | _C.MODEL.MICRONETS.SHUFFLE = False 42 | _C.MODEL.MICRONETS.OUT_CH = 1024 43 | _C.MODEL.MICRONETS.DROPOUT = 0.0 44 | 45 | 46 | # ---------------------------------------------------------------------------- # 47 | # Misc options 48 | # ---------------------------------------------------------------------------- # 49 | _C.OUTPUT_DIR = "." 50 | 51 | _C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py") 52 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import torch 3 | 4 | __all__ = ['accuracy'] 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """Computes the precision@k for the specified values of k""" 8 | with torch.no_grad(): 9 | maxk = max(topk) 10 | batch_size = target.size(0) 11 | 12 | _, pred = output.topk(maxk, 1, True, True) 13 | pred = pred.t() 14 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 15 | 16 | res = [] 17 | for k in topk: 18 | correct_k = correct[:k].view(-1).float().sum(0) 19 | res.append(correct_k.mul_(100.0 / batch_size)) 20 | return res 21 | -------------------------------------------------------------------------------- /utils/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | from PIL import Image 5 | from torchvision.datasets import DatasetFolder 6 | 7 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 8 | 9 | def has_file_allowed_extension(filename, extensions): 10 | return filename.lower().endswith(extensions) 11 | 12 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): 13 | images = [] 14 | dir = os.path.expanduser(dir) 15 | if not ((extensions is None) ^ (is_valid_file is None)): 16 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 17 | if extensions is not None: 18 | def is_valid_file(x): 19 | return has_file_allowed_extension(x, extensions) 20 | for target in sorted(class_to_idx.keys()): 21 | d = os.path.join(dir, target) 22 | if not os.path.isdir(d): 23 | continue 24 | for root, _, fnames in sorted(os.walk(d)): 25 | for fname in sorted(fnames): 26 | path = os.path.join(root, fname) 27 | if is_valid_file(path): 28 | item = (path.replace(dir, ''), class_to_idx[target]) 29 | images.append(item) 30 | 31 | return images 32 | 33 | def pil_loader(path, retry=5): 34 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 35 | ri = 0 36 | while ri= (3, 5): 97 | # Faster and available in Python 3.5 and above 98 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 99 | else: 100 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 101 | classes.sort() 102 | class_to_idx = {classes[i]: i for i in range(len(classes))} 103 | return classes, class_to_idx 104 | 105 | def __getitem__(self, index): 106 | """ 107 | Args: 108 | index (int): Index 109 | 110 | Returns: 111 | tuple: (sample, target) where target is class_index of the target class. 112 | """ 113 | path, target = self.samples[index] 114 | sample = self.loader(self.root + '/' + path) 115 | if self.transform is not None: 116 | sample = self.transform(sample) 117 | if self.target_transform is not None: 118 | target = self.target_transform(target) 119 | 120 | return sample, target 121 | 122 | def __len__(self): 123 | return len(self.samples) 124 | 125 | -------------------------------------------------------------------------------- /utils/larc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.parameter import Parameter 4 | 5 | class LARC(object): 6 | """ 7 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, 8 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive 9 | local learning rate for each individual parameter. The algorithm is designed to improve 10 | convergence of large batch training. 11 | 12 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. 13 | 14 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate 15 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. 16 | 17 | ``` 18 | model = ... 19 | optim = torch.optim.Adam(model.parameters(), lr=...) 20 | optim = LARC(optim) 21 | ``` 22 | 23 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. 24 | 25 | ``` 26 | model = ... 27 | optim = torch.optim.Adam(model.parameters(), lr=...) 28 | optim = LARC(optim) 29 | optim = apex.fp16_utils.FP16_Optimizer(optim) 30 | ``` 31 | 32 | Args: 33 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 34 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 35 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. 36 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr 37 | """ 38 | 39 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): 40 | self.optim = optimizer 41 | self.trust_coefficient = trust_coefficient 42 | self.eps = eps 43 | self.clip = clip 44 | 45 | def __getstate__(self): 46 | return self.optim.__getstate__() 47 | 48 | def __setstate__(self, state): 49 | self.optim.__setstate__(state) 50 | 51 | @property 52 | def state(self): 53 | return self.optim.state 54 | 55 | def __repr__(self): 56 | return self.optim.__repr__() 57 | 58 | @property 59 | def param_groups(self): 60 | return self.optim.param_groups 61 | 62 | @param_groups.setter 63 | def param_groups(self, value): 64 | self.optim.param_groups = value 65 | 66 | def state_dict(self): 67 | return self.optim.state_dict() 68 | 69 | def load_state_dict(self, state_dict): 70 | self.optim.load_state_dict(state_dict) 71 | 72 | def zero_grad(self): 73 | self.optim.zero_grad() 74 | 75 | def add_param_group(self, param_group): 76 | self.optim.add_param_group( param_group) 77 | 78 | def step(self): 79 | with torch.no_grad(): 80 | weight_decays = [] 81 | for group in self.optim.param_groups: 82 | # absorb weight decay control from optimizer 83 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 84 | weight_decays.append(weight_decay) 85 | 86 | if 'exclude_larc' in group and group['exclude_larc']: 87 | continue 88 | 89 | group['weight_decay'] = 0 90 | for p in group['params']: 91 | if p.grad is None: 92 | continue 93 | param_norm = torch.norm(p.data) 94 | grad_norm = torch.norm(p.grad.data) 95 | 96 | if param_norm != 0 and grad_norm != 0: 97 | # calculate adaptive lr + weight decay 98 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) 99 | 100 | # clip learning rate for LARC 101 | if self.clip: 102 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` 103 | adaptive_lr = min(adaptive_lr/group['lr'], 1) 104 | 105 | p.grad.data += weight_decay * p.data 106 | p.grad.data *= adaptive_lr 107 | 108 | self.optim.step() 109 | # return weight decay control to optimizer 110 | for i, group in enumerate(self.optim.param_groups): 111 | group['weight_decay'] = weight_decays[i] 112 | 113 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 9 | 10 | def savefig(fname, dpi=None): 11 | dpi = 150 if dpi == None else dpi 12 | 13 | def plot_overlap(logger, names=None): 14 | names = logger.names if names == None else names 15 | numbers = logger.numbers 16 | for _, name in enumerate(names): 17 | x = np.arange(len(numbers[name])) 18 | return [logger.title + '(' + name + ')' for name in names] 19 | 20 | class Logger(object): 21 | '''Save training process to log file with simple plot function.''' 22 | def __init__(self, fpath, title=None, resume=False): 23 | self.file = None 24 | self.resume = resume 25 | self.title = '' if title == None else title 26 | if fpath is not None: 27 | if resume: 28 | self.file = open(fpath, 'r') 29 | name = self.file.readline() 30 | self.names = name.rstrip().split('\t') 31 | self.numbers = {} 32 | for _, name in enumerate(self.names): 33 | self.numbers[name] = [] 34 | 35 | for numbers in self.file: 36 | numbers = numbers.rstrip().split('\t') 37 | for i in range(0, len(numbers)): 38 | self.numbers[self.names[i]].append(numbers[i]) 39 | self.file.close() 40 | self.file = open(fpath, 'a') 41 | else: 42 | self.file = open(fpath, 'w') 43 | 44 | def set_names(self, names): 45 | if self.resume: 46 | pass 47 | # initialize numbers as empty list 48 | self.numbers = {} 49 | self.names = names 50 | for _, name in enumerate(self.names): 51 | self.file.write(name) 52 | self.file.write('\t') 53 | self.numbers[name] = [] 54 | self.file.write('\n') 55 | self.file.flush() 56 | 57 | 58 | def append(self, numbers): 59 | assert len(self.names) == len(numbers), 'Numbers do not match names' 60 | for index, num in enumerate(numbers): 61 | self.file.write("{0:.6f}".format(num)) 62 | self.file.write('\t') 63 | self.numbers[self.names[index]].append(num) 64 | self.file.write('\n') 65 | self.file.flush() 66 | 67 | def plot(self, names=None): 68 | names = self.names if names == None else names 69 | numbers = self.numbers 70 | for _, name in enumerate(names): 71 | x = np.arange(len(numbers[name])) 72 | 73 | def close(self): 74 | if self.file is not None: 75 | self.file.close() 76 | 77 | class LoggerMonitor(object): 78 | '''Load and visualize multiple logs.''' 79 | def __init__ (self, paths): 80 | '''paths is a distionary with {name:filepath} pair''' 81 | self.loggers = [] 82 | for title, path in paths.items(): 83 | logger = Logger(path, title=title, resume=True) 84 | self.loggers.append(logger) 85 | 86 | def plot(self, names=None): 87 | legend_text = [] 88 | for logger in self.loggers: 89 | legend_text += plot_overlap(logger, names) 90 | if __name__ == '__main__': 91 | # # Example 92 | # logger = Logger('test.txt') 93 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 94 | 95 | # length = 100 96 | # t = np.arange(length) 97 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 98 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 99 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 100 | 101 | # for i in range(0, length): 102 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 103 | # logger.plot() 104 | 105 | # Example: logger monitor 106 | paths = { 107 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 108 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 109 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 110 | } 111 | 112 | field = ['Valid Acc.'] 113 | 114 | monitor = LoggerMonitor(paths) 115 | monitor.plot(names=field) 116 | savefig('test.eps') 117 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | from .misc import * 7 | 8 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 9 | 10 | # functions to show an image 11 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 12 | for i in range(0, 3): 13 | img[i] = img[i] * std[i] + mean[i] # unnormalize 14 | npimg = img.numpy() 15 | return np.transpose(npimg, (1, 2, 0)) 16 | 17 | def gauss(x,a,b,c): 18 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 19 | 20 | def colorize(x): 21 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 22 | if x.dim() == 2: 23 | torch.unsqueeze(x, 0, out=x) 24 | if x.dim() == 3: 25 | cl = torch.zeros([3, x.size(1), x.size(2)]) 26 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 27 | cl[1] = gauss(x,1,.5,.3) 28 | cl[2] = gauss(x,1,.2,.3) 29 | cl[cl.gt(1)] = 1 30 | elif x.dim() == 4: 31 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 32 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 33 | cl[:,1,:,:] = gauss(x,1,.5,.3) 34 | cl[:,2,:,:] = gauss(x,1,.2,.3) 35 | return cl 36 | 37 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 38 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 39 | 40 | 41 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 42 | im_size = images.size(2) 43 | 44 | # save for adding mask 45 | im_data = images.clone() 46 | for i in range(0, 3): 47 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 48 | 49 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 50 | 51 | # for b in range(mask.size(0)): 52 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 53 | mask_size = mask.size(2) 54 | # print('Max %f Min %f' % (mask.max(), mask.min())) 55 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 56 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 57 | # for c in range(3): 58 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 59 | 60 | # print(mask.size()) 61 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 62 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 63 | 64 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 65 | im_size = images.size(2) 66 | 67 | # save for adding mask 68 | im_data = images.clone() 69 | for i in range(0, 3): 70 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 71 | 72 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 73 | 74 | for i in range(len(masklist)): 75 | mask = masklist[i].data.cpu() 76 | # for b in range(mask.size(0)): 77 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 78 | mask_size = mask.size(2) 79 | # print('Max %f Min %f' % (mask.max(), mask.min())) 80 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 81 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 82 | # for c in range(3): 83 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 84 | 85 | # print(mask.size()) 86 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 87 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 88 | 89 | 90 | 91 | # x = torch.zeros(1, 3, 3) 92 | # out = colorize(x) 93 | # out_im = make_image(out) 94 | --------------------------------------------------------------------------------