├── README.md ├── count_params_flops.py ├── main.py ├── model.txt ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── mixnet.cpython-35.pyc │ ├── mixnet.cpython-36.pyc │ ├── mixnet_builder.cpython-35.pyc │ ├── mixnet_builder.cpython-36.pyc │ ├── utils.cpython-35.pyc │ └── utils.cpython-36.pyc ├── mixnet.py ├── mixnet_builder.py └── utils.py ├── scripts.sh ├── train_cifar.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Mixconv_pytorch 2 | 3 | This repo is the pytorch implementation of the paper from Google: [MixConv: Mixed Depthwise Convolutional Kernels](https://arxiv.org/pdf/1907.09595.pdf) 4 | 5 | This code mimics the implementation from the offical repo in Tensorflow (https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) 6 | 7 | ### Dependencies 8 | Python 3.5+ 9 | [PyTorch v1.0.0+](http://pytorch.org/) 10 | 11 | ### How to use 12 | `python train_cifar.py --lr 0.016 --batch-size 256 -a mixnet-s --dtype cifar100 --optim adam --scheduler exp --epochs 650` 13 | 14 | ## Reproduce and Results 15 | # CIFAR 100 16 | | **Network** | **Top 1** | **#Params** | **#Flops** | 17 | | ----------- | ------------ | ------------------|------------| 18 | | Mixnet-S | in progress | 2.7M (*this code*)| 3.2M (*this code*) | 19 | | Mixnet-M | in progress | 3.6M (*this code*)| 4.4M (*this code*) | 20 | | Mixnet-L | in progress | 5.8M (*this code*)| Bug issue (solved soon)| 21 | 22 | # ImageNet 23 | | **Network** | **Top 1** | **#Params** | **#Flops** | 24 | | ----------- | ------------ | ------------------|------------| 25 | | Mixnet-S | in progress | 4.1M (*this code*)| 259M (*this code*) | 26 | | Mixnet-M | in progress | 5.0M (*this code*)| 360M (*this code*) | 27 | | Mixnet-L | in progress | 7.3M (*this code*)| 580M (*this code*) | 28 | 29 | ### Discussion 30 | Currently, the accuracy is very low compare with the numbers reported in the paper. So, welcome scientific, rigorous ,and helpful feedbacks to train MixConv proper in Pytorch. 31 | -------------------------------------------------------------------------------- /count_params_flops.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import numpy as np 6 | import torch 7 | # import utils 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torch.nn.functional as F 13 | import torchvision.datasets as dset 14 | import torch.backends.cudnn as cudnn 15 | import copy 16 | 17 | from models import MixNet 18 | from models import mixnet_builder 19 | from models import utils 20 | 21 | # from thop import profile 22 | dtype = 'imagenet' # 'imagenet' 23 | arch = 'mixnet-m' 24 | input_size = 32 25 | num_classes = 100 26 | batch_size = 2 27 | 28 | if dtype == 'imagenet': 29 | input_size = 224 30 | num_classes = 1000 31 | batch_size = 1 32 | blocks_args, global_params = mixnet_builder.get_model_params(arch) 33 | model = MixNet(input_size=input_size, num_classes=num_classes, blocks_args=blocks_args, global_params=global_params) 34 | input = torch.randn(batch_size, 3, input_size, input_size) 35 | # flops, params = profile(model, inputs=(input, ),) 36 | out = model(input) 37 | print('params= %fMB'%(model._num_params/1e6)) 38 | print('flops: %fM'%(model._num_flops/batch_size/1e6)) 39 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import numpy as np 6 | import torch 7 | import utils 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torch.nn.functional as F 13 | import torchvision.datasets as dset 14 | import torch.backends.cudnn as cudnn 15 | import copy 16 | 17 | 18 | -------------------------------------------------------------------------------- /model.txt: -------------------------------------------------------------------------------- 1 | Experiment dir : eval-EXP-cifar10-20190815-183831 2 | 08/15 06:38:31 PM => load data 'cifar10' 3 | Files already downloaded and verified 4 | Files already downloaded and verified 5 | 08/15 06:38:32 PM update lrs: '[150, 250, 350]' 6 | 08/15 06:38:32 PM => creating model 'mixnet-s' 7 | MixNet( 8 | (_relu): ReLU() 9 | (_mix_blocks): Sequential( 10 | (0): MixnetBlock( 11 | (_act_fn): ReLU() 12 | (_depthwise_conv): MDConv( 13 | (_convs): ModuleList( 14 | (0): Conv2dSamePadding(16, 16, kernel_size=(3, 3), stride=(1, 1), groups=16, bias=False) 15 | ) 16 | ) 17 | (_bn1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 18 | (_project_conv): GroupedConv2D( 19 | (_convs): ModuleList( 20 | (0): Conv2dSamePadding(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) 21 | ) 22 | ) 23 | (_bn2): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 24 | ) 25 | (1): MixnetBlock( 26 | (_act_fn): ReLU() 27 | (_expand_conv): GroupedConv2D( 28 | (_convs): ModuleList( 29 | (0): Conv2dSamePadding(8, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) 30 | (1): Conv2dSamePadding(8, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) 31 | ) 32 | ) 33 | (_bn0): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 34 | (_depthwise_conv): MDConv( 35 | (_convs): ModuleList( 36 | (0): Conv2dSamePadding(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False) 37 | ) 38 | ) 39 | (_bn1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 40 | (_project_conv): GroupedConv2D( 41 | (_convs): ModuleList( 42 | (0): Conv2dSamePadding(48, 12, kernel_size=(1, 1), stride=(1, 1), bias=False) 43 | (1): Conv2dSamePadding(48, 12, kernel_size=(1, 1), stride=(1, 1), bias=False) 44 | ) 45 | ) 46 | (_bn2): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 47 | ) 48 | (2): MixnetBlock( 49 | (_act_fn): ReLU() 50 | (_expand_conv): GroupedConv2D( 51 | (_convs): ModuleList( 52 | (0): Conv2dSamePadding(12, 36, kernel_size=(1, 1), stride=(1, 1), bias=False) 53 | (1): Conv2dSamePadding(12, 36, kernel_size=(1, 1), stride=(1, 1), bias=False) 54 | ) 55 | ) 56 | (_bn0): BatchNorm2d(72, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 57 | (_depthwise_conv): MDConv( 58 | (_convs): ModuleList( 59 | (0): Conv2dSamePadding(72, 72, kernel_size=(3, 3), stride=(1, 1), groups=72, bias=False) 60 | ) 61 | ) 62 | (_bn1): BatchNorm2d(72, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 63 | (_project_conv): GroupedConv2D( 64 | (_convs): ModuleList( 65 | (0): Conv2dSamePadding(36, 12, kernel_size=(1, 1), stride=(1, 1), bias=False) 66 | (1): Conv2dSamePadding(36, 12, kernel_size=(1, 1), stride=(1, 1), bias=False) 67 | ) 68 | ) 69 | (_bn2): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 70 | ) 71 | (3): MixnetBlock( 72 | (_act_fn): Swish( 73 | (sig): Sigmoid() 74 | ) 75 | (_expand_conv): GroupedConv2D( 76 | (_convs): ModuleList( 77 | (0): Conv2dSamePadding(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) 78 | ) 79 | ) 80 | (_bn0): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 81 | (_depthwise_conv): MDConv( 82 | (_convs): ModuleList( 83 | (0): Conv2dSamePadding(48, 48, kernel_size=(3, 3), stride=(2, 2), groups=48, bias=False) 84 | (1): Conv2dSamePadding(48, 48, kernel_size=(5, 5), stride=(2, 2), groups=48, bias=False) 85 | (2): Conv2dSamePadding(48, 48, kernel_size=(7, 7), stride=(2, 2), groups=48, bias=False) 86 | ) 87 | ) 88 | (_bn1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 89 | (_se_reduce): GroupedConv2D( 90 | (_convs): ModuleList( 91 | (0): Conv2dSamePadding(144, 12, kernel_size=(1, 1), stride=(1, 1), bias=False) 92 | ) 93 | ) 94 | (_se_expand): GroupedConv2D( 95 | (_convs): ModuleList( 96 | (0): Conv2dSamePadding(12, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) 97 | ) 98 | ) 99 | (sigmoid): Sigmoid() 100 | (_project_conv): GroupedConv2D( 101 | (_convs): ModuleList( 102 | (0): Conv2dSamePadding(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 103 | ) 104 | ) 105 | (_bn2): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 106 | ) 107 | (4): MixnetBlock( 108 | (_act_fn): Swish( 109 | (sig): Sigmoid() 110 | ) 111 | (_expand_conv): GroupedConv2D( 112 | (_convs): ModuleList( 113 | (0): Conv2dSamePadding(20, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) 114 | (1): Conv2dSamePadding(20, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) 115 | ) 116 | ) 117 | (_bn0): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 118 | (_depthwise_conv): MDConv( 119 | (_convs): ModuleList( 120 | (0): Conv2dSamePadding(120, 120, kernel_size=(3, 3), stride=(1, 1), groups=120, bias=False) 121 | (1): Conv2dSamePadding(120, 120, kernel_size=(5, 5), stride=(1, 1), groups=120, bias=False) 122 | ) 123 | ) 124 | (_bn1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 125 | (_se_reduce): GroupedConv2D( 126 | (_convs): ModuleList( 127 | (0): Conv2dSamePadding(240, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 128 | ) 129 | ) 130 | (_se_expand): GroupedConv2D( 131 | (_convs): ModuleList( 132 | (0): Conv2dSamePadding(20, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) 133 | ) 134 | ) 135 | (sigmoid): Sigmoid() 136 | (_project_conv): GroupedConv2D( 137 | (_convs): ModuleList( 138 | (0): Conv2dSamePadding(120, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 139 | (1): Conv2dSamePadding(120, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 140 | ) 141 | ) 142 | (_bn2): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 143 | ) 144 | (5): MixnetBlock( 145 | (_act_fn): Swish( 146 | (sig): Sigmoid() 147 | ) 148 | (_expand_conv): GroupedConv2D( 149 | (_convs): ModuleList( 150 | (0): Conv2dSamePadding(20, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) 151 | (1): Conv2dSamePadding(20, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) 152 | ) 153 | ) 154 | (_bn0): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 155 | (_depthwise_conv): MDConv( 156 | (_convs): ModuleList( 157 | (0): Conv2dSamePadding(120, 120, kernel_size=(3, 3), stride=(1, 1), groups=120, bias=False) 158 | (1): Conv2dSamePadding(120, 120, kernel_size=(5, 5), stride=(1, 1), groups=120, bias=False) 159 | ) 160 | ) 161 | (_bn1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 162 | (_se_reduce): GroupedConv2D( 163 | (_convs): ModuleList( 164 | (0): Conv2dSamePadding(240, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 165 | ) 166 | ) 167 | (_se_expand): GroupedConv2D( 168 | (_convs): ModuleList( 169 | (0): Conv2dSamePadding(20, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) 170 | ) 171 | ) 172 | (sigmoid): Sigmoid() 173 | (_project_conv): GroupedConv2D( 174 | (_convs): ModuleList( 175 | (0): Conv2dSamePadding(120, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 176 | (1): Conv2dSamePadding(120, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 177 | ) 178 | ) 179 | (_bn2): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 180 | ) 181 | (6): MixnetBlock( 182 | (_act_fn): Swish( 183 | (sig): Sigmoid() 184 | ) 185 | (_expand_conv): GroupedConv2D( 186 | (_convs): ModuleList( 187 | (0): Conv2dSamePadding(20, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) 188 | (1): Conv2dSamePadding(20, 120, kernel_size=(1, 1), stride=(1, 1), bias=False) 189 | ) 190 | ) 191 | (_bn0): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 192 | (_depthwise_conv): MDConv( 193 | (_convs): ModuleList( 194 | (0): Conv2dSamePadding(120, 120, kernel_size=(3, 3), stride=(1, 1), groups=120, bias=False) 195 | (1): Conv2dSamePadding(120, 120, kernel_size=(5, 5), stride=(1, 1), groups=120, bias=False) 196 | ) 197 | ) 198 | (_bn1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 199 | (_se_reduce): GroupedConv2D( 200 | (_convs): ModuleList( 201 | (0): Conv2dSamePadding(240, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 202 | ) 203 | ) 204 | (_se_expand): GroupedConv2D( 205 | (_convs): ModuleList( 206 | (0): Conv2dSamePadding(20, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) 207 | ) 208 | ) 209 | (sigmoid): Sigmoid() 210 | (_project_conv): GroupedConv2D( 211 | (_convs): ModuleList( 212 | (0): Conv2dSamePadding(120, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 213 | (1): Conv2dSamePadding(120, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 214 | ) 215 | ) 216 | (_bn2): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 217 | ) 218 | (7): MixnetBlock( 219 | (_act_fn): Swish( 220 | (sig): Sigmoid() 221 | ) 222 | (_expand_conv): GroupedConv2D( 223 | (_convs): ModuleList( 224 | (0): Conv2dSamePadding(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) 225 | ) 226 | ) 227 | (_bn0): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 228 | (_depthwise_conv): MDConv( 229 | (_convs): ModuleList( 230 | (0): Conv2dSamePadding(80, 80, kernel_size=(3, 3), stride=(2, 2), groups=80, bias=False) 231 | (1): Conv2dSamePadding(80, 80, kernel_size=(5, 5), stride=(2, 2), groups=80, bias=False) 232 | (2): Conv2dSamePadding(80, 80, kernel_size=(7, 7), stride=(2, 2), groups=80, bias=False) 233 | ) 234 | ) 235 | (_bn1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 236 | (_se_reduce): GroupedConv2D( 237 | (_convs): ModuleList( 238 | (0): Conv2dSamePadding(240, 10, kernel_size=(1, 1), stride=(1, 1), bias=False) 239 | ) 240 | ) 241 | (_se_expand): GroupedConv2D( 242 | (_convs): ModuleList( 243 | (0): Conv2dSamePadding(10, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) 244 | ) 245 | ) 246 | (sigmoid): Sigmoid() 247 | (_project_conv): GroupedConv2D( 248 | (_convs): ModuleList( 249 | (0): Conv2dSamePadding(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 250 | (1): Conv2dSamePadding(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 251 | ) 252 | ) 253 | (_bn2): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 254 | ) 255 | (8): MixnetBlock( 256 | (_act_fn): Swish( 257 | (sig): Sigmoid() 258 | ) 259 | (_expand_conv): GroupedConv2D( 260 | (_convs): ModuleList( 261 | (0): Conv2dSamePadding(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False) 262 | ) 263 | ) 264 | (_bn0): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 265 | (_depthwise_conv): MDConv( 266 | (_convs): ModuleList( 267 | (0): Conv2dSamePadding(240, 240, kernel_size=(3, 3), stride=(1, 1), groups=240, bias=False) 268 | (1): Conv2dSamePadding(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False) 269 | ) 270 | ) 271 | (_bn1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 272 | (_se_reduce): GroupedConv2D( 273 | (_convs): ModuleList( 274 | (0): Conv2dSamePadding(480, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 275 | ) 276 | ) 277 | (_se_expand): GroupedConv2D( 278 | (_convs): ModuleList( 279 | (0): Conv2dSamePadding(20, 480, kernel_size=(1, 1), stride=(1, 1), bias=False) 280 | ) 281 | ) 282 | (sigmoid): Sigmoid() 283 | (_project_conv): GroupedConv2D( 284 | (_convs): ModuleList( 285 | (0): Conv2dSamePadding(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 286 | (1): Conv2dSamePadding(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 287 | ) 288 | ) 289 | (_bn2): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 290 | ) 291 | (9): MixnetBlock( 292 | (_act_fn): Swish( 293 | (sig): Sigmoid() 294 | ) 295 | (_expand_conv): GroupedConv2D( 296 | (_convs): ModuleList( 297 | (0): Conv2dSamePadding(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False) 298 | ) 299 | ) 300 | (_bn0): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 301 | (_depthwise_conv): MDConv( 302 | (_convs): ModuleList( 303 | (0): Conv2dSamePadding(240, 240, kernel_size=(3, 3), stride=(1, 1), groups=240, bias=False) 304 | (1): Conv2dSamePadding(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False) 305 | ) 306 | ) 307 | (_bn1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 308 | (_se_reduce): GroupedConv2D( 309 | (_convs): ModuleList( 310 | (0): Conv2dSamePadding(480, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) 311 | ) 312 | ) 313 | (_se_expand): GroupedConv2D( 314 | (_convs): ModuleList( 315 | (0): Conv2dSamePadding(20, 480, kernel_size=(1, 1), stride=(1, 1), bias=False) 316 | ) 317 | ) 318 | (sigmoid): Sigmoid() 319 | (_project_conv): GroupedConv2D( 320 | (_convs): ModuleList( 321 | (0): Conv2dSamePadding(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 322 | (1): Conv2dSamePadding(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 323 | ) 324 | ) 325 | (_bn2): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 326 | ) 327 | (10): MixnetBlock( 328 | (_act_fn): Swish( 329 | (sig): Sigmoid() 330 | ) 331 | (_expand_conv): GroupedConv2D( 332 | (_convs): ModuleList( 333 | (0): Conv2dSamePadding(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) 334 | (1): Conv2dSamePadding(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False) 335 | ) 336 | ) 337 | (_bn0): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 338 | (_depthwise_conv): MDConv( 339 | (_convs): ModuleList( 340 | (0): Conv2dSamePadding(160, 160, kernel_size=(3, 3), stride=(1, 1), groups=160, bias=False) 341 | (1): Conv2dSamePadding(160, 160, kernel_size=(5, 5), stride=(1, 1), groups=160, bias=False) 342 | (2): Conv2dSamePadding(160, 160, kernel_size=(7, 7), stride=(1, 1), groups=160, bias=False) 343 | ) 344 | ) 345 | (_bn1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 346 | (_se_reduce): GroupedConv2D( 347 | (_convs): ModuleList( 348 | (0): Conv2dSamePadding(480, 40, kernel_size=(1, 1), stride=(1, 1), bias=False) 349 | ) 350 | ) 351 | (_se_expand): GroupedConv2D( 352 | (_convs): ModuleList( 353 | (0): Conv2dSamePadding(40, 480, kernel_size=(1, 1), stride=(1, 1), bias=False) 354 | ) 355 | ) 356 | (sigmoid): Sigmoid() 357 | (_project_conv): GroupedConv2D( 358 | (_convs): ModuleList( 359 | (0): Conv2dSamePadding(240, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 360 | (1): Conv2dSamePadding(240, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 361 | ) 362 | ) 363 | (_bn2): BatchNorm2d(120, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 364 | ) 365 | (11): MixnetBlock( 366 | (_act_fn): Swish( 367 | (sig): Sigmoid() 368 | ) 369 | (_expand_conv): GroupedConv2D( 370 | (_convs): ModuleList( 371 | (0): Conv2dSamePadding(60, 180, kernel_size=(1, 1), stride=(1, 1), bias=False) 372 | (1): Conv2dSamePadding(60, 180, kernel_size=(1, 1), stride=(1, 1), bias=False) 373 | ) 374 | ) 375 | (_bn0): BatchNorm2d(360, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 376 | (_depthwise_conv): MDConv( 377 | (_convs): ModuleList( 378 | (0): Conv2dSamePadding(90, 90, kernel_size=(3, 3), stride=(1, 1), groups=90, bias=False) 379 | (1): Conv2dSamePadding(90, 90, kernel_size=(5, 5), stride=(1, 1), groups=90, bias=False) 380 | (2): Conv2dSamePadding(90, 90, kernel_size=(7, 7), stride=(1, 1), groups=90, bias=False) 381 | (3): Conv2dSamePadding(90, 90, kernel_size=(9, 9), stride=(1, 1), groups=90, bias=False) 382 | ) 383 | ) 384 | (_bn1): BatchNorm2d(360, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 385 | (_se_reduce): GroupedConv2D( 386 | (_convs): ModuleList( 387 | (0): Conv2dSamePadding(360, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 388 | ) 389 | ) 390 | (_se_expand): GroupedConv2D( 391 | (_convs): ModuleList( 392 | (0): Conv2dSamePadding(60, 360, kernel_size=(1, 1), stride=(1, 1), bias=False) 393 | ) 394 | ) 395 | (sigmoid): Sigmoid() 396 | (_project_conv): GroupedConv2D( 397 | (_convs): ModuleList( 398 | (0): Conv2dSamePadding(180, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 399 | (1): Conv2dSamePadding(180, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 400 | ) 401 | ) 402 | (_bn2): BatchNorm2d(120, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 403 | ) 404 | (12): MixnetBlock( 405 | (_act_fn): Swish( 406 | (sig): Sigmoid() 407 | ) 408 | (_expand_conv): GroupedConv2D( 409 | (_convs): ModuleList( 410 | (0): Conv2dSamePadding(60, 180, kernel_size=(1, 1), stride=(1, 1), bias=False) 411 | (1): Conv2dSamePadding(60, 180, kernel_size=(1, 1), stride=(1, 1), bias=False) 412 | ) 413 | ) 414 | (_bn0): BatchNorm2d(360, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 415 | (_depthwise_conv): MDConv( 416 | (_convs): ModuleList( 417 | (0): Conv2dSamePadding(90, 90, kernel_size=(3, 3), stride=(1, 1), groups=90, bias=False) 418 | (1): Conv2dSamePadding(90, 90, kernel_size=(5, 5), stride=(1, 1), groups=90, bias=False) 419 | (2): Conv2dSamePadding(90, 90, kernel_size=(7, 7), stride=(1, 1), groups=90, bias=False) 420 | (3): Conv2dSamePadding(90, 90, kernel_size=(9, 9), stride=(1, 1), groups=90, bias=False) 421 | ) 422 | ) 423 | (_bn1): BatchNorm2d(360, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 424 | (_se_reduce): GroupedConv2D( 425 | (_convs): ModuleList( 426 | (0): Conv2dSamePadding(360, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 427 | ) 428 | ) 429 | (_se_expand): GroupedConv2D( 430 | (_convs): ModuleList( 431 | (0): Conv2dSamePadding(60, 360, kernel_size=(1, 1), stride=(1, 1), bias=False) 432 | ) 433 | ) 434 | (sigmoid): Sigmoid() 435 | (_project_conv): GroupedConv2D( 436 | (_convs): ModuleList( 437 | (0): Conv2dSamePadding(180, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 438 | (1): Conv2dSamePadding(180, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 439 | ) 440 | ) 441 | (_bn2): BatchNorm2d(120, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 442 | ) 443 | (13): MixnetBlock( 444 | (_act_fn): Swish( 445 | (sig): Sigmoid() 446 | ) 447 | (_expand_conv): GroupedConv2D( 448 | (_convs): ModuleList( 449 | (0): Conv2dSamePadding(120, 720, kernel_size=(1, 1), stride=(1, 1), bias=False) 450 | ) 451 | ) 452 | (_bn0): BatchNorm2d(720, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 453 | (_depthwise_conv): MDConv( 454 | (_convs): ModuleList( 455 | (0): Conv2dSamePadding(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144, bias=False) 456 | (1): Conv2dSamePadding(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False) 457 | (2): Conv2dSamePadding(144, 144, kernel_size=(7, 7), stride=(2, 2), groups=144, bias=False) 458 | (3): Conv2dSamePadding(144, 144, kernel_size=(9, 9), stride=(2, 2), groups=144, bias=False) 459 | (4): Conv2dSamePadding(144, 144, kernel_size=(11, 11), stride=(2, 2), groups=144, bias=False) 460 | ) 461 | ) 462 | (_bn1): BatchNorm2d(720, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 463 | (_se_reduce): GroupedConv2D( 464 | (_convs): ModuleList( 465 | (0): Conv2dSamePadding(720, 60, kernel_size=(1, 1), stride=(1, 1), bias=False) 466 | ) 467 | ) 468 | (_se_expand): GroupedConv2D( 469 | (_convs): ModuleList( 470 | (0): Conv2dSamePadding(60, 720, kernel_size=(1, 1), stride=(1, 1), bias=False) 471 | ) 472 | ) 473 | (sigmoid): Sigmoid() 474 | (_project_conv): GroupedConv2D( 475 | (_convs): ModuleList( 476 | (0): Conv2dSamePadding(720, 200, kernel_size=(1, 1), stride=(1, 1), bias=False) 477 | ) 478 | ) 479 | (_bn2): BatchNorm2d(200, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 480 | ) 481 | (14): MixnetBlock( 482 | (_act_fn): Swish( 483 | (sig): Sigmoid() 484 | ) 485 | (_expand_conv): GroupedConv2D( 486 | (_convs): ModuleList( 487 | (0): Conv2dSamePadding(200, 1200, kernel_size=(1, 1), stride=(1, 1), bias=False) 488 | ) 489 | ) 490 | (_bn0): BatchNorm2d(1200, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 491 | (_depthwise_conv): MDConv( 492 | (_convs): ModuleList( 493 | (0): Conv2dSamePadding(300, 300, kernel_size=(3, 3), stride=(1, 1), groups=300, bias=False) 494 | (1): Conv2dSamePadding(300, 300, kernel_size=(5, 5), stride=(1, 1), groups=300, bias=False) 495 | (2): Conv2dSamePadding(300, 300, kernel_size=(7, 7), stride=(1, 1), groups=300, bias=False) 496 | (3): Conv2dSamePadding(300, 300, kernel_size=(9, 9), stride=(1, 1), groups=300, bias=False) 497 | ) 498 | ) 499 | (_bn1): BatchNorm2d(1200, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 500 | (_se_reduce): GroupedConv2D( 501 | (_convs): ModuleList( 502 | (0): Conv2dSamePadding(1200, 100, kernel_size=(1, 1), stride=(1, 1), bias=False) 503 | ) 504 | ) 505 | (_se_expand): GroupedConv2D( 506 | (_convs): ModuleList( 507 | (0): Conv2dSamePadding(100, 1200, kernel_size=(1, 1), stride=(1, 1), bias=False) 508 | ) 509 | ) 510 | (sigmoid): Sigmoid() 511 | (_project_conv): GroupedConv2D( 512 | (_convs): ModuleList( 513 | (0): Conv2dSamePadding(600, 100, kernel_size=(1, 1), stride=(1, 1), bias=False) 514 | (1): Conv2dSamePadding(600, 100, kernel_size=(1, 1), stride=(1, 1), bias=False) 515 | ) 516 | ) 517 | (_bn2): BatchNorm2d(200, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 518 | ) 519 | (15): MixnetBlock( 520 | (_act_fn): Swish( 521 | (sig): Sigmoid() 522 | ) 523 | (_expand_conv): GroupedConv2D( 524 | (_convs): ModuleList( 525 | (0): Conv2dSamePadding(200, 1200, kernel_size=(1, 1), stride=(1, 1), bias=False) 526 | ) 527 | ) 528 | (_bn0): BatchNorm2d(1200, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 529 | (_depthwise_conv): MDConv( 530 | (_convs): ModuleList( 531 | (0): Conv2dSamePadding(300, 300, kernel_size=(3, 3), stride=(1, 1), groups=300, bias=False) 532 | (1): Conv2dSamePadding(300, 300, kernel_size=(5, 5), stride=(1, 1), groups=300, bias=False) 533 | (2): Conv2dSamePadding(300, 300, kernel_size=(7, 7), stride=(1, 1), groups=300, bias=False) 534 | (3): Conv2dSamePadding(300, 300, kernel_size=(9, 9), stride=(1, 1), groups=300, bias=False) 535 | ) 536 | ) 537 | (_bn1): BatchNorm2d(1200, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 538 | (_se_reduce): GroupedConv2D( 539 | (_convs): ModuleList( 540 | (0): Conv2dSamePadding(1200, 100, kernel_size=(1, 1), stride=(1, 1), bias=False) 541 | ) 542 | ) 543 | (_se_expand): GroupedConv2D( 544 | (_convs): ModuleList( 545 | (0): Conv2dSamePadding(100, 1200, kernel_size=(1, 1), stride=(1, 1), bias=False) 546 | ) 547 | ) 548 | (sigmoid): Sigmoid() 549 | (_project_conv): GroupedConv2D( 550 | (_convs): ModuleList( 551 | (0): Conv2dSamePadding(600, 100, kernel_size=(1, 1), stride=(1, 1), bias=False) 552 | (1): Conv2dSamePadding(600, 100, kernel_size=(1, 1), stride=(1, 1), bias=False) 553 | ) 554 | ) 555 | (_bn2): BatchNorm2d(200, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 556 | ) 557 | ) 558 | (_conv_stem): GroupedConv2D( 559 | (_convs): ModuleList( 560 | (0): Conv2dSamePadding(3, 16, kernel_size=(3, 3), stride=(2, 2), bias=False) 561 | ) 562 | ) 563 | (_bn0): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 564 | (_conv_head): GroupedConv2D( 565 | (_convs): ModuleList( 566 | (0): Conv2dSamePadding(200, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False) 567 | ) 568 | ) 569 | (_bn1): BatchNorm2d(1536, eps=0.001, momentum=0.99, affine=True, track_running_stats=True) 570 | (avgpool): AvgPool2d(kernel_size=1, stride=1, padding=0) 571 | (classifier): Linear(in_features=1536, out_features=10, bias=True) 572 | (dropout): Dropout(p=0.2) 573 | ) 574 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixnet import MixNet 2 | from .mixnet_builder import * 3 | from utils import * -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mixnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/mixnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/mixnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/mixnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mixnet_builder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/mixnet_builder.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/mixnet_builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/mixnet_builder.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haithanhp/mixconv_pytorch/cc306a60682b13d5ee881e1a8f16821c8bf7e3b6/models/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/mixnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | # from utils import swish 7 | from .utils import * 8 | 9 | 10 | # from utils import Conv2dSamePadding 11 | 12 | NON_LINEARITY = { 13 | 'ReLU': nn.ReLU(), 14 | 'Swish': Swish(), 15 | } 16 | 17 | class GroupedConv2D(nn.Module): 18 | def __init__(self, in_filters, out_filters, kernel_size, strides=1): 19 | super(GroupedConv2D, self).__init__() 20 | self._groups = len(kernel_size) 21 | self._convs = nn.ModuleList() 22 | self._num_params = 0 23 | self._num_flops = 0 24 | splits = splitFilters(out_filters, self._groups) 25 | inp_splits = splitFilters(in_filters, self._groups) 26 | for i in range(self._groups): 27 | # in_filters = in_filters if i==0 else splits[i] 28 | in_filters = inp_splits[i] 29 | self._convs.append( 30 | Conv2dSamePadding(in_channels=in_filters, 31 | out_channels=splits[i], 32 | groups=1, 33 | kernel_size=kernel_size[i], 34 | stride=strides, 35 | bias=False) 36 | ) 37 | self._num_params += self._convs[i]._num_params 38 | 39 | def forward(self, x): 40 | if len(self._convs)==1: 41 | x = self._convs[0](x) 42 | self._num_flops += self._convs[0]._num_flops 43 | return x 44 | filters = x.size(1) 45 | splits = splitFilters(filters, len(self._convs)) 46 | x_splits = torch.split(x, splits, dim=1) 47 | x_outs = [c(x) for x, c in zip(x_splits, self._convs)] 48 | for c in self._convs: 49 | self._num_flops += c._num_flops 50 | 51 | x = torch.cat(x_outs, dim=1) 52 | return x 53 | 54 | 55 | class MDConv(nn.Module): 56 | def __init__(self, filters, kernel_size, strides=1, dilated=False): 57 | super(MDConv, self).__init__() 58 | self._dilated = dilated 59 | self._convs = nn.ModuleList() 60 | self._groups = len(kernel_size) 61 | self._num_params = 0 62 | self._num_flops = 0 63 | splits = splitFilters(filters, self._groups) 64 | for i in range(self._groups): 65 | self._convs.append( 66 | Conv2dSamePadding(in_channels=splits[i], 67 | out_channels=splits[i], 68 | groups=splits[i], 69 | kernel_size=kernel_size[i], 70 | stride=strides, 71 | bias=False) 72 | ) 73 | self._num_params += self._convs[i]._num_params 74 | 75 | def forward(self, x): 76 | if self._groups == 1: 77 | x = self._convs[0](x) 78 | self._num_flops += self._convs[0]._num_flops 79 | return x 80 | 81 | filters = x.size(1) 82 | splits = splitFilters(filters, len(self._convs)) 83 | x_splits = torch.split(x, splits, dim=1) 84 | x_outs = [c(x) for x, c in zip(x_splits, self._convs)] 85 | for c in self._convs: 86 | self._num_flops += c._num_flops 87 | 88 | x = torch.cat(x_outs, dim=1) 89 | return x 90 | 91 | 92 | class MixnetBlock(nn.Module): 93 | def __init__(self, block_args, global_params): 94 | super().__init__() 95 | self._block_args = block_args 96 | self._bn_momentum = global_params.bn_momentum 97 | self._bn_eps = global_params.bn_eps 98 | self._data_format = global_params.data_format 99 | # if self._data_format == 'channel_first': 100 | # self._channel_axis = 1 101 | # self._spatial_dims = (2,3) 102 | # else: 103 | # self._channel_axis = -1 104 | # self._spatial_dims = (1,2) 105 | self._spatial_dims = (2,3) 106 | self._has_se = (self._block_args.se_ratio is not None) and ( 107 | self._block_args.se_ratio > 0) and (self._block_args.se_ratio <= 1) 108 | non_linear = 'Swish' if self._block_args.swish else 'ReLU' 109 | self._act_fn = NON_LINEARITY[non_linear] #swish if self._block_args.swish else nn.ReLU() 110 | self._num_params = 0 111 | self._num_flops = 0 112 | 113 | # Build modules 114 | inp = self._block_args.input_filters 115 | filters = self._block_args.input_filters * self._block_args.expand_ratio 116 | kExpand_size = self._block_args.expand_ksize 117 | 118 | if self._block_args.expand_ratio != 1: 119 | # Expansion component 120 | self._expand_conv = GroupedConv2D(inp, 121 | filters, 122 | kExpand_size) 123 | self._num_params += self._expand_conv._num_params 124 | 125 | self._bn0 = nn.BatchNorm2d(num_features=filters, 126 | momentum=self._bn_momentum, 127 | eps=self._bn_eps) 128 | self._num_params += 4*filters 129 | 130 | # Depth-wise components 131 | kernel_size = self._block_args.dw_ksize 132 | self._depthwise_conv = MDConv(filters, kernel_size, 133 | self._block_args.strides[0], 134 | dilated=self._block_args.dilated) 135 | self._num_params += self._depthwise_conv._num_params 136 | 137 | self._bn1 = nn.BatchNorm2d(num_features=filters, 138 | momentum=self._bn_momentum, 139 | eps=self._bn_eps) 140 | self._num_params += 4*filters 141 | 142 | # Squeeze and Excite components 143 | if self._has_se: 144 | num_reduced_filters = max( 145 | 1, int(self._block_args.input_filters * self._block_args.se_ratio)) 146 | self._se_reduce = GroupedConv2D(filters, 147 | num_reduced_filters, 148 | [1]) 149 | self._num_params += self._se_reduce._num_params 150 | self._se_expand = GroupedConv2D(num_reduced_filters, 151 | filters, 152 | [1]) 153 | self._num_params += self._se_expand._num_params 154 | self.sigmoid = nn.Sigmoid() 155 | 156 | # Output 157 | inp = filters 158 | filters = self._block_args.output_filters 159 | self._project_conv = GroupedConv2D(inp, 160 | filters, 161 | self._block_args.project_ksize) 162 | self._num_params += self._project_conv._num_params 163 | self._bn2 = nn.BatchNorm2d(num_features=filters, 164 | momentum=self._bn_momentum, 165 | eps=self._bn_eps) 166 | self._num_params += 4*filters 167 | 168 | def make_cuda_and_parallel(self): 169 | if self._block_args.expand_ratio != 1: 170 | self._expand_conv.make_cuda_and_parallel() 171 | 172 | self._depthwise_conv.make_cuda_and_parallel() 173 | 174 | if self._has_se: 175 | self._se_reduce.make_cuda_and_parallel() 176 | self._se_expand.make_cuda_and_parallel() 177 | 178 | self._project_conv.make_cuda_and_parallel() 179 | 180 | def forward(self, x): 181 | inputs = x.clone() 182 | 183 | if self._block_args.expand_ratio != 1: 184 | # print('do expand conv') 185 | x1 = self._expand_conv(x) 186 | self._num_flops += self._expand_conv._num_flops 187 | 188 | x2 = self._bn0(x1) 189 | t = x1[0] 190 | nelements = t.numel() 191 | self._num_flops += 4*nelements 192 | 193 | x = self._act_fn(x2) 194 | # t = x2[0] 195 | # nelements = t.numel() 196 | # self._num_flops += nelements 197 | 198 | # print('do depthwise conv') 199 | x1 = self._depthwise_conv(x) 200 | self._num_flops += self._depthwise_conv._num_flops 201 | x2 = self._bn1(x1) 202 | # t = x1[0] 203 | # nelements = t.numel() 204 | # self._num_flops += 4*nelements 205 | 206 | x = self._act_fn(x2) 207 | # t = x2[0] 208 | # nelements = t.numel() 209 | # self._num_flops += nelements 210 | 211 | # print('finish depthwise conv :', x.size()) 212 | 213 | if self._has_se: 214 | # print('do squeeze and excite') 215 | se = torch.mean(x, self._spatial_dims, keepdim=True) 216 | s1 = self._se_reduce(se) 217 | self._num_flops += self._se_reduce._num_flops 218 | s2 = self._act_fn(s1) 219 | # t = s1[0] 220 | # nelements=t.numel() 221 | # self._num_flops += nelements 222 | 223 | se = self._se_expand(s2) 224 | self._num_flops += self._se_expand._num_flops 225 | 226 | # print('finish squeeze and excite :', x.size()) 227 | x = self.sigmoid(se) * x 228 | # t = se[0] 229 | # nelements = t.numel() 230 | # self._num_flops += nelements 231 | 232 | # print('do project conv') 233 | x1 = self._project_conv(x) 234 | self._num_flops += self._project_conv._num_flops 235 | x = self._bn2(x1) 236 | t = x1[0] 237 | nelements = t.numel() 238 | self._num_flops += nelements 239 | 240 | # print('finish project conv :', x.size()) 241 | if self._block_args.id_skip: 242 | if all(s == 1 for s in self._block_args.strides) and self._block_args.input_filters == self._block_args.output_filters: 243 | x = inputs + x 244 | # t = x[0] 245 | # nelements = t.numel() 246 | # self._num_flops += nelements 247 | 248 | return x 249 | 250 | class MixNet(nn.Module): 251 | def __init__(self, input_size=224, num_classes=1000, blocks_args=None, global_params=None): 252 | super(MixNet, self).__init__() 253 | if not isinstance(blocks_args, list): 254 | raise ValueError('blocks_args should be a list.') 255 | self._global_params = global_params 256 | self._blocks_args = blocks_args 257 | self._relu = nn.ReLU() 258 | self._num_params = 0 259 | self._num_flops = 0 260 | blocks = [] 261 | for block_args in self._blocks_args: 262 | assert block_args.num_repeat > 0 263 | block_args = block_args._replace( 264 | input_filters=roundFilters(block_args.input_filters, self._global_params), 265 | output_filters=roundFilters(block_args.output_filters, self._global_params) 266 | ) 267 | blocks.append(MixnetBlock(block_args, self._global_params)) 268 | self._num_params += blocks[-1]._num_params 269 | 270 | if block_args.num_repeat > 1: 271 | block_args = block_args._replace( 272 | input_filters=block_args.output_filters, strides=[1,1]) 273 | for _ in range(block_args.num_repeat - 1): 274 | blocks.append(MixnetBlock(block_args, self._global_params)) 275 | self._num_params += blocks[-1]._num_params 276 | 277 | self._mix_blocks = nn.Sequential(*blocks) 278 | self._bn_momentum = global_params.bn_momentum 279 | self._bn_eps = global_params.bn_eps 280 | 281 | # Stem component 282 | stem_size = self._global_params.stem_size 283 | filters = roundFilters(stem_size, self._global_params) 284 | self._conv_stem = GroupedConv2D(3, 285 | filters, 286 | [3], 287 | 2) 288 | 289 | self._num_params += self._conv_stem._num_params 290 | self._bn0 = nn.BatchNorm2d(num_features=filters, 291 | momentum=self._bn_momentum, 292 | eps=self._bn_eps) 293 | self._num_params += 4*filters 294 | 295 | # Head component 296 | feature_size = self._global_params.feature_size 297 | output_filters = self._blocks_args[-1].output_filters 298 | self._conv_head = GroupedConv2D(output_filters, 299 | feature_size, 300 | [1], 301 | 1) 302 | self._num_params += self._conv_head._num_params 303 | self._bn1 = nn.BatchNorm2d(num_features=feature_size, 304 | momentum=self._bn_momentum, 305 | eps=self._bn_eps) 306 | self._num_params += 4*feature_size 307 | 308 | self.avgpool = nn.AvgPool2d(input_size//32, stride=1) 309 | self.classifier = nn.Linear(feature_size, num_classes) 310 | 311 | if self._global_params.dropout_rate > 0: 312 | self.dropout = nn.Dropout(self._global_params.dropout_rate) 313 | else: 314 | self.dropout = None 315 | self._num_params += feature_size*num_classes 316 | 317 | self._initialize_weights() 318 | 319 | def forward(self,x): 320 | # print('do stem conv') 321 | x = self._conv_stem(x) 322 | self._num_flops += self._conv_stem._num_flops 323 | 324 | # print('finish stem conv x: ', x.size()) 325 | x = self._bn0(x) 326 | t = x[0] 327 | nelements = t.numel() 328 | self._num_flops += nelements 329 | 330 | # print('do mix blocks') 331 | x = self._mix_blocks(x) 332 | # print('do conv head') 333 | for block in self._mix_blocks: 334 | self._num_flops += block._num_flops 335 | 336 | x = self._conv_head(x) 337 | self._num_flops += self._conv_head._num_flops 338 | 339 | x = self._bn1(x) 340 | t = x[0] 341 | nelements = t.numel() 342 | self._num_flops += nelements 343 | 344 | # print('do avg pooling') 345 | t = x.clone() 346 | x = self.avgpool(x) 347 | total_add = torch.prod(torch.Tensor([self.avgpool.kernel_size])) 348 | total_div = 1 349 | kernel_ops = total_add + total_div 350 | num_elements = t.numel() 351 | self._num_flops += kernel_ops * num_elements 352 | 353 | # print('do dropout') 354 | if self.dropout: 355 | x = self.dropout(x) 356 | x = x.view(x.size(0), -1) 357 | t = x.clone() 358 | x = self.classifier(x) 359 | total_mul = self.classifier.in_features 360 | total_add = self.classifier.in_features - 1 361 | num_elements = x.numel() 362 | self._num_flops += (total_mul + total_add) * num_elements 363 | return x 364 | 365 | def _initialize_weights(self): 366 | for m in self.modules(): 367 | # if isinstance(m, nn.Conv2d): 368 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 369 | # m.weight.data.normal_(0, math.sqrt(2.0 / n)) 370 | # if m.bias is not None: 371 | # m.bias.data.zero_() 372 | if isinstance(m, nn.BatchNorm2d): 373 | m.weight.data.fill_(1) 374 | m.bias.data.zero_() 375 | elif isinstance(m, nn.Linear): 376 | n = m.weight.size(1) 377 | init_range = 1.0 / np.sqrt(n) 378 | # m.weight.data.normal_(0, 0.01) 379 | m.weight.data.uniform_(init_range, init_range) 380 | m.bias.data.zero_() -------------------------------------------------------------------------------- /models/mixnet_builder.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | import re 3 | 4 | class MixnetDecoder(object): 5 | """A class of Mixnet decoder to get model configuration.""" 6 | 7 | def _decode_block_string(self, block_string): 8 | """Gets a mixnet block through a string notation of arguments. 9 | 10 | E.g. r2_k3_a1_p1_s2_e1_i32_o16_se0.25_noskip: r - number of repeat blocks, 11 | k - kernel size, s - strides (1-9), e - expansion ratio, i - input filters, 12 | o - output filters, se - squeeze/excitation ratio 13 | 14 | Args: 15 | block_string: a string, a string representation of block arguments. 16 | 17 | Returns: 18 | A BlockArgs instance. 19 | Raises: 20 | ValueError: if the strides option is not correctly specified. 21 | """ 22 | assert isinstance(block_string, str) 23 | ops = block_string.split('_') 24 | options = {} 25 | for op in ops: 26 | splits = re.split(r'(\d.*)', op) 27 | if len(splits) >= 2: 28 | key, value = splits[:2] 29 | options[key] = value 30 | 31 | if 's' not in options or len(options['s']) != 2: 32 | raise ValueError('Strides options should be a pair of integers.') 33 | 34 | def _parse_ksize(ss): 35 | return [int(k) for k in ss.split('.')] 36 | 37 | return BlockArgs( 38 | expand_ksize=_parse_ksize(options['a']), 39 | dw_ksize=_parse_ksize(options['k']), 40 | project_ksize=_parse_ksize(options['p']), 41 | num_repeat=int(options['r']), 42 | input_filters=int(options['i']), 43 | output_filters=int(options['o']), 44 | expand_ratio=int(options['e']), 45 | id_skip=('noskip' not in block_string), 46 | se_ratio=float(options['se']) if 'se' in options else None, 47 | strides=[int(options['s'][0]), int(options['s'][1])], 48 | swish=('sw' in block_string), 49 | dilated=('dilated' in block_string)) 50 | 51 | def _encode_block_string(self, block): 52 | """Encodes a Mixnet block to a string.""" 53 | def _encode_ksize(arr): 54 | return '.'.join([str(k) for k in arr]) 55 | 56 | args = [ 57 | 'r%d' % block.num_repeat, 58 | 'k%s' % _encode_ksize(block.dw_ksize), 59 | 'a%s' % _encode_ksize(block.expand_ksize), 60 | 'p%s' % _encode_ksize(block.project_ksize), 61 | 's%d%d' % (block.strides[0], block.strides[1]), 62 | 'e%s' % block.expand_ratio, 63 | 'i%d' % block.input_filters, 64 | 'o%d' % block.output_filters 65 | ] 66 | if (block.se_ratio is not None and block.se_ratio > 0 and 67 | block.se_ratio <= 1): 68 | args.append('se%s' % block.se_ratio) 69 | if block.id_skip is False: # pylint: disable=g-bool-id-comparison 70 | args.append('noskip') 71 | if block.swish: 72 | args.append('sw') 73 | if block.dilated: 74 | args.append('dilated') 75 | return '_'.join(args) 76 | 77 | def decode(self, string_list): 78 | """Decodes a list of string notations to specify blocks inside the network. 79 | 80 | Args: 81 | string_list: a list of strings, each string is a notation of Mixnet 82 | block.build_model_base 83 | 84 | Returns: 85 | A list of namedtuples to represent Mixnet blocks arguments. 86 | """ 87 | assert isinstance(string_list, list) 88 | blocks_args = [] 89 | for block_string in string_list: 90 | blocks_args.append(self._decode_block_string(block_string)) 91 | return blocks_args 92 | 93 | def encode(self, blocks_args): 94 | """Encodes a list of Mixnet Blocks to a list of strings. 95 | 96 | Args: 97 | blocks_args: A list of namedtuples to represent Mixnet blocks arguments. 98 | Returns: 99 | a list of strings, each string is a notation of Mixnet block. 100 | """ 101 | block_strings = [] 102 | for block in blocks_args: 103 | block_strings.append(self._encode_block_string(block)) 104 | return block_strings 105 | 106 | 107 | def mixnet_s(depth_multiplier=None): 108 | """Creates mixnet-s model. 109 | 110 | Args: 111 | depth_multiplier: multiplier to number of filters per layer. 112 | 113 | Returns: 114 | blocks_args: a list of BlocksArgs for internal Mixnet blocks. 115 | global_params: GlobalParams, global parameters for the model. 116 | """ 117 | blocks_args = [ 118 | 'r1_k3_a1_p1_s11_e1_i16_o16', 119 | 'r1_k3_a1.1_p1.1_s22_e6_i16_o24', 120 | 'r1_k3_a1.1_p1.1_s11_e3_i24_o24', 121 | 122 | 'r1_k3.5.7_a1_p1_s22_e6_i24_o40_se0.5_sw', 123 | 'r3_k3.5_a1.1_p1.1_s11_e6_i40_o40_se0.5_sw', 124 | 125 | 'r1_k3.5.7_a1_p1.1_s22_e6_i40_o80_se0.25_sw', 126 | 'r2_k3.5_a1_p1.1_s11_e6_i80_o80_se0.25_sw', 127 | 128 | 'r1_k3.5.7_a1.1_p1.1_s11_e6_i80_o120_se0.5_sw', 129 | 'r2_k3.5.7.9_a1.1_p1.1_s11_e3_i120_o120_se0.5_sw', 130 | 131 | 'r1_k3.5.7.9.11_a1_p1_s22_e6_i120_o200_se0.5_sw', 132 | 'r2_k3.5.7.9_a1_p1.1_s11_e6_i200_o200_se0.5_sw', 133 | ] 134 | global_params = GlobalParams( 135 | bn_momentum=0.99, 136 | bn_eps=1e-3, 137 | dropout_rate=0.2, 138 | data_format='channels_last', 139 | num_classes=1000, 140 | depth_multiplier=depth_multiplier, 141 | depth_divisor=8, 142 | min_depth=None, 143 | stem_size=16, 144 | # use_keras=True, 145 | feature_size=1536) 146 | decoder = MixnetDecoder() 147 | return decoder.decode(blocks_args), global_params 148 | 149 | 150 | def mixnet_m(depth_multiplier=None): 151 | """Creates a mixnet-m model. 152 | 153 | Args: 154 | depth_multiplier: multiplier to number of filters per layer. 155 | 156 | Returns: 157 | blocks_args: a list of BlocksArgs for internal Mixnet blocks. 158 | global_params: GlobalParams, global parameters for the model. 159 | """ 160 | blocks_args = [ 161 | 'r1_k3_a1_p1_s11_e1_i24_o24', 162 | 'r1_k3.5.7_a1.1_p1.1_s22_e6_i24_o32', 163 | 'r1_k3_a1.1_p1.1_s11_e3_i32_o32', 164 | 165 | 'r1_k3.5.7.9_a1_p1_s22_e6_i32_o40_se0.5_sw', 166 | 'r3_k3.5_a1.1_p1.1_s11_e6_i40_o40_se0.5_sw', 167 | 168 | 'r1_k3.5.7_a1_p1_s22_e6_i40_o80_se0.25_sw', 169 | 'r3_k3.5.7.9_a1.1_p1.1_s11_e6_i80_o80_se0.25_sw', 170 | 171 | 'r1_k3_a1_p1_s11_e6_i80_o120_se0.5_sw', 172 | 'r3_k3.5.7.9_a1.1_p1.1_s11_e3_i120_o120_se0.5_sw', 173 | 174 | 'r1_k3.5.7.9_a1_p1_s22_e6_i120_o200_se0.5_sw', 175 | 'r3_k3.5.7.9_a1_p1.1_s11_e6_i200_o200_se0.5_sw', 176 | ] 177 | global_params = GlobalParams( 178 | bn_momentum=0.99, 179 | bn_eps=1e-3, 180 | dropout_rate=0.25, 181 | data_format='channels_last', 182 | num_classes=1000, 183 | depth_multiplier=depth_multiplier, 184 | depth_divisor=8, 185 | min_depth=None, 186 | stem_size=24, 187 | # use_keras=True, 188 | feature_size=1536) 189 | decoder = MixnetDecoder() 190 | return decoder.decode(blocks_args), global_params 191 | 192 | def mixnet_l(depth_multiplier=None): 193 | d = 1.3 * depth_multiplier if depth_multiplier else 1.3 194 | return mixnet_m(d) 195 | 196 | def get_model_params(model_name): 197 | """Get the block args and global params for a given model.""" 198 | if model_name == 'mixnet-s': 199 | blocks_args, global_params = mixnet_s() 200 | elif model_name == 'mixnet-m': 201 | blocks_args, global_params = mixnet_m() 202 | elif model_name == 'mixnet-l': 203 | blocks_args, global_params = mixnet_l() 204 | else: 205 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 206 | 207 | return blocks_args, global_params -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import collections 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | # Parameters for the entire model (stem, all blocks, and head) 8 | GlobalParams = collections.namedtuple('GlobalParams', [ 9 | 'bn_momentum', 'bn_eps', 'dropout_rate', 'data_format', 10 | 'num_classes', 'depth_multiplier', 'depth_divisor', 'min_depth', 11 | 'stem_size', 'feature_size', 12 | ]) 13 | 14 | # Parameters for an individual model block 15 | BlockArgs = collections.namedtuple('BlockArgs', [ 16 | 'dw_ksize', 'expand_ksize', 'project_ksize', 'num_repeat', 'input_filters', 17 | 'output_filters', 'expand_ratio', 'id_skip', 'strides', 'se_ratio', 18 | 'swish', 'dilated', 19 | ]) 20 | 21 | # Change namedtuple defaults 22 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 23 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 24 | 25 | def swish(x): 26 | return x * torch.sigmoid(x) 27 | 28 | class Swish(nn.Module): 29 | def __init__(self): 30 | super(Swish, self).__init__() 31 | self.sig = nn.Sigmoid() 32 | 33 | def forward(self, x): 34 | return x * self.sig(x) 35 | 36 | def round_filters(filters, global_params): 37 | """Round number of filters based on depth multiplier.""" 38 | multiplier = global_params.depth_multiplier 39 | divisor = global_params.depth_divisor 40 | min_depth = global_params.min_depth 41 | if not multiplier: 42 | return filters 43 | 44 | filters *= multiplier 45 | min_depth = min_depth or divisor 46 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 47 | # Make sure that round down does not go down by more than 10%. 48 | if new_filters < 0.9 * filters: 49 | new_filters += divisor 50 | return new_filters 51 | 52 | class Conv2dSamePadding(nn.Conv2d): 53 | """ 2D Convolutions like TensorFlow """ 54 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True, is_expand=False, is_reduce=False, is_project=False): 55 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 56 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2 57 | self.is_expand = is_expand 58 | self.is_reduce = is_reduce 59 | self.is_project = is_project 60 | self._num_flops = 0 61 | # self.groups = groups 62 | self._num_params = kernel_size * kernel_size * in_channels * out_channels / groups 63 | # print('%dx%dx%dx%d,g%d,num:%d'%(kernel_size, kernel_size, in_channels, out_channels,groups,self._num_params)) 64 | self._init_weights() 65 | 66 | def _init_weights(self): 67 | n = self.kernel_size[0] * self.kernel_size[1] * self.out_channels 68 | self.weight.data.normal_(0, math.sqrt(2.0 / n)) 69 | if self.bias is not None: 70 | self.bias.data.zero_() 71 | 72 | def count_flops(self, x): 73 | # kernel_ops = self.weight.size()[2:].numel() # Kw x Kh 74 | # bias_ops = 1 if self.bias is not None else 0 75 | 76 | # # N x Cout x H x W x (Cin x Kw x Kh + bias) 77 | # total_ops = x.nelement() * (self.in_channels // self.groups * kernel_ops + bias_ops) 78 | 79 | # self.total_ops += torch.Tensor([int(total_ops)]) 80 | # https://github.com/ShichenLiu/CondenseNet/blob/master/utils.py 81 | mutil_add = 1 82 | out_h = int((x.size()[2] + 2 * self.padding[0] - self.kernel_size[0]) / self.stride[0] + 1) 83 | out_w = int((x.size()[3] + 2 * self.padding[1] - self.kernel_size[1]) / self.stride[1] + 1) 84 | self._num_flops = self.in_channels * self.out_channels * self.kernel_size[0] * self.kernel_size[1] * out_h * out_w / self.groups * mutil_add 85 | 86 | 87 | def forward(self, x): 88 | ih, iw = x.size()[-2:] 89 | kh, kw = self.weight.size()[-2:] 90 | sh, sw = self.stride 91 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 92 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 93 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 94 | if pad_h > 0 or pad_w > 0: 95 | x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) 96 | 97 | self.count_flops(x) 98 | 99 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 100 | 101 | def roundFilters(filters, global_params): 102 | """Round number of filters based on depth multiplier.""" 103 | multiplier = global_params.depth_multiplier 104 | divisor = global_params.depth_divisor 105 | min_depth = global_params.min_depth 106 | if not multiplier: 107 | return filters 108 | 109 | filters *= multiplier 110 | min_depth = min_depth or divisor 111 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 112 | # Make sure that round down does not go down by more than 10%. 113 | if new_filters < 0.9 * filters: 114 | new_filters += divisor 115 | return new_filters 116 | 117 | def splitFilters(channels, num_groups): 118 | split_channels = [channels//num_groups for _ in range(num_groups)] 119 | split_channels[0] += channels - sum(split_channels) 120 | return split_channels -------------------------------------------------------------------------------- /scripts.sh: -------------------------------------------------------------------------------- 1 | python train_cifar.py --lr 0.1 --batch-size 256 -a mixnet-s 2 | python train_cifar.py --lr 0.016 --batch-size 256 -a mixnet-s --dtype cifar100 --optim adam --scheduler exp --epochs 650 -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import numpy as np 6 | import torch 7 | import utils 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torch.nn.functional as F 13 | import torchvision.datasets as dset 14 | import torch.backends.cudnn as cudnn 15 | from torch.optim.lr_scheduler import StepLR 16 | from torch.optim.lr_scheduler import CosineAnnealingLR 17 | import copy 18 | 19 | from models import MixNet 20 | from models import mixnet_builder 21 | 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 24 | parser.add_argument('--data', metavar='DATA_PATH', default='./data/', 25 | help='path to imagenet data (default: ./data/)') 26 | parser.add_argument('-a', '--arch', metavar='ARCH', default='mixnet_s', 27 | help='model architecture (default: resnet18)') 28 | parser.add_argument('--dtype', type=str, default="cifar10") 29 | parser.add_argument('--scheduler', type=str, default='exp', help="Learning rate scheduler type") 30 | parser.add_argument('--optim', type=str, default='adam') 31 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 34 | help='number of total epochs to run') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 36 | help='manual epoch number (useful on restarts)') 37 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | metavar='N', 40 | help='mini-batch size (default: 256), this is the total ' 41 | 'batch size of all GPUs on the current node when ' 42 | 'using Data Parallel or Distributed Data Parallel') 43 | parser.add_argument('--s', type=float, default=0.0001, 44 | help='scale sparse rate (default: 0.0001)') 45 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 46 | metavar='LR', help='initial learning rate', dest='lr') 47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 48 | help='momentum') 49 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 50 | parser.add_argument('--eps', type=float, default=0.001) 51 | parser.add_argument('--wd', '--weight-decay', default=1e-5, type=float, 52 | metavar='W', help='weight decay (default: 1e-5)', 53 | dest='weight_decay') 54 | parser.add_argument('-p', '--print-freq', default=10, type=int, 55 | metavar='N', help='print frequency (default: 10)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 59 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 60 | help='evaluate model on validation set') 61 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 62 | help='use pre-trained model') 63 | parser.add_argument('--world-size', default=-1, type=int, 64 | help='number of nodes for distributed training') 65 | parser.add_argument('--rank', default=-1, type=int, 66 | help='node rank for distributed training') 67 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 68 | help='url used to set up distributed training') 69 | parser.add_argument('--dist-backend', default='nccl', type=str, 70 | help='distributed backend') 71 | parser.add_argument('--seed', default=None, type=int, 72 | help='seed for initializing training. ') 73 | parser.add_argument('--gpu', default=None, type=int, 74 | help='GPU id to use.') 75 | parser.add_argument('--image_size', default=32, type=int, 76 | help='image size') 77 | parser.add_argument('--multiprocessing-distributed', action='store_true', 78 | help='Use multi-processing distributed training to launch ' 79 | 'N processes per node, which has N GPUs. This is the ' 80 | 'fastest way to use PyTorch for either single node or ' 81 | 'multi node data parallel training') 82 | 83 | parser.add_argument('--tmp_data_dir', type=str, default='/tmp/cache/', help='temp data dir') 84 | 85 | def get_scheduler(optim, sche_type, step_size, t_max): 86 | if sche_type == "exp": 87 | return StepLR(optim, step_size, 0.97) 88 | elif sche_type == "cosine": 89 | return CosineAnnealingLR(optim, t_max) 90 | else: 91 | return None 92 | 93 | def main(): 94 | if not torch.cuda.is_available(): 95 | logging.info('no gpu device available') 96 | sys.exit(1) 97 | global args 98 | args = parser.parse_args() 99 | args.save = 'eval-{}-{}-{}'.format(args.save, args.dtype,time.strftime("%Y%m%d-%H%M%S")) 100 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 101 | 102 | log_format = '%(asctime)s %(message)s' 103 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 104 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 105 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 106 | fh.setFormatter(logging.Formatter(log_format)) 107 | logging.getLogger().addHandler(fh) 108 | 109 | if args.seed is not None: 110 | random.seed(args.seed) 111 | torch.manual_seed(args.seed) 112 | cudnn.deterministic = True 113 | warnings.warn('You have chosen to seed training. ' 114 | 'This will turn on the CUDNN deterministic setting, ' 115 | 'which can slow down your training considerably! ' 116 | 'You may see unexpected behavior when restarting ' 117 | 'from checkpoints.') 118 | 119 | if args.gpu is not None: 120 | warnings.warn('You have chosen a specific GPU. This will completely ' 121 | 'disable data parallelism.') 122 | 123 | if args.dist_url == "env://" and args.world_size == -1: 124 | args.world_size = int(os.environ["WORLD_SIZE"]) 125 | 126 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 127 | 128 | ngpus_per_node = torch.cuda.device_count() 129 | if args.multiprocessing_distributed: 130 | # Since we have ngpus_per_node processes per node, the total world_size 131 | # needs to be adjusted accordingly 132 | args.world_size = ngpus_per_node * args.world_size 133 | # Use torch.multiprocessing.spawn to launch distributed processes: the 134 | # main_worker process function 135 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 136 | else: 137 | # Simply call main_worker function 138 | main_worker(args.gpu, ngpus_per_node, args) 139 | 140 | def main_worker(gpu, ngpus_per_node, args): 141 | if args.distributed: 142 | if args.dist_url == "env://" and args.rank == -1: 143 | args.rank = int(os.environ["RANK"]) 144 | if args.multiprocessing_distributed: 145 | # For multiprocessing distributed training, rank needs to be the 146 | # global rank among all the processes 147 | args.rank = args.rank * ngpus_per_node + gpu 148 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 149 | world_size=args.world_size, rank=args.rank) 150 | 151 | logging.info(("=> load data '{}'".format(args.dtype))) 152 | if args.dtype == 'cifar10': 153 | train_transform, valid_transform = utils._data_transforms_cifar10(args, cutout=True) 154 | train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform) 155 | valid_data = dset.CIFAR10(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform) 156 | num_classes = 10 157 | update_lrs = [150, 250, 350] 158 | elif args.dtype == 'cifar100': 159 | train_transform, valid_transform = utils._data_transforms_cifar100(args, cutout=True) 160 | train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform) 161 | valid_data = dset.CIFAR100(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform) 162 | num_classes = 100 163 | update_lrs = [40, 80, 160, 300] 164 | else: 165 | logging.info('no data type available') 166 | sys.exit(1) 167 | logging.info("update lrs: '{}'".format(update_lrs)) 168 | train_queue = torch.utils.data.DataLoader( 169 | train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) 170 | 171 | valid_queue = torch.utils.data.DataLoader( 172 | valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) 173 | 174 | logging.info(("=> creating model '{}'".format(args.arch))) 175 | blocks_args, global_params = mixnet_builder.get_model_params(args.arch) 176 | model = MixNet(input_size=32, num_classes=num_classes, blocks_args=blocks_args, global_params=global_params) 177 | # print(model) 178 | # exit(0) 179 | logging.info("args = %s", args) 180 | # logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 181 | logging.info("param size = %fMB", model._num_params / 1e6) 182 | 183 | # exit(0) 184 | if args.distributed: 185 | # For multiprocessing distributed, DistributedDataParallel constructor 186 | # should always set the single device scope, otherwise, 187 | # DistributedDataParallel will use all available devices. 188 | if args.gpu is not None: 189 | torch.cuda.set_device(args.gpu) 190 | model.cuda(args.gpu) 191 | # When using a single GPU per process and per 192 | # DistributedDataParallel, we need to divide the batch size 193 | # ourselves based on the total number of GPUs we have 194 | args.batch_size = int(args.batch_size / ngpus_per_node) 195 | args.workers = int(args.workers / ngpus_per_node) 196 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 197 | else: 198 | model.cuda() 199 | # DistributedDataParallel will divide and allocate batch_size to all 200 | # available GPUs if device_ids are not set 201 | model = torch.nn.parallel.DistributedDataParallel(model) 202 | elif args.gpu is not None: 203 | torch.cuda.set_device(args.gpu) 204 | model = model.cuda(args.gpu) 205 | else: 206 | # DataParallel will divide and allocate batch_size to all available GPUs 207 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 208 | model.features = torch.nn.DataParallel(model.features) 209 | model.cuda() 210 | else: 211 | # model.make_cuda_and_parallel() 212 | # model.avgpool = torch.nn.DataParallel(model.avgpool) 213 | # model.classifier = torch.nn.DataParallel(model.classifier) 214 | model = torch.nn.DataParallel(model) 215 | model = model.cuda() 216 | 217 | 218 | criterion = nn.CrossEntropyLoss().cuda() 219 | 220 | if args.optim == 'adam': 221 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 222 | elif args.optim == 'rmsprop': 223 | optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, eps=args.eps, weight_decay=args.weight_decay) 224 | else: 225 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 226 | momentum=args.momentum, 227 | weight_decay=args.weight_decay) 228 | 229 | cudnn.benchmark = True 230 | # scaled_lr = args.lr * args.batch_size / 256 231 | # optim = { 232 | # "adam" : lambda : torch.optim.Adam(model.parameters()), 233 | # "rmsprop" : lambda : torch.optim.RMSprop(model.parameters(), lr=scaled_lr, momentum=args.momentum, eps=args.eps, weight_decay=args.weight_decay) 234 | # }[args.optim]() 235 | 236 | # scheduler = get_scheduler(optim, args.scheduler, int(2.4*len(train_queue)), args.epochs * len(train_queue)) 237 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)) 238 | best_acc = 0.0 239 | cur_lr = args.lr 240 | for epoch in range(args.epochs): 241 | scheduler.step() 242 | logging.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0]) 243 | # cur_lr = adjust_learning_rate(optimizer, epoch, cur_lr, update_lrs) 244 | # logging.info('Epoch: %d lr %e', epoch, cur_lr) 245 | start_time = time.time() 246 | train_acc, train_obj = train(train_queue, model, criterion, optimizer) 247 | logging.info('Train_acc: %f', train_acc) 248 | valid_acc, valid_obj = test(valid_queue, model, criterion) 249 | if valid_acc > best_acc: 250 | best_acc = valid_acc 251 | logging.info('Valid_acc: %f', valid_acc) 252 | end_time = time.time() 253 | duration = end_time - start_time 254 | print('Epoch time: %ds.' % duration ) 255 | utils.save(model, os.path.join(args.save, 'weights.pt')) 256 | 257 | def adjust_learning_rate(optimizer, epoch, cur_lr, update_lrs): 258 | if epoch in update_lrs: 259 | cur_lr = cur_lr * 0.1 260 | for param_group in optimizer.param_groups: 261 | param_group['lr'] = cur_lr 262 | 263 | return cur_lr 264 | 265 | def train(train_queue, model, criterion, optimizer): 266 | objs = utils.AverageMeter() 267 | top1 = utils.AverageMeter() 268 | model.train() 269 | 270 | for step, (input, target) in enumerate(train_queue): 271 | input = input.cuda(non_blocking=True) 272 | target = target.cuda(non_blocking=True) 273 | optimizer.zero_grad() 274 | logits = model(input) 275 | loss = criterion(logits, target) 276 | loss.backward() 277 | optimizer.step() 278 | 279 | prec1, _ = utils.accuracy(logits, target, topk=(1,5)) 280 | n = input.size(0) 281 | objs.update(loss.data.item(), n) 282 | top1.update(prec1.data.item(), n) 283 | 284 | if step % args.report_freq == 0: 285 | logging.info('Train Step: %03d Objs: %e Acc: %f', step, objs.avg, top1.avg) 286 | 287 | return top1.avg, objs.avg 288 | 289 | def test(valid_queue, model, criterion): 290 | objs = utils.AverageMeter() 291 | top1 = utils.AverageMeter() 292 | model.eval() 293 | 294 | for step, (input, target) in enumerate(valid_queue): 295 | input = input.cuda(non_blocking=True) 296 | target = target.cuda(non_blocking=True) 297 | with torch.no_grad(): 298 | logits = model(input) 299 | loss = criterion(logits, target) 300 | 301 | prec1, _ = utils.accuracy(logits, target, topk=(1,5)) 302 | n = input.size(0) 303 | objs.update(loss.data.item(), n) 304 | top1.update(prec1.data.item(), n) 305 | 306 | if step % args.report_freq == 0: 307 | logging.info('Valid Step: %03d Objs: %e Acc: %f', step, objs.avg, top1.avg) 308 | 309 | return top1.avg, objs.avg 310 | 311 | if __name__== '__main__': 312 | start_time = time.time() 313 | main() 314 | end_time = time.time() 315 | duration = end_time - start_time 316 | logging.info('Eval time: %ds.', duration) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import shutil 5 | import torchvision.transforms as transforms 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Function, Variable 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self, fmt=':f'): 14 | self.fmt = fmt 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | def __str__(self): 30 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 31 | return fmtstr.format(**self.__dict__) 32 | 33 | class ProgressMeter(object): 34 | def __init__(self, num_batches, *meters, prefix=""): 35 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 36 | self.meters = meters 37 | self.prefix = prefix 38 | 39 | def print(self, batch): 40 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 41 | entries += [str(meter) for meter in self.meters] 42 | print('\t'.join(entries)) 43 | 44 | def _get_batch_fmtstr(self, num_batches): 45 | num_digits = len(str(num_batches // 1)) 46 | fmt = '{:' + str(num_digits) + 'd}' 47 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | maxk = max(topk) 51 | batch_size = target.size(0) 52 | 53 | _, pred = output.topk(maxk, 1, True, True) 54 | pred = pred.t() 55 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 56 | 57 | res = [] 58 | for k in topk: 59 | correct_k = correct[:k].view(-1).float().sum(0) 60 | res.append(correct_k.mul_(100.0/batch_size)) 61 | return res 62 | 63 | 64 | class Cutout(object): 65 | def __init__(self, length): 66 | self.length = length 67 | 68 | def __call__(self, img): 69 | h, w = img.size(1), img.size(2) 70 | mask = np.ones((h, w), np.float32) 71 | y = np.random.randint(h) 72 | x = np.random.randint(w) 73 | 74 | y1 = np.clip(y - self.length // 2, 0, h) 75 | y2 = np.clip(y + self.length // 2, 0, h) 76 | x1 = np.clip(x - self.length // 2, 0, w) 77 | x2 = np.clip(x + self.length // 2, 0, w) 78 | 79 | mask[y1: y2, x1: x2] = 0. 80 | mask = torch.from_numpy(mask) 81 | mask = mask.expand_as(img) 82 | img *= mask 83 | return img 84 | 85 | 86 | def _data_transforms_cifar10(args, cutout=False): 87 | CIFAR_MEAN = [0.4914, 0.4822, 0.4465] 88 | CIFAR_STD = [0.2023, 0.1994, 0.2010] 89 | 90 | train_transform = transforms.Compose([ 91 | transforms.RandomCrop(32, padding=4), 92 | transforms.RandomHorizontalFlip(), 93 | transforms.ToTensor(), 94 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 95 | ]) 96 | 97 | if cutout: 98 | train_transform.transforms.append(Cutout(args.cutout_length)) 99 | 100 | valid_transform = transforms.Compose([ 101 | transforms.ToTensor(), 102 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 103 | ]) 104 | return train_transform, valid_transform 105 | 106 | def _data_transforms_cifar100(args, cutout=False): 107 | # CIFAR_MEAN = [0.5071, 0.4867, 0.4408] 108 | # CIFAR_STD = [0.2675, 0.2565, 0.2761] 109 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 110 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 111 | train_transform = transforms.Compose([ 112 | transforms.RandomCrop(32, padding=4), 113 | transforms.RandomHorizontalFlip(), 114 | # transforms.RandomRotation(15), 115 | transforms.ToTensor(), 116 | # transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 117 | normalize, 118 | ]) 119 | if cutout: 120 | train_transform.transforms.append(Cutout(args.cutout_length)) 121 | 122 | valid_transform = transforms.Compose([ 123 | transforms.ToTensor(), 124 | normalize, 125 | # transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 126 | ]) 127 | return train_transform, valid_transform 128 | 129 | 130 | def count_parameters_in_MB(model): 131 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 132 | 133 | 134 | def save_checkpoint(state, is_best, save): 135 | filename = os.path.join(save, 'checkpoint.pth.tar') 136 | torch.save(state, filename) 137 | if is_best: 138 | best_filename = os.path.join(save, 'model_best.pth.tar') 139 | shutil.copyfile(filename, best_filename) 140 | 141 | 142 | def save(model, model_path): 143 | torch.save(model.state_dict(), model_path) 144 | 145 | 146 | def load(model, model_path): 147 | model.load_state_dict(torch.load(model_path)) 148 | 149 | 150 | def drop_path(x, drop_prob): 151 | if drop_prob > 0.: 152 | keep_prob = 1.-drop_prob 153 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 154 | x.div_(keep_prob) 155 | x.mul_(mask) 156 | return x 157 | 158 | 159 | def create_exp_dir(path, scripts_to_save=None): 160 | if not os.path.exists(path): 161 | os.mkdir(path) 162 | print('Experiment dir : {}'.format(path)) 163 | 164 | if scripts_to_save is not None: 165 | os.mkdir(os.path.join(path, 'scripts')) 166 | for script in scripts_to_save: 167 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 168 | shutil.copyfile(script, dst_file) 169 | 170 | # def count_parameters_in_MB(model): 171 | # return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 --------------------------------------------------------------------------------