├── .gitignore ├── Dockerfile ├── README.md ├── requirements.txt ├── setup_conda.py ├── test.py ├── train.py └── util ├── __init__.py ├── caltech101_prepare.sh ├── counter.sh ├── dataloader.py ├── functions.py ├── optimizer.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | util/__pycache__ 2 | logs/ 3 | model/ 4 | example_images/ 5 | 101_ObjectCategories.tar.gz 6 | upload.sh 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel 2 | 3 | ARG PYTHON_VERSION=3.7 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | build-essential \ 6 | cmake \ 7 | git \ 8 | curl \ 9 | vim \ 10 | ca-certificates \ 11 | libjpeg-dev \ 12 | libpng-dev &&\ 13 | rm -rf /var/lib/apt/lists/* 14 | 15 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 16 | chmod +x ~/miniconda.sh && \ 17 | ~/miniconda.sh -b -p /opt/conda && \ 18 | rm ~/miniconda.sh && \ 19 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \ 20 | /opt/conda/bin/conda install -y pytorch torchvision cudatoolkit=10.0 -c pytorch && \ 21 | /opt/conda/bin/conda install -y -c intel ipp-devel && \ 22 | /opt/conda/bin/conda install -y -c conda-forge libjpeg-turbo && \ 23 | /opt/conda/bin/conda clean -ya 24 | ENV PATH /opt/conda/bin:$PATH 25 | 26 | RUN pip install \ 27 | adabound \ 28 | cnn_finetune \ 29 | logzero \ 30 | munch \ 31 | pretrainedmodels \ 32 | protobuf \ 33 | scikit-learn \ 34 | tensorboardX &&\ 35 | pip uninstall -y pillow &&\ 36 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 37 | 38 | RUN git clone https://github.com/pytorch/accimage.git /accimage 39 | COPY setup_conda.py /accimage 40 | RUN cd /accimage && \ 41 | python setup_conda.py install --user 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-finetuner 2 | 3 | ## Setup 4 | 5 | ``` 6 | $ git clone https://github.com/knjcode/pytorch-finetuner 7 | $ cd pytorch-finetuner 8 | $ pip install -r requirements.txt 9 | ``` 10 | 11 | ## Example usage 12 | 13 | ### (Prerequisites) Arrange images into their respective directories 14 | 15 | A training data directory (`images/train`), validation data directory (`images/valid`), and test data directory (`images/test`) should containing one subdirectory per image class. 16 | 17 | For example, arrange training, validation, and test data as follows. 18 | 19 | ``` 20 | images/ 21 | train/ 22 | airplanes/ 23 | airplane001.jpg 24 | airplane002.jpg 25 | ... 26 | watch/ 27 | watch001.jpg 28 | watch002.jpg 29 | ... 30 | valid/ 31 | airplanes/ 32 | airplane101.jpg 33 | airplane102.jpg 34 | ... 35 | watch/ 36 | watch101.jpg 37 | watch102.jpg 38 | ... 39 | test/ 40 | airplanes/ 41 | airplane201.jpg 42 | airplane202.jpg 43 | ... 44 | watch/ 45 | watch201.jpg 46 | watch202.jpg 47 | ... 48 | ``` 49 | 50 | 51 | ### Prepare example images using bundled shell script 52 | 53 | Download [Caltech 101] dataset, and split part of it into the `example_images` directory. 54 | 55 | ``` 56 | $ util/caltech101_prepare.sh 57 | ``` 58 | 59 | - `example_images/train` is train set of 60 images for each classes 60 | - `example_images/valid` is validation set of 20 images for each classes 61 | - `example_imags/test` is test set of 20 images for each classes 62 | 63 | ``` 64 | $ util/counter.sh example_images/train 65 | example_images/train contains 10 directories 66 | Faces 60 67 | Leopards 60 68 | Motorbikes 60 69 | airplanes 60 70 | bonsai 60 71 | car_side 60 72 | chandelier 60 73 | hawksbill 60 74 | ketch 60 75 | watch 60 76 | ``` 77 | 78 | With this dataset you can immediately try fine-tuning with pytorch-finetuner. 79 | 80 | ``` 81 | $ ./train.py example_images --model resnet50 --epochs 30 --lr-step-epochs 10,20 82 | ``` 83 | 84 | 85 | ### train 86 | 87 | ``` 88 | $ ./train.py example_images --epochs 5 89 | ``` 90 | 91 | ``` 92 | $ ./train.py example_images --epochs 5 93 | Running script with args: Namespace(base_lr=0.0125, batch_size=128, cosine_annealing_eta_min=1e-05, cosine_annealing_mult=2, cosine_annealing_t_max=None, cuda=True, cutout=False, cutout_holes=None, cutout_length=None, data='example_images', disp_batches=1, epochs=3, image_dump=False, input_size=224, jitter_brightness=0.1, jitter_contrast=0.1, jitter_hue=0.05, jitter_saturation=0.1, log_dir='logs', lr_factor=0.1, lr_patience=None, lr_step_epochs=[30, 60, 90], mixup=False, mixup_alpha=None, model='resnet18', model_dir='model', momentum=0.9, no_cuda=False, optimizer='sgd', prefix='20181226184319', random_erasing=False, random_erasing_p=None, random_erasing_r1=None, random_erasing_r2=None, random_erasing_sh=None, random_erasing_sl=None, random_horizontal_flip=0.5, random_resized_crop_ratio=[0.75, 1.3333333333333333], random_resized_crop_scale=[0.08, 1.0], random_rotate_degree=3.0, random_vertical_flip=0.0, resume=None, rgb_mean=[0.485, 0.456, 0.406], rgb_std=[0.229, 0.224, 0.225], ricap=False, ricap_beta=None, ricap_with_line=False, save_best_and_last=False, save_best_only=False, scale_size=256, scratch=False, seed=None, start_epoch=0, warmup_epochs=5, wd=0.0001, workers=22) 94 | scale_size: 256 input_size: 224 95 | rgb_mean: [0.0, 0.0, 0.0] 96 | rgb_std: [1.0, 1.0, 1.0] 97 | number of train dataset: 600 98 | number of validation dataset: 200 99 | number of classes: 10 100 | => using pre-trained model 'resnet18' 101 | => using GPU 102 | => using optimizer: sgd 103 | => using MultiStepLR scheduler 104 | => model and logs prefix: 20181226184319 105 | => log dir: logs 106 | => model dir: model 107 | => tensorboardX log dir: logs/20181226184319-tensorboardX 108 | Epoch[0] Batch[1] [128/600 (20%)] speed: 15.22 samples/sec accuracy: 0.0781250000 loss: 2.5856597424 109 | Epoch[0] Batch[2] [256/600 (40%)] speed: 30.00 samples/sec accuracy: 0.0820312500 loss: 2.5729613304 110 | Epoch[0] Batch[3] [384/600 (60%)] speed: 44.36 samples/sec accuracy: 0.0885416642 loss: 2.5417649746 111 | Epoch[0] Batch[4] [512/600 (80%)] speed: 58.31 samples/sec accuracy: 0.0996093750 loss: 2.5194582939 112 | Epoch[0] Batch[5] [600/600 (100%)] speed: 60.46 samples/sec accuracy: 0.1183238626 loss: 2.4788246155 113 | Epoch[0] Train-accuracy: 0.11832386255264282 114 | Epoch[0] Train-loss: 2.4788246154785156 115 | Epoch[0] learning-rate: 0.0020833333333333333 116 | Epoch[0] Validation-accuracy: 0.2777777910232544 117 | Epoch[0] Validation-loss: 2.0763139724731445 118 | Epoch[0] Time cost: 13.761763095855713 [sec] 119 | => Saved checkpoint to "model/20181226184319-resnet18-0001.model" 120 | => Saved checkpoint to "model/20181226184319-resnet18-best.model" 121 | Epoch[1] Batch[1] [128/600 (20%)] speed: 42.31 samples/sec accuracy: 0.2578125000 loss: 2.2109849453 122 | Epoch[1] Batch[2] [256/600 (40%)] speed: 81.27 samples/sec accuracy: 0.3007812500 loss: 2.1127164364 123 | ... 124 | (snip) 125 | ... 126 | Epoch[4] Batch[1] [128/600 (20%)] speed: 42.00 samples/sec accuracy: 0.9531250000 loss: 0.2027403414 127 | Epoch[4] Batch[2] [256/600 (40%)] speed: 80.71 samples/sec accuracy: 0.9648437500 loss: 0.1552993804 128 | Epoch[4] Batch[3] [384/600 (60%)] speed: 116.50 samples/sec accuracy: 0.9661458135 loss: 0.1483899951 129 | Epoch[4] Batch[4] [512/600 (80%)] speed: 149.72 samples/sec accuracy: 0.9726562500 loss: 0.1362725645 130 | Epoch[4] Batch[5] [600/600 (100%)] speed: 170.88 samples/sec accuracy: 0.9690340757 loss: 0.1381170452 131 | Epoch[4] Train-accuracy: 0.9690340757369995 132 | Epoch[4] Train-loss: 0.13811704516410828 133 | Epoch[4] learning-rate: 0.010416666666666668 134 | Epoch[4] Validation-accuracy: 1.0 135 | Epoch[4] Validation-loss: 0.016005855053663254 136 | Epoch[4] Time cost: 5.863810300827026 [sec] 137 | => Saved checkpoint to "model/20181226184319-resnet18-0005.model" 138 | => Saved checkpoint to "model/20181226184319-resnet18-best.model" 139 | ``` 140 | 141 | ### test 142 | 143 | ``` 144 | $ ./test.py example_images -m model/20181226184319-resnet18-best.model 145 | ``` 146 | 147 | ``` 148 | $ ./test.py example_images -m model/20181226184319-resnet18-best.model 149 | Running script with args: Namespace(batch_size=128, cuda=True, data='example_images', input_size=None, log_dir='logs', model='model/20181226184319-resnet18-best.model', no_cuda=False, num_classes=None, prefix='20190117043959', print_cr=False, rgb_mean='0,0,0', rgb_std='1,1,1', scale_size=None, seed=None, topk=3, tta=False, tta_custom_seven_crop=False, tta_custom_six_crop=False, tta_custom_ten_crop=False, tta_custom_twenty_crop=False, tta_ten_crop=False, workers=22) 150 | => loading saved checkpoint 'model/20181226184319-resnet18-best.model' 151 | scale_size: 256 input_size: 224 152 | rgb_mean: [0.0, 0.0, 0.0] 153 | rgb_std: [1.0, 1.0, 1.0] 154 | number of test dataset: 200 155 | number of classes: 10 156 | Test: 100%|██████████| 2/2 [00:06<00:00, 4.37s/it, loss=0.0231, accuracy=99.6] 157 | => Saved test results to "logs/20190117043959-test-results.log" 158 | => Saved classification report to "logs/20190117043959-test-classification_report.log" 159 | model: model/20181226184319-resnet18-best.model 160 | Test-loss: 0.023076653480529785 161 | Test-accuracy: 0.995 (199/200) 162 | => Saved test log to "logs/20190117043959-test.log" 163 | ``` 164 | 165 | ### Calculate RGB mean and std of dataset 166 | 167 | ``` 168 | $ ./train.py example_images/train --calc-rgb-mean-and-std --batch-size 600 169 | => Calculate rgb mean and std (dir: example_images/train images: 600 batch-size: 600) 170 | Calc rgb mean/std: 100%|████████████████████████| 1/1 [00:04<00:00, 4.19s/it] 171 | => processed: 600 images 172 | => calculated rgb mean: [0.5068701 0.50441307 0.47790593] 173 | => calculated rgb std: [0.28773245 0.27445307 0.29044855] 174 | Please use following command options when train and test: 175 | --rgb-mean 0.507,0.504,0.478 --rgb-std 0.288,0.274,0.290 176 | ``` 177 | 178 | ## With Data Augmentation 179 | 180 | ### Random Rotation (Enabled by default) 181 | 182 | ``` 183 | ./train.py example_images --random-rotate-degree 3.0 184 | ``` 185 | 186 | ### Random Resized Crop (Enabled by default) 187 | 188 | ``` 189 | ./train.py example_images \ 190 | --random-resized-crop-scale 0.08,1.0 \ 191 | --random-resized-crop-ratio 0.75,1.3333333333333333 192 | ``` 193 | 194 | ### Random Horizontal Flip (Enabled by default) 195 | 196 | ``` 197 | ./train.py example_images --random-horizontal-flip 0.5 198 | ``` 199 | 200 | ### Random Vertical Flip 201 | 202 | ``` 203 | ./train.py example_images --random-vertical-flip 0.5 204 | ``` 205 | 206 | ### Color Jitter (Enabled by default) 207 | 208 | ``` 209 | ./train.py example_images \ 210 | --jitter-brightness 0.10 \ 211 | --jitter-contrast 0.10 \ 212 | --jitter-saturation 0.10 \ 213 | --jitter-hue 0.05 214 | ``` 215 | 216 | ### Normalize 217 | 218 | ``` 219 | ./train.py example_images \ 220 | --rgb-mean 0.485,0.456,0.406 \ 221 | --rgb-std 0.229,0.224,0.225 222 | ``` 223 | 224 | ### Cutout, Random Erasing, mixup, RICAP, ICAP, CutMix 225 | 226 | ``` 227 | # Cutout 228 | $ ./train.py example_images --cutout 229 | 230 | # Random Erasing 231 | $ ./train.py example_images --random-erasing 232 | 233 | # mixup 234 | $ ./train.py example_images --mixup 235 | 236 | # RICAP 237 | $ ./train.py example_images --ricap 238 | 239 | # ICAP 240 | $ ./train.py example_images --icap 241 | 242 | # CutMix 243 | $ ./train.py example_images --cutmix 244 | 245 | # mixup + Cutout 246 | $ ./train.py example_images --mixup --cutout 247 | 248 | # mixup + Random Erasing 249 | $ ./train.py example_images --mixup --random-erasing 250 | 251 | # RICAP + Cutout 252 | $ ./train.py example_images --ricap --cutout 253 | 254 | # RICAP + Random Erasing 255 | $ ./train.py example_images --ricap --random-erasing 256 | ``` 257 | 258 | - Cutout: [Improved Regularization of Convolutional Neural Networks with Cutout] 259 | - Random Erasing: [Random Erasing Data Augmentation] 260 | - mixup: [mixup: Beyond Empirical Risk Minimization] 261 | - RICAP: [Data Augmentation using Random Image Cropping and Patching for Deep CNNs] 262 | - CutMix: [CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features] 263 | 264 | 265 | ## Usage 266 | 267 | ``` 268 | usage: train.py [-h] [--model ARCH] [--from-scratch] [--epochs EPOCHS] 269 | [--batch-size BATCH_SIZE] [--val-batch-size VAL_BATCH_SIZE] 270 | [-j WORKERS] [--prefix PREFIX] [--log-dir LOG_DIR] 271 | [--model-dir MODEL_DIR] [--resume MODEL] 272 | [--start-epoch START_EPOCH] [--disp-batches DISP_BATCHES] 273 | [--save-best-only] [--save-best-and-last] [--drop-last] 274 | [--base-lr BASE_LR] [--lr-factor LR_FACTOR] 275 | [--lr-step-epochs LR_STEP_EPOCHS] [--lr-patience LR_PATIENCE] 276 | [--cosine-annealing-t-max COSINE_ANNEALING_T_MAX] 277 | [--cosine-annealing-mult COSINE_ANNEALING_MULT] 278 | [--cosine-annealing-eta-min COSINE_ANNEALING_ETA_MIN] 279 | [--final-lr FINAL_LR] [--optimizer OPTIMIZER] 280 | [--momentum MOMENTUM] [--wd WD] 281 | [--warmup-epochs WARMUP_EPOCHS] [--scale-size SCALE_SIZE] 282 | [--input-size INPUT_SIZE] [--rgb-mean RGB_MEAN] 283 | [--rgb-std RGB_STD] 284 | [--random-resized-crop-scale RANDOM_RESIZED_CROP_SCALE] 285 | [--random-resized-crop-ratio RANDOM_RESIZED_CROP_RATIO] 286 | [--random-horizontal-flip RANDOM_HORIZONTAL_FLIP] 287 | [--random-vertical-flip RANDOM_VERTICAL_FLIP] 288 | [--jitter-brightness JITTER_BRIGHTNESS] 289 | [--jitter-contrast JITTER_CONTRAST] 290 | [--jitter-saturation JITTER_SATURATION] 291 | [--jitter-hue JITTER_HUE] 292 | [--random-rotate-degree RANDOM_ROTATE_DEGREE] 293 | [--interpolation {BILINEAR,BICUBIC,NEAREST}] [--image-dump] 294 | [--calc-rgb-mean-and-std] [--no-cuda] [--seed SEED] 295 | [--warm_restart_next WARM_RESTART_NEXT] 296 | [--warm_restart_current WARM_RESTART_CURRENT] [--cutout] 297 | [--cutout-holes CUTOUT_HOLES] [--cutout-length CUTOUT_LENGTH] 298 | [--random-erasing] [--random-erasing-p RANDOM_ERASING_P] 299 | [--random-erasing-sl RANDOM_ERASING_SL] 300 | [--random-erasing-sh RANDOM_ERASING_SH] 301 | [--random-erasing-r1 RANDOM_ERASING_R1] 302 | [--random-erasing-r2 RANDOM_ERASING_R2] [--mixup] 303 | [--mixup-alpha MIXUP_ALPHA] [--ricap] 304 | [--ricap-beta RICAP_BETA] [--ricap-with-line] [--icap] 305 | [--icap-beta ICAP_BETA] [--cutmix] [--cutmix-beta CUTMIX_BETA] 306 | [--cutmix-prob CUTMIX_PROB] 307 | DIR 308 | 309 | train 310 | 311 | positional arguments: 312 | DIR path to dataset (train and validation) 313 | 314 | optional arguments: 315 | -h, --help show this help message and exit 316 | --model ARCH, -m ARCH 317 | specify model architecture (default: resnet18) 318 | --from-scratch do not use pre-trained weights (default: False) 319 | --epochs EPOCHS number of total epochs to run (default: 30) 320 | --batch-size BATCH_SIZE, -b BATCH_SIZE 321 | the batch size (default: 128) 322 | --val-batch-size VAL_BATCH_SIZE 323 | the validation batch size (default: 256) 324 | -j WORKERS, --workers WORKERS 325 | number of data loading workers (default: 80% of the 326 | number of cores) 327 | --prefix PREFIX prefix of model and logs (default: auto) 328 | --log-dir LOG_DIR log directory (default: logs) 329 | --model-dir MODEL_DIR 330 | model saving dir (default: model) 331 | --resume MODEL path to saved model (default: None) 332 | --start-epoch START_EPOCH 333 | manual epoch number (default: 0) 334 | --disp-batches DISP_BATCHES 335 | show progress for every n batches (default: auto) 336 | --save-best-only save only the latest best model according to the 337 | validation accuracy (default: False) 338 | --save-best-and-last save last and latest best model according to the 339 | validation accuracy (default: False) 340 | --drop-last drop the last incomplete batch, if the dataset size is 341 | not divisible by the batch size. (default: False) 342 | --base-lr BASE_LR initial learning rate (default: 0.001) 343 | --lr-factor LR_FACTOR 344 | the ratio to reduce lr on each step (default: 0.1) 345 | --lr-step-epochs LR_STEP_EPOCHS 346 | the epochs to reduce the lr (default: 10,20) 347 | --lr-patience LR_PATIENCE 348 | enable ReduceLROnPlateau lr scheduler with specified 349 | patience (default: None) 350 | --cosine-annealing-t-max COSINE_ANNEALING_T_MAX 351 | enable CosineAnnealinigLR scheduler with specified 352 | T_max (default: None) 353 | --cosine-annealing-mult COSINE_ANNEALING_MULT 354 | T_mult of CosineAnnealingLR scheduler 355 | --cosine-annealing-eta-min COSINE_ANNEALING_ETA_MIN 356 | Minimum learning rate of CosineannealingLR scheduler 357 | --final-lr FINAL_LR final_lr of AdaBound optimizer 358 | --optimizer OPTIMIZER 359 | the optimizer type (default: sgd) 360 | --momentum MOMENTUM momentum (default: 0.9) 361 | --wd WD weight decay (default: 1e-04) 362 | --warmup-epochs WARMUP_EPOCHS 363 | number of warmup epochs (default: 5) 364 | --scale-size SCALE_SIZE 365 | scale size (default: auto) 366 | --input-size INPUT_SIZE 367 | input size (default: auto) 368 | --rgb-mean RGB_MEAN RGB mean (default: 0,0,0) 369 | --rgb-std RGB_STD RGB std (default: 1,1,1) 370 | --random-resized-crop-scale RANDOM_RESIZED_CROP_SCALE 371 | range of size of the origin size cropped (default: 372 | 0.08,1.0) 373 | --random-resized-crop-ratio RANDOM_RESIZED_CROP_RATIO 374 | range of aspect ratio of the origin aspect ratio 375 | cropped (defaullt: 0.75,1.3333333333333333) 376 | --random-horizontal-flip RANDOM_HORIZONTAL_FLIP 377 | probability of the image being flipped (default: 0.5) 378 | --random-vertical-flip RANDOM_VERTICAL_FLIP 379 | probability of the image being flipped (default: 0.0) 380 | --jitter-brightness JITTER_BRIGHTNESS 381 | jitter brightness of data augmentation (default: 0.10) 382 | --jitter-contrast JITTER_CONTRAST 383 | jitter contrast of data augmentation (default: 0.10) 384 | --jitter-saturation JITTER_SATURATION 385 | jitter saturation of data augmentation (default: 0.10) 386 | --jitter-hue JITTER_HUE 387 | jitter hue of data augmentation (default: 0.05) 388 | --random-rotate-degree RANDOM_ROTATE_DEGREE 389 | rotate degree of data augmentation (default: 3.0) 390 | --interpolation {BILINEAR,BICUBIC,NEAREST} 391 | interpolation. (default: BILINEAR) 392 | --image-dump dump batch images and exit (default: False) 393 | --calc-rgb-mean-and-std 394 | calculate rgb mean and std of train images and exit 395 | (default: False) 396 | --no-cuda disables CUDA training (default: False) 397 | --seed SEED seed for initializing training. (default: None) 398 | --warm_restart_next WARM_RESTART_NEXT 399 | next warm restart epoch (default: None 400 | --warm_restart_current WARM_RESTART_CURRENT 401 | current warm restart epoch (default: None) 402 | --cutout apply cutout (default: False) 403 | --cutout-holes CUTOUT_HOLES 404 | number of holes to cut out from image (default: 1) 405 | --cutout-length CUTOUT_LENGTH 406 | length of the holes (default: 64) 407 | --random-erasing apply random erasing (default: False) 408 | --random-erasing-p RANDOM_ERASING_P 409 | random erasing p (default: 0.5) 410 | --random-erasing-sl RANDOM_ERASING_SL 411 | random erasing sl (default: 0.02) 412 | --random-erasing-sh RANDOM_ERASING_SH 413 | random erasing sh (default: 0.4) 414 | --random-erasing-r1 RANDOM_ERASING_R1 415 | random erasing r1 (default: 0.3) 416 | --random-erasing-r2 RANDOM_ERASING_R2 417 | random erasing r2 (default: 3.3333333333333335) 418 | --mixup apply mixup (default: Falsse) 419 | --mixup-alpha MIXUP_ALPHA 420 | mixup alpha (default: 0.2) 421 | --ricap apply RICAP (default: False) 422 | --ricap-beta RICAP_BETA 423 | RICAP beta (default: 0.3) 424 | --ricap-with-line RICAP with boundary line (default: False) 425 | --icap apply ICAP (default: False) 426 | --icap-beta ICAP_BETA 427 | ICAP beta (default: 0.3) 428 | --cutmix apply CutMix (default: False) 429 | --cutmix-beta CUTMIX_BETA 430 | CutMix beta (default: 1.0) 431 | --cutmix-prob CUTMIX_PROB 432 | CutMix probability (default: 1.0) 433 | ``` 434 | 435 | 436 | [Caltech 101]: http://www.vision.caltech.edu/Image_Datasets/Caltech101/ 437 | [Improved Regularization of Convolutional Neural Networks with Cutout]: https://arxiv.org/abs/1708.04552 438 | [Random Erasing Data Augmentation]: https://arxiv.org/abs/1708.04896 439 | [mixup: Beyond Empirical Risk Minimization]: https://arxiv.org/pdf/1710.09412.pdf 440 | [Data Augmentation using Random Image Cropping and Patching for Deep CNNs]: https://arxiv.org/abs/1811.09030 441 | [CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features]: https://arxiv.org/pdf/1905.04899.pdf 442 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | adabound==0.0.5 2 | cnn-finetune==0.5.3 3 | logzero==1.5.0 4 | munch==2.3.2 5 | numpy==1.16.4 6 | Pillow==6.0.0 7 | pretrainedmodels==0.7.4 8 | protobuf==3.8.0 9 | scikit-learn==0.21.2 10 | scipy==1.3.0 11 | six==1.12.0 12 | tensorboardX==1.7 13 | torch==1.1.0 14 | torchvision==0.3.0 15 | tqdm==4.32.1 16 | -------------------------------------------------------------------------------- /setup_conda.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | import os 3 | from os.path import expanduser, exists, join, realpath, basename, dirname 4 | import sys 5 | 6 | 7 | def find_path(candidates): 8 | for path in candidates: 9 | if exists(path): 10 | return path 11 | 12 | 13 | def find_path_with_members(candidates, required_names_we): 14 | """ 15 | Args: 16 | candidates (list): possible directories 17 | required_names_we (list) : candidate directory must contain 18 | filenames in this list (extensions are ignored) 19 | """ 20 | for path in candidates: 21 | if exists(path): 22 | members_we = [basename(p).split('.')[0] for p in os.listdir(path)] 23 | if all(want in members_we for want in required_names_we): 24 | return path 25 | 26 | anaconda_dir = '/opt/conda/' 27 | 28 | ipp_root = find_path([ 29 | expanduser('~/intel/ipp'), 30 | expanduser('/opt/intel/ipp'), 31 | anaconda_dir 32 | ]) 33 | if ipp_root is None: 34 | raise Exception('Cannot find path to Intel IPP') 35 | else: 36 | ipp_lib_dir = find_path_with_members( 37 | candidates=[ 38 | join(ipp_root, 'lib'), 39 | # join(ipp_root, 'lib', 'ia32'), 40 | join(ipp_root, 'lib', 'intel64'), 41 | ], 42 | required_names_we=['libippi', 'libipps'] 43 | ) 44 | if ipp_lib_dir is None: 45 | raise Exception('Cannot find path to Intel IPP') 46 | 47 | # Ensure that the image and signal processing lib are in the libdir 48 | ipp = { 49 | 'include_dir': join(ipp_root, 'include'), 50 | 'lib_dir': ipp_lib_dir, 51 | } 52 | 53 | join(ipp_root, 'lib', 'ia32', 'libippi') 54 | 55 | 56 | # jpeg_turbo_root = '/usr/local/opt/jpeg-turbo' 57 | jpeg_turbo_root = anaconda_dir 58 | if exists(jpeg_turbo_root): 59 | jpeg_turbo = { 60 | # 'lib_dir': join('/usr/local/opt/jpeg-turbo', 'lib'), 61 | # 'include_dir': join('/usr/local/opt/jpeg-turbo', 'include'), 62 | 'lib_dir': join(jpeg_turbo_root, 'lib'), 63 | 'include_dir': join(jpeg_turbo_root, 'include'), 64 | } 65 | else: 66 | jpeg_turbo_header = find_path([ 67 | '/usr/include/jpeglib.h' 68 | ]) 69 | jpeg_turbo_lib = find_path([ 70 | '/usr/lib/x86_64-linux-gnu/libjpeg.so', 71 | '/usr/lib/i386-linux-gnu/libjpeg.so', 72 | ]) 73 | 74 | # We can use the system libjpeg if its version is at least 8 75 | jpeg_version_info = basename(realpath(jpeg_turbo_lib)).split('.')[2:] 76 | jpeg_version_major = int(jpeg_version_info[0]) 77 | if jpeg_version_major < 8: 78 | raise Exception('Cannot find LibJpegTurbo') 79 | 80 | if jpeg_turbo_header is None or jpeg_turbo_lib is None: 81 | raise Exception('Cannot find LibJpegTurbo') 82 | 83 | jpeg_turbo = { 84 | 'lib_dir': dirname(jpeg_turbo_lib), 85 | 'include_dir': dirname(jpeg_turbo_header), 86 | } 87 | 88 | accimage = Extension( 89 | 'accimage', 90 | include_dirs=[ 91 | jpeg_turbo['include_dir'], 92 | ipp['include_dir'] 93 | ], 94 | libraries=['jpeg', 'ippi', 'ipps'], 95 | library_dirs=[ 96 | jpeg_turbo['lib_dir'], 97 | ipp['lib_dir'] 98 | ], 99 | sources=[ 100 | 'accimagemodule.c', 101 | 'jpegloader.c', 102 | 'imageops.c' 103 | ]) 104 | 105 | setup(name='accimage', 106 | version='0.1', 107 | description='Accelerated image loader and preprocessor for Torch', 108 | author='Marat Dukhan', 109 | author_email='maratek@gmail.com', 110 | ext_modules=[accimage]) 111 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import argparse 5 | import datetime 6 | import logging 7 | import os 8 | import signal 9 | import warnings 10 | 11 | import logzero 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | import torch.utils.data.distributed 16 | 17 | from multiprocessing import cpu_count 18 | from PIL import Image, ImageFile 19 | from sklearn.metrics import classification_report 20 | from torchvision import transforms 21 | from tqdm import tqdm 22 | from logzero import logger 23 | 24 | from util.dataloader import ImageFolderWithPaths 25 | from util.functions import accuracy, load_checkpoint, load_model_from_checkpoint, Metric, CustomTenCrop, CustomTwentyCrop, CustomSixCrop, CustomSevenCrop 26 | 27 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 28 | signal.signal(signal.SIGINT, signal.default_int_handler) 29 | ImageFile.LOAD_TRUNCATED_IMAGES = True 30 | 31 | parser = argparse.ArgumentParser(description='test') 32 | parser.add_argument('data', metavar='DIR', help='path to dataset') 33 | parser.add_argument('--prefix', default='auto', 34 | help="prefix of model and logs (default: auto)") 35 | parser.add_argument('--log-dir', default='logs', 36 | help='log directory (default: logs)') 37 | parser.add_argument('--model', '-m', type=str, 38 | help='model file to test') 39 | parser.add_argument('-j', '--workers', type=int, default=None, 40 | help='number of data loading workers (default: 80%% of the number of cores)') 41 | 42 | parser.add_argument('-b', '--batch-size', type=int, default=128, help='the batch size') 43 | parser.add_argument('--topk', type=int, default=3, 44 | help='report the top-k accuracy (default: 3)') 45 | parser.add_argument('--print-cr', action='store_true', default=False, 46 | help='print classification report (default: False)') 47 | 48 | 49 | # Test Time Augmentation 50 | parser.add_argument('--tta', action='store_true', default=False, 51 | help='test time augmentation (use FiveCrop)') 52 | parser.add_argument('--tta-ten-crop', action='store_true', default=False, 53 | help='test time augmentation (use TenCrop)') 54 | parser.add_argument('--tta-custom-six-crop', action='store_true', default=False, 55 | help='test time augmentation (use CustomSixCrop)') 56 | parser.add_argument('--tta-custom-seven-crop', action='store_true', default=False, 57 | help='test time augmentation (use CustomSevenCrop)') 58 | parser.add_argument('--tta-custom-ten-crop', action='store_true', default=False, 59 | help='test time augmentation (use CustomTenCrop)') 60 | parser.add_argument('--tta-custom-twenty-crop', action='store_true', default=False, 61 | help='test time augmentation (use CustomTwentyCrop)') 62 | 63 | # data preprocess 64 | parser.add_argument('--scale-size', type=int, default=None, 65 | help='scale size (default: auto)') 66 | parser.add_argument('--input-size', type=int, default=None, 67 | help='input size (default: auto)') 68 | parser.add_argument('--rgb-mean', type=str, default=None, 69 | help='RGB mean (default: auto)') 70 | parser.add_argument('--rgb-std', type=str, default=None, 71 | help='RGB std (default: auto)') 72 | parser.add_argument('--interpolation', type=str, default=None, 73 | choices=[None, 'BILINEAR', 'BICUBIC', 'NEAREST'], 74 | help='interpolation. (default: auto)') 75 | 76 | # misc 77 | parser.add_argument('--no-cuda', action='store_true', default=False, 78 | help='disables CUDA training') 79 | parser.add_argument('--seed', default=None, type=int, 80 | help='seed for initializing training. ') 81 | 82 | 83 | def main(): 84 | global args 85 | args = parser.parse_args() 86 | args.cuda = not args.no_cuda and torch.cuda.is_available() 87 | 88 | if args.prefix == 'auto': 89 | args.prefix = datetime.datetime.now().strftime('%Y%m%d%H%M%S') 90 | 91 | formatter = logging.Formatter('%(message)s') 92 | logzero.formatter(formatter) 93 | 94 | if not os.path.exists(args.log_dir): 95 | os.makedirs(args.log_dir, exist_ok=True) 96 | 97 | log_filename = "{}-test.log".format(args.prefix) 98 | log_filepath = os.path.join(args.log_dir, log_filename) 99 | logzero.logfile(log_filepath) 100 | 101 | if args.workers is None: 102 | args.workers = max(1, int(0.8 * cpu_count())) 103 | elif args.workers == -1: 104 | args.workers = cpu_count() 105 | 106 | cudnn.benchmark = True 107 | 108 | logger.info('Running script with args: {}'.format(str(args))) 109 | 110 | checkpoint = load_checkpoint(args, args.model) 111 | logger.info("=> loaded the model (epoch {})".format(checkpoint['epoch'])) 112 | model_arch = checkpoint['arch'] 113 | model_args = checkpoint['args'] 114 | 115 | if args.scale_size: 116 | scale_size = args.scale_size 117 | else: 118 | scale_size = model_args.scale_size 119 | if args.input_size: 120 | input_size = args.input_size 121 | else: 122 | input_size = model_args.input_size 123 | 124 | if args.rgb_mean: 125 | rgb_mean = args.rgb_mean 126 | rgb_mean = [float(mean) for mean in rgb_mean.split(',')] 127 | else: 128 | rgb_mean = model_args.rgb_mean 129 | if args.rgb_std: 130 | rgb_std = args.rgb_std 131 | rgb_std = [float(std) for std in rgb_std.split(',')] 132 | else: 133 | rgb_std = model_args.rgb_std 134 | 135 | if args.interpolation: 136 | interpolation = args.interpolation 137 | else: 138 | try: 139 | interpolation = model_args.interpolation 140 | except AttributeError: 141 | interpolation = 'BILINEAR' 142 | 143 | logger.info("scale_size: {} input_size: {}".format(scale_size, input_size)) 144 | logger.info("rgb_mean: {}".format(rgb_mean)) 145 | logger.info("rgb_std: {}".format(rgb_std)) 146 | logger.info("interpolation: {}".format(interpolation)) 147 | 148 | interpolation = getattr(Image, interpolation, 2) 149 | 150 | # Data augmentation and normalization for test 151 | data_transforms = { 152 | 'test': transforms.Compose([ 153 | transforms.Resize(scale_size, interpolation=interpolation), 154 | transforms.CenterCrop(input_size), 155 | transforms.ToTensor(), 156 | transforms.Normalize(rgb_mean, rgb_std) 157 | ]), 158 | 'test_FiveCrop': transforms.Compose([ 159 | transforms.Resize(scale_size, interpolation=interpolation), 160 | transforms.FiveCrop(input_size), 161 | transforms.Lambda(lambda crops: torch.stack( 162 | [transforms.ToTensor()(crop) for crop in crops])), 163 | transforms.Lambda(lambda crops: torch.stack( 164 | [transforms.Normalize(rgb_mean, rgb_std)(crop) for crop in crops])) 165 | ]), 166 | 'test_TenCrop': transforms.Compose([ 167 | transforms.Resize(scale_size, interpolation=interpolation), 168 | transforms.TenCrop(input_size), 169 | transforms.Lambda(lambda crops: torch.stack( 170 | [transforms.ToTensor()(crop) for crop in crops])), 171 | transforms.Lambda(lambda crops: torch.stack( 172 | [transforms.Normalize(rgb_mean, rgb_std)(crop) for crop in crops])) 173 | ]), 174 | 'test_CustomSixCrop': transforms.Compose([ 175 | transforms.Resize(scale_size, interpolation=interpolation), 176 | CustomSixCrop(input_size), 177 | transforms.Lambda(lambda crops: torch.stack( 178 | [transforms.ToTensor()(crop) for crop in crops])), 179 | transforms.Lambda(lambda crops: torch.stack( 180 | [transforms.Normalize(rgb_mean, rgb_std)(crop) for crop in crops])) 181 | ]), 182 | 'test_CustomSevenCrop': transforms.Compose([ 183 | transforms.Resize(scale_size, interpolation=interpolation), 184 | CustomSevenCrop(input_size), 185 | transforms.Lambda(lambda crops: torch.stack( 186 | [transforms.ToTensor()(crop) for crop in crops])), 187 | transforms.Lambda(lambda crops: torch.stack( 188 | [transforms.Normalize(rgb_mean, rgb_std)(crop) for crop in crops])) 189 | ]), 190 | 'test_CustomTenCrop': transforms.Compose([ 191 | transforms.Resize(scale_size, interpolation=interpolation), 192 | CustomTenCrop(input_size), 193 | transforms.Lambda(lambda crops: torch.stack( 194 | [transforms.ToTensor()(crop) for crop in crops])), 195 | transforms.Lambda(lambda crops: torch.stack( 196 | [transforms.Normalize(rgb_mean, rgb_std)(crop) for crop in crops])) 197 | ]), 198 | 'test_CustomTwentyCrop': transforms.Compose([ 199 | transforms.Resize(scale_size, interpolation=interpolation), 200 | CustomTwentyCrop(input_size), 201 | transforms.Lambda(lambda crops: torch.stack( 202 | [transforms.ToTensor()(crop) for crop in crops])), 203 | transforms.Lambda(lambda crops: torch.stack( 204 | [transforms.Normalize(rgb_mean, rgb_std)(crop) for crop in crops])) 205 | ]) 206 | 207 | } 208 | 209 | tfms = 'test' 210 | if args.tta: 211 | tfms = 'test_FiveCrop' 212 | batch_size = args.batch_size // 5 213 | elif args.tta_ten_crop: 214 | tfms = 'test_TenCrop' 215 | batch_size = args.batch_size // 10 216 | elif args.tta_custom_six_crop: 217 | tfms = 'test_CustomSixCrop' 218 | batch_size = args.batch_size // 6 219 | elif args.tta_custom_seven_crop: 220 | tfms = 'test_CustomSevenCrop' 221 | batch_size = args.batch_size // 7 222 | elif args.tta_custom_ten_crop: 223 | tfms = 'test_CustomTenCrop' 224 | batch_size = args.batch_size // 10 225 | elif args.tta_custom_twenty_crop: 226 | tfms = 'test_CustomTwentyCrop' 227 | batch_size = args.batch_size // 20 228 | else: 229 | batch_size = args.batch_size 230 | 231 | image_datasets = { 232 | 'test': ImageFolderWithPaths(os.path.join(args.data, 'test'), data_transforms[tfms]) 233 | } 234 | 235 | test_num_classes = len(image_datasets['test'].classes) 236 | test_class_names = image_datasets['test'].classes 237 | 238 | kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {} 239 | test_loader = torch.utils.data.DataLoader( 240 | image_datasets['test'], batch_size=batch_size, shuffle=False, **kwargs) 241 | 242 | logger.info("number of test dataset: {}".format(len(test_loader.dataset))) 243 | logger.info("number of classes: {}".format(len(test_class_names))) 244 | 245 | model, num_classes, class_names = load_model_from_checkpoint(args, checkpoint, test_num_classes, test_class_names) 246 | 247 | if args.topk > num_classes: 248 | logger.warn('--topk must be less than or equal to the class number of the model') 249 | args.topk = num_classes 250 | logger.warn('--topk set to {}'.format(num_classes)) 251 | 252 | # check test and train class names 253 | do_report = True 254 | if test_num_classes != num_classes: 255 | logger.info("The number of classes for train and test is different.") 256 | logger.info("Skip accuracy report.") 257 | do_report = False 258 | 259 | test(args, model_arch, model, test_loader, class_names, do_report) 260 | logger.info("=> Saved test log to \"{}\"".format(log_filepath)) 261 | 262 | 263 | def test(args, model_arch, model, test_loader, class_names, do_report): 264 | model.eval() 265 | test_accuracy = Metric('test_accuracy') 266 | test_loss = Metric('test_loss') 267 | 268 | pred = [] 269 | Y = [] 270 | correct_num = 0 271 | 272 | filepath = '{}-test-results.log'.format(args.prefix) 273 | savepath = os.path.join(args.log_dir, filepath) 274 | f = open(savepath, 'w') 275 | 276 | softmax = torch.nn.Softmax(dim=1) 277 | criterion = nn.CrossEntropyLoss() 278 | 279 | with tqdm(total=len(test_loader), desc='Test') as t: 280 | with torch.no_grad(): 281 | for data, target, paths in test_loader: 282 | if args.cuda: 283 | data = data.cuda(non_blocking=True) 284 | target = target.cuda(non_blocking=True) 285 | if args.tta or args.tta_ten_crop or \ 286 | args.tta_custom_ten_crop or args.tta_custom_twenty_crop or \ 287 | args.tta_custom_six_crop or args.tta_custom_seven_crop: 288 | bs, ncrops, c, h, w = data.size() 289 | output = model(data.view(-1, c, h, w)) 290 | output = output.view(bs, ncrops, -1).mean(1) 291 | else: 292 | output = model(data) 293 | 294 | if do_report: 295 | pred += [int(l.argmax()) for l in output] 296 | Y += [int(l) for l in target] 297 | 298 | for path, y, preds in zip(paths, target, softmax(output)): 299 | probabilities, labels = preds.topk(args.topk) 300 | preds_text = '' 301 | for i in range(args.topk): 302 | preds_text += " {} {}".format(labels[i], probabilities[i]) 303 | f.write("{} {}{}\n".format(path, int(y), preds_text)) 304 | 305 | if str(y.item()) == str(labels[0].item()): 306 | correct_num += 1 307 | 308 | if do_report: 309 | test_accuracy.update(accuracy(output, target)) 310 | test_loss.update(criterion(output, target)) 311 | t.set_postfix({'loss': test_loss.avg.item(), 312 | 'accuracy': 100. * test_accuracy.avg.item()}) 313 | t.update(1) 314 | 315 | f.close() 316 | logger.info("=> Saved test results to \"{}\"".format(savepath)) 317 | 318 | if do_report: 319 | 320 | cr_filepath = '{}-test-classification_report.log'.format(args.prefix) 321 | cr_savepath = os.path.join(args.log_dir, cr_filepath) 322 | 323 | cr = classification_report(Y, pred, target_names=class_names) 324 | if args.print_cr: 325 | print(cr) 326 | with open(cr_savepath, 'w') as crf: 327 | crf.write(cr) 328 | logger.info("=> Saved classification report to \"{}\"".format(cr_savepath)) 329 | 330 | logger.info("model: {}".format(args.model)) 331 | logger.info("Test-loss: {}".format(test_loss.avg)) 332 | logger.info("Test-accuracy: {} ({}/{})".format((correct_num / len(test_loader.dataset)), correct_num, len(test_loader.dataset))) 333 | 334 | 335 | if __name__ == '__main__': 336 | main() 337 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import argparse 5 | import logging 6 | import math 7 | import os 8 | import sys 9 | import time 10 | 11 | import logzero 12 | import numpy as np 13 | import tensorboardX 14 | import torch 15 | import torch.nn as nn 16 | 17 | import torch.backends.cudnn as cudnn 18 | 19 | from cnn_finetune import make_model 20 | from PIL import ImageFile 21 | from torchvision.utils import save_image 22 | from logzero import logger 23 | 24 | from util.dataloader import get_dataloader, CutoutForBatchImages, RandomErasingForBatchImages 25 | from util.functions import check_args, get_lr, print_batch, report, report_lr, save_model, accuracy, Metric, rand_bbox 26 | from util.optimizer import get_optimizer 27 | from util.scheduler import get_cosine_annealing_lr_scheduler, get_multi_step_lr_scheduler, get_reduce_lr_on_plateau_scheduler 28 | 29 | import signal 30 | import warnings 31 | 32 | signal.signal(signal.SIGINT, signal.default_int_handler) 33 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 34 | 35 | ImageFile.LOAD_TRUNCATED_IMAGES = True 36 | 37 | # argparse 38 | parser = argparse.ArgumentParser(description='train') 39 | parser.add_argument('data', metavar='DIR', help='path to dataset (train and validation)') 40 | 41 | # model architecture 42 | parser.add_argument('--model', '-m', metavar='ARCH', default='resnet18', 43 | help='specify model architecture (default: resnet18)') 44 | parser.add_argument('--from-scratch', dest='scratch', action='store_true', 45 | help='do not use pre-trained weights (default: False)') 46 | 47 | # epochs, batch sixe, etc 48 | parser.add_argument('--epochs', type=int, default=30, help='number of total epochs to run (default: 30)') 49 | parser.add_argument('--batch-size', '-b', type=int, default=128, help='the batch size (default: 128)') 50 | parser.add_argument('--val-batch-size', type=int, default=256, help='the validation batch size (default: 256)') 51 | parser.add_argument('-j', '--workers', type=int, default=None, 52 | help='number of data loading workers (default: 80%% of the number of cores)') 53 | parser.add_argument('--prefix', default='auto', 54 | help="prefix of model and logs (default: auto)") 55 | parser.add_argument('--log-dir', default='logs', 56 | help='log directory (default: logs)') 57 | parser.add_argument('--model-dir', default='model', 58 | help='model saving dir (default: model)') 59 | parser.add_argument('--resume', default=None, type=str, metavar='MODEL', 60 | help='path to saved model (default: None)') 61 | parser.add_argument('--start-epoch', default=0, type=int, 62 | help='manual epoch number (default: 0)') 63 | parser.add_argument('--disp-batches', type=int, default=0, 64 | help='show progress for every n batches (default: auto)') 65 | parser.add_argument('--save-best-only', action='store_true', default=False, 66 | help='save only the latest best model according to the validation accuracy (default: False)') 67 | parser.add_argument('--save-best-and-last', action='store_true', default=False, 68 | help='save last and latest best model according to the validation accuracy (default: False)') 69 | parser.add_argument('--drop-last', action='store_true', default=False, 70 | help='drop the last incomplete batch, if the dataset size is not divisible by the batch size. (default: False)') 71 | 72 | # optimizer, lr, etc 73 | parser.add_argument('--base-lr', type=float, default=0.001, 74 | help='initial learning rate (default: 0.001)') 75 | parser.add_argument('--lr-factor', type=float, default=0.1, 76 | help='the ratio to reduce lr on each step (default: 0.1)') 77 | parser.add_argument('--lr-step-epochs', type=str, default='10,20', 78 | help='the epochs to reduce the lr (default: 10,20)') 79 | parser.add_argument('--lr-patience', type=int, default=None, 80 | help='enable ReduceLROnPlateau lr scheduler with specified patience (default: None)') 81 | parser.add_argument('--cosine-annealing-t-max', type=int, default=None, 82 | help='enable CosineAnnealinigLR scheduler with specified T_max (default: None)') 83 | parser.add_argument('--cosine-annealing-mult', type=int, default=2, 84 | help='T_mult of CosineAnnealingLR scheduler') 85 | parser.add_argument('--cosine-annealing-eta-min', type=float, default=1e-05, 86 | help='Minimum learning rate of CosineannealingLR scheduler') 87 | parser.add_argument('--final-lr', type=float, default=0.1, 88 | help='final_lr of AdaBound optimizer') 89 | parser.add_argument('--optimizer', type=str, default='sgd', 90 | help='the optimizer type (default: sgd)') 91 | parser.add_argument('--momentum', type=float, default=0.9, 92 | help='momentum (default: 0.9)') 93 | parser.add_argument('--wd', type=float, default=1e-04, 94 | help='weight decay (default: 1e-04)') 95 | parser.add_argument('--warmup-epochs', type=float, default=5, 96 | help='number of warmup epochs (default: 5)') 97 | 98 | # data preprocess and augmentation settings 99 | parser.add_argument('--scale-size', type=int, default=None, 100 | help='scale size (default: auto)') 101 | parser.add_argument('--input-size', type=int, default=None, 102 | help='input size (default: auto)') 103 | parser.add_argument('--rgb-mean', type=str, default='0,0,0', 104 | help='RGB mean (default: 0,0,0)') 105 | parser.add_argument('--rgb-std', type=str, default='1,1,1', 106 | help='RGB std (default: 1,1,1)') 107 | parser.add_argument('--random-resized-crop-scale', type=str, default='0.08,1.0', 108 | help='range of size of the origin size cropped (default: 0.08,1.0)') 109 | parser.add_argument('--random-resized-crop-ratio', type=str, default='0.75,1.3333333333333333', 110 | help='range of aspect ratio of the origin aspect ratio cropped (defaullt: 0.75,1.3333333333333333)') 111 | parser.add_argument('--random-horizontal-flip', type=float, default=0.5, 112 | help='probability of the image being flipped (default: 0.5)') 113 | parser.add_argument('--random-vertical-flip', type=float, default=0.0, 114 | help='probability of the image being flipped (default: 0.0)') 115 | parser.add_argument('--jitter-brightness', type=float, default=0.10, 116 | help='jitter brightness of data augmentation (default: 0.10)') 117 | parser.add_argument('--jitter-contrast', type=float, default=0.10, 118 | help='jitter contrast of data augmentation (default: 0.10)') 119 | parser.add_argument('--jitter-saturation', type=float, default=0.10, 120 | help='jitter saturation of data augmentation (default: 0.10)') 121 | parser.add_argument('--jitter-hue', type=float, default=0.05, 122 | help='jitter hue of data augmentation (default: 0.05)') 123 | parser.add_argument('--random-rotate-degree', type=float, default=3.0, 124 | help='rotate degree of data augmentation (default: 3.0)') 125 | parser.add_argument('--interpolation', type=str, default='BILINEAR', 126 | choices=['BILINEAR', 'BICUBIC', 'NEAREST'], 127 | help='interpolation. (default: BILINEAR)') 128 | 129 | parser.add_argument('--image-dump', action='store_true', default=False, 130 | help='dump batch images and exit (default: False)') 131 | parser.add_argument('--calc-rgb-mean-and-std', action='store_true', default=False, 132 | help='calculate rgb mean and std of train images and exit (default: False)') 133 | 134 | # misc 135 | parser.add_argument('--no-cuda', action='store_true', default=False, 136 | help='disables CUDA training (default: False)') 137 | parser.add_argument('--seed', default=None, type=int, 138 | help='seed for initializing training. (default: None)') 139 | parser.add_argument('--warm_restart_next', type=int, default=None, 140 | help='next warm restart epoch (default: None') 141 | parser.add_argument('--warm_restart_current', type=int, default=None, 142 | help='current warm restart epoch (default: None)') 143 | 144 | # cutout 145 | parser.add_argument('--cutout', action='store_true', default=False, 146 | help='apply cutout (default: False)') 147 | parser.add_argument('--cutout-holes', type=int, default=1, 148 | help='number of holes to cut out from image (default: 1)') 149 | parser.add_argument('--cutout-length', type=int, default=64, 150 | help='length of the holes (default: 64)') 151 | 152 | # random erasing 153 | parser.add_argument('--random-erasing', action='store_true', default=False, 154 | help='apply random erasing (default: False)') 155 | parser.add_argument('--random-erasing-p', type=float, default=0.5, 156 | help='random erasing p (default: 0.5)') 157 | parser.add_argument('--random-erasing-sl', type=float, default=0.02, 158 | help='random erasing sl (default: 0.02)') 159 | parser.add_argument('--random-erasing-sh', type=float, default=0.4, 160 | help='random erasing sh (default: 0.4)') 161 | parser.add_argument('--random-erasing-r1', type=float, default=0.3, 162 | help='random erasing r1 (default: 0.3)') 163 | parser.add_argument('--random-erasing-r2', type=float, default=1/0.3, 164 | help='random erasing r2 (default: 3.3333333333333335)') 165 | 166 | # mixup 167 | parser.add_argument('--mixup', action='store_true', default=False, 168 | help='apply mixup (default: Falsse)') 169 | parser.add_argument('--mixup-alpha', type=float, default=0.2, 170 | help='mixup alpha (default: 0.2)') 171 | 172 | # ricap 173 | parser.add_argument('--ricap', action='store_true', default=False, 174 | help='apply RICAP (default: False)') 175 | parser.add_argument('--ricap-beta', type=float, default=0.3, 176 | help='RICAP beta (default: 0.3)') 177 | parser.add_argument('--ricap-with-line', action='store_true', default=False, 178 | help='RICAP with boundary line (default: False)') 179 | 180 | # icap 181 | parser.add_argument('--icap', action='store_true', default=False, 182 | help='apply ICAP (default: False)') 183 | parser.add_argument('--icap-beta', type=float, default=0.3, 184 | help='ICAP beta (default: 0.3)') 185 | 186 | # cutmix 187 | parser.add_argument('--cutmix', action='store_true', default=False, 188 | help='apply CutMix (default: False)') 189 | parser.add_argument('--cutmix-beta', type=float, default=1.0, 190 | help='CutMix beta (default: 1.0)') 191 | parser.add_argument('--cutmix-prob', type=float, default=1.0, 192 | help='CutMix probability (default: 1.0)') 193 | 194 | 195 | best_acc1 = 0 196 | 197 | 198 | def main(): 199 | global args, best_acc1 200 | args = parser.parse_args() 201 | 202 | args = check_args(args) 203 | 204 | formatter = logging.Formatter('%(message)s') 205 | logzero.formatter(formatter) 206 | 207 | if not os.path.exists(args.log_dir): 208 | os.makedirs(args.log_dir, exist_ok=True) 209 | 210 | log_filename = "{}-train.log".format(args.prefix) 211 | logzero.logfile(os.path.join(args.log_dir, log_filename)) 212 | 213 | # calc rgb_mean and rgb_std 214 | if args.calc_rgb_mean_and_std: 215 | calc_rgb_mean_and_std(args, logger) 216 | 217 | # setup dataset 218 | train_loader, train_num_classes, train_class_names, valid_loader, _valid_num_classes, _valid_class_names \ 219 | = get_dataloader(args, args.scale_size, args.input_size) 220 | 221 | if args.disp_batches == 0: 222 | target = len(train_loader) // 10 223 | args.disp_batches = target - target % 5 224 | if args.disp_batches < 5: 225 | args.disp_batches = 1 226 | 227 | logger.info('Running script with args: {}'.format(str(args))) 228 | logger.info("scale_size: {} input_size: {}".format(args.scale_size, args.input_size)) 229 | logger.info("rgb_mean: {}".format(args.rgb_mean)) 230 | logger.info("rgb_std: {}".format(args.rgb_std)) 231 | logger.info("number of train dataset: {}".format(len(train_loader.dataset))) 232 | logger.info("number of validation dataset: {}".format(len(valid_loader.dataset))) 233 | logger.info("number of classes: {}".format(len(train_class_names))) 234 | 235 | if args.mixup: 236 | logger.info("Using mixup: alpha:{}".format(args.mixup_alpha)) 237 | if args.ricap: 238 | logger.info("Using RICAP: beta:{}".format(args.ricap_beta)) 239 | if args.icap: 240 | logger.info("Using ICAP: beta:{}".format(args.icap_beta)) 241 | if args.cutmix: 242 | logger.info("Using CutMix: prob:{} beta:{}".format(args.cutmix_prob, args.cutmix_beta)) 243 | if args.cutout: 244 | logger.info("Using cutout: holes:{} length:{}".format(args.cutout_holes, args.cutout_length)) 245 | if args.random_erasing: 246 | logger.info("Using Random Erasing: p:{} s_l:{} s_h:{} r1:{} r2:{}".format( 247 | args.random_erasing_p, args.random_erasing_sl, args.random_erasing_sh, 248 | args.random_erasing_r1, args.random_erasing_r2)) 249 | 250 | device = torch.device("cuda" if args.cuda else "cpu") 251 | 252 | # create model 253 | if args.resume: 254 | # resume from a checkpoint 255 | if os.path.isfile(args.resume): 256 | logger.info("=> loading saved checkpoint '{}'".format(args.resume)) 257 | checkpoint = torch.load(args.resume, map_location=device) 258 | args.model = checkpoint['arch'] 259 | base_model = make_model(args.model, 260 | num_classes=train_num_classes, 261 | pretrained=False, 262 | input_size=(args.input_size, args.input_size)) 263 | base_model.load_state_dict(checkpoint['model']) 264 | args.start_epoch = checkpoint['epoch'] 265 | best_acc1 = float(checkpoint['acc1']) 266 | logger.info("=> loaded checkpoint '{}' (epoch {})" 267 | .format(args.resume, checkpoint['epoch'])) 268 | else: 269 | logger.error("=> no checkpoint found at '{}'".format(args.resume)) 270 | sys.exit(1) 271 | else: 272 | if args.scratch: 273 | # train from scratch 274 | logger.info("=> creating model '{}' (train from scratch)".format(args.model)) 275 | base_model = make_model(args.model, 276 | num_classes=train_num_classes, 277 | pretrained=False, 278 | input_size=(args.input_size, args.input_size)) 279 | else: 280 | # fine-tuning 281 | logger.info("=> using pre-trained model '{}'".format(args.model)) 282 | base_model = make_model(args.model, 283 | num_classes=train_num_classes, 284 | pretrained=True, 285 | input_size=(args.input_size, args.input_size)) 286 | 287 | if args.cuda: 288 | logger.info("=> using GPU") 289 | model = nn.DataParallel(base_model) 290 | model.to(device) 291 | else: 292 | logger.info("=> using CPU") 293 | model = base_model 294 | 295 | # define loss function (criterion) and optimizer 296 | criterion = nn.CrossEntropyLoss() 297 | optimizer = get_optimizer(args, model) 298 | logger.info('=> using optimizer: {}'.format(args.optimizer)) 299 | if args.resume: 300 | optimizer.load_state_dict(checkpoint['optimizer']) 301 | logger.info("=> restore optimizer state from checkpoint") 302 | 303 | # create scheduler 304 | if args.lr_patience: 305 | scheduler = get_reduce_lr_on_plateau_scheduler(args, optimizer) 306 | logger.info("=> using ReduceLROnPlateau scheduler") 307 | elif args.cosine_annealing_t_max: 308 | scheduler = get_cosine_annealing_lr_scheduler(args, optimizer, args.cosine_annealing_t_max, len(train_loader)) 309 | logger.info("=> using CosineAnnealingLR scheduler") 310 | else: 311 | scheduler = get_multi_step_lr_scheduler(args, optimizer, args.lr_step_epochs, args.lr_factor) 312 | logger.info("=> using MultiStepLR scheduler") 313 | if args.resume: 314 | scheduler.load_state_dict(checkpoint['scheduler']) 315 | logger.info("=> restore lr scheduler state from checkpoint") 316 | 317 | logger.info("=> model and logs prefix: {}".format(args.prefix)) 318 | logger.info("=> log dir: {}".format(args.log_dir)) 319 | logger.info("=> model dir: {}".format(args.model_dir)) 320 | tensorboradX_log_dir = os.path.join(args.log_dir, "{}-tensorboardX".format(args.prefix)) 321 | log_writer = tensorboardX.SummaryWriter(tensorboradX_log_dir) 322 | logger.info("=> tensorboardX log dir: {}".format(tensorboradX_log_dir)) 323 | 324 | if args.cuda: 325 | cudnn.benchmark = True 326 | 327 | if args.lr_patience: # ReduceLROnPlateau 328 | scheduler.step(float('inf')) 329 | elif not args.cosine_annealing_t_max: # MultiStepLR 330 | scheduler.step() 331 | 332 | # for CosineAnnealingLR 333 | if args.resume: 334 | args.warm_restart_next = checkpoint['args'].warm_restart_next 335 | args.warm_restart_current = checkpoint['args'].warm_restart_current 336 | else: 337 | if args.cosine_annealing_t_max: # CosineAnnealingLR 338 | args.warm_restart_next = args.cosine_annealing_t_max + args.warmup_epochs 339 | args.warm_restart_current = args.warmup_epochs 340 | 341 | for epoch in range(args.start_epoch, args.epochs): 342 | start = time.time() 343 | 344 | # CosineAnnealingLR warm restart 345 | if args.cosine_annealing_t_max and (epoch % args.warm_restart_next == 0) and epoch != 0: 346 | current_span = args.warm_restart_next - args.warm_restart_current 347 | next_span = current_span * args.cosine_annealing_mult 348 | args.warm_restart_current = args.warm_restart_next 349 | args.warm_restart_next = args.warm_restart_next + next_span 350 | scheduler = get_cosine_annealing_lr_scheduler(args, optimizer, next_span, len(train_loader)) 351 | 352 | if args.mixup: 353 | train(args, 'mixup', train_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer) 354 | elif args.ricap: 355 | train(args, 'ricap', train_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer) 356 | elif args.icap: 357 | train(args, 'icap', train_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer) 358 | elif args.cutmix: 359 | train(args, 'cutmix', train_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer) 360 | else: 361 | train(args, 'normal', train_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer) 362 | 363 | report_lr(epoch, 'x_learning_rate', get_lr(optimizer), logger, log_writer) 364 | 365 | acc1 = valid(args, valid_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer) 366 | 367 | elapsed_time = time.time() - start 368 | logger.info("Epoch[{}] Time cost: {} [sec]".format(epoch, elapsed_time)) 369 | 370 | # remember best acc@1 and save checkpoint 371 | is_best = acc1 > best_acc1 372 | best_acc1 = max(acc1, best_acc1) 373 | save_model(args, base_model, optimizer, scheduler, is_best, train_num_classes, train_class_names, epoch, acc1, logger) 374 | 375 | 376 | def train(args, train_mode, train_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer): 377 | total_size = 0 378 | data_size = len(train_loader.dataset) 379 | train_loss = Metric('train_loss') 380 | train_accuracy = Metric('train_accuracy') 381 | model.train() 382 | 383 | start = time.time() 384 | 385 | 386 | if train_mode in ['mixup', 'ricap', 'icap']: 387 | if args.cutout: 388 | batch_cutout = CutoutForBatchImages(n_holes=args.cutout_holes, length=args.cutout_length) 389 | if args.random_erasing: 390 | batch_random_erasing = RandomErasingForBatchImages(p=args.random_erasing_p, 391 | sl=args.random_erasing_sl, 392 | sh=args.random_erasing_sh, 393 | r1=args.random_erasing_r1, 394 | r2=args.random_erasing_r2) 395 | 396 | for batch_idx, (data, target, _paths) in enumerate(train_loader): 397 | adjust_learning_rate(args, epoch, batch_idx, train_loader, optimizer, scheduler, logger) 398 | 399 | if train_mode is 'mixup': 400 | alpha = args.mixup_alpha 401 | if alpha > 0: 402 | lam = np.random.beta(alpha, alpha) 403 | else: 404 | lam = 1 405 | index = torch.randperm(data.size(0)) 406 | mixed_data = lam * data + (1 - lam) * data[index, :] 407 | target_a, target_b = target, target[index] 408 | 409 | if args.cutout: 410 | mixed_data = batch_cutout(mixed_data) 411 | if args.random_erasing: 412 | mixed_data = batch_random_erasing(mixed_data) 413 | 414 | elif train_mode is 'ricap': 415 | beta = args.ricap_beta 416 | I_x, I_y = data.size()[2:] 417 | w = int(np.round(I_x * np.random.beta(beta, beta))) 418 | h = int(np.round(I_y * np.random.beta(beta, beta))) 419 | w_ = [w, I_x - w, w, I_x - w] 420 | h_ = [h, h, I_y - h, I_y - h] 421 | cropped_images = {} 422 | c_ = {} 423 | W_ = {} 424 | if args.cuda: 425 | data = data.cuda(non_blocking=True) 426 | for k in range(4): 427 | index = torch.randperm(data.size(0)) 428 | x_k = np.random.randint(0, I_x - w_[k] + 1) 429 | y_k = np.random.randint(0, I_y - h_[k] + 1) 430 | cropped_images[k] = data[index][:, :, x_k:x_k + w_[k], y_k:y_k + h_[k]] 431 | if args.cuda: 432 | c_[k] = target[index].cuda(non_blocking=True) 433 | else: 434 | c_[k] = target[index] 435 | W_[k] = w_[k] * h_[k] / (I_x * I_y) 436 | patched_images = torch.cat( 437 | (torch.cat((cropped_images[0], cropped_images[1]), 2), 438 | torch.cat((cropped_images[2], cropped_images[3]), 2)), 3) 439 | # draw lines 440 | if args.ricap_with_line: 441 | patched_images[:, :, w-1:w+1, :] = 0. 442 | patched_images[:, :, :, h-1:h+1] = 0. 443 | 444 | if args.cutout: 445 | patched_images = batch_cutout(patched_images) 446 | if args.random_erasing: 447 | patched_images = batch_random_erasing(patched_images) 448 | 449 | elif train_mode is 'icap': 450 | beta = args.icap_beta 451 | I_x, I_y = data.size()[2:] 452 | w = int(np.round(I_x * np.random.beta(beta, beta))) 453 | h = int(np.round(I_y * np.random.beta(beta, beta))) 454 | h_from = [0, 0, h, h] 455 | h_to = [h, h, I_y, I_y] 456 | w_from = [0, w, 0, w] 457 | w_to = [w, I_x, w, I_x] 458 | cropped_images = {} 459 | c_ = {} 460 | W_ = {} 461 | 462 | if args.cuda: 463 | data = data.cuda(non_blocking=True) 464 | for k in range(4): 465 | index = torch.randperm(data.size(0)) 466 | cropped_images[k] = data[index][:, :, h_from[k]:h_to[k], w_from[k]:w_to[k]] 467 | if args.cuda: 468 | c_[k] = target[index].cuda(non_blocking=True) 469 | else: 470 | c_[k] = target[index] 471 | W_[k] = (h_to[k] - h_from[k]) * (w_to[k] - w_from[k]) / (I_x * I_y) 472 | 473 | patched_images = torch.cat( 474 | (torch.cat((cropped_images[0], cropped_images[2]), 2), 475 | torch.cat((cropped_images[1], cropped_images[3]), 2)), 3) 476 | 477 | if args.cutout: 478 | patched_images = batch_cutout(patched_images) 479 | if args.random_erasing: 480 | patched_images = batch_random_erasing(patched_images) 481 | 482 | elif train_mode is 'cutmix': 483 | p = args.cutmix_prob 484 | beta = args.cutmix_beta 485 | r = np.random.rand(1) 486 | if beta > 0 and r < p: 487 | # generate mixed sample 488 | lam = np.random.beta(beta, beta) 489 | index = torch.randperm(data.size(0)) 490 | bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam) 491 | data[:, :, bbx1:bbx2, bby1:bby2] = data[index, :, bbx1:bbx2, bby1:bby2] 492 | mixed_data = data 493 | target_a, target_b = target, target[index] 494 | else: 495 | # change normal train mode 496 | train_mode = 'normal' 497 | 498 | 499 | # normal train mode applies Cutout or Random Erasing inside dataloader 500 | 501 | if args.image_dump: 502 | if train_mode in ['mixup', 'cutmix']: 503 | save_image(mixed_data, './samples.jpg') 504 | elif train_mode in ['ricap', 'icap']: 505 | save_image(patched_images, './samples.jpg') 506 | else: 507 | save_image(data, './samples.jpg') 508 | logger.info("image saved! at ./samples.jpg") 509 | sys.exit(0) 510 | 511 | if args.cuda: 512 | if train_mode in ['mixup', 'cutmix']: 513 | mixed_data = mixed_data.cuda(non_blocking=True) 514 | target_a = target_a.cuda(non_blocking=True) 515 | target_b = target_b.cuda(non_blocking=True) 516 | elif train_mode in ['ricap', 'icap']: 517 | patched_images = patched_images.cuda(non_blocking=True) 518 | else: # vanila train 519 | data = data.cuda(non_blocking=True) 520 | target = target.cuda(non_blocking=True) 521 | 522 | optimizer.zero_grad() 523 | 524 | # output and loss 525 | if train_mode in ['mixup', 'cutmix']: 526 | output = model(mixed_data) 527 | loss = lam * criterion(output, target_a) + (1 - lam) * criterion(output, target_b) 528 | elif train_mode in ['ricap', 'icap']: 529 | output = model(patched_images) 530 | loss = sum(W_[k] * criterion(output, c_[k]) for k in range(4)) 531 | else: # vanila train 532 | output = model(data) 533 | loss = criterion(output, target) 534 | 535 | loss.backward() 536 | optimizer.step() 537 | 538 | if train_mode in ['mixup', 'cutmix']: 539 | train_accuracy.update(lam * accuracy(output, target_a) + (1 - lam) * accuracy(output, target_b)) 540 | elif train_mode in ['ricap', 'icap']: 541 | train_accuracy.update(sum([W_[k] * accuracy(output, c_[k]) for k in range(4)])) 542 | else: # vanila train 543 | train_accuracy.update(accuracy(output, target)) 544 | 545 | train_loss.update(loss) 546 | 547 | total_size += data.size(0) 548 | if (batch_idx + 1) % args.disp_batches == 0: 549 | prop = 100. * (batch_idx+1) / len(train_loader) 550 | elapsed_time = time.time() - start 551 | speed = total_size / elapsed_time 552 | print_batch(batch_idx+1, epoch, total_size, data_size, prop, speed, train_accuracy.avg, train_loss.avg, logger) 553 | 554 | report(epoch, 'Train', 'train/loss', train_loss.avg, 'train/accuracy', train_accuracy.avg, logger, log_writer) 555 | 556 | 557 | def valid(args, valid_loader, model, device, criterion, optimizer, scheduler, epoch, logger, log_writer): 558 | valid_loss = Metric('valid_loss') 559 | valid_accuracy = Metric('valid_accuracy') 560 | model.eval() 561 | 562 | with torch.no_grad(): 563 | for (data, target, _paths) in valid_loader: 564 | if args.cuda: 565 | data = data.cuda(non_blocking=True) 566 | target = target.cuda(non_blocking=True) 567 | output = model(data) 568 | loss = criterion(output, target) 569 | 570 | valid_accuracy.update(accuracy(output, target)) 571 | valid_loss.update(loss) 572 | 573 | report(epoch, 'Validation', 'val/loss', valid_loss.avg, 'val/accuracy', valid_accuracy.avg, logger, log_writer) 574 | 575 | if args.lr_patience: # ReduceLROnPlateau 576 | scheduler.step(valid_loss.avg) 577 | elif not args.cosine_annealing_t_max: # MultiStepLR 578 | scheduler.step() 579 | 580 | return valid_accuracy.avg 581 | 582 | 583 | def get_warmup_lr_adj(args, epoch, batch_idx, train_loader, optimizer, logger): 584 | lr_adj = 1. 585 | if epoch < args.warmup_epochs: 586 | epoch = epoch * len(train_loader) 587 | epoch += float(batch_idx + 1) 588 | lr_adj = epoch / (len(train_loader) * (args.warmup_epochs + 1)) 589 | return lr_adj 590 | 591 | 592 | def adjust_learning_rate(args, epoch, batch_idx, train_loader, optimizer, scheduler, logger): 593 | lr_adj = 1. 594 | if epoch < args.warmup_epochs: 595 | lr_adj = get_warmup_lr_adj(args, epoch, batch_idx, train_loader, optimizer, logger) 596 | for param_group in optimizer.param_groups: 597 | param_group['lr'] = args.base_lr * lr_adj 598 | elif epoch == args.warmup_epochs and batch_idx == 0: 599 | for param_group in optimizer.param_groups: 600 | param_group['lr'] = args.base_lr 601 | if args.cosine_annealing_t_max: 602 | scheduler.step() 603 | else: 604 | if args.cosine_annealing_t_max: 605 | scheduler.step() 606 | 607 | 608 | def calc_rgb_mean_and_std(args, logger): 609 | from util.dataloader import get_image_datasets_for_rgb_mean_and_std 610 | from util.functions import IncrementalVariance 611 | from tqdm import tqdm 612 | 613 | image_datasets = get_image_datasets_for_rgb_mean_and_std(args, args.scale_size, args.input_size) 614 | logger.info("=> Calculate rgb mean and std (dir: {} images: {} batch-size: {})".format(args.data, len(image_datasets), args.batch_size)) 615 | 616 | if args.batch_size < len(image_datasets): 617 | logger.info("To calculate more accurate values, please specify as large a batch size as possible.") 618 | 619 | kwargs = {'num_workers': args.workers} 620 | train_loader = torch.utils.data.DataLoader( 621 | image_datasets, batch_size=args.batch_size, shuffle=False, **kwargs) 622 | 623 | iv = IncrementalVariance() 624 | processed = 0 625 | with tqdm(total=len(train_loader), desc="Calc rgb mean/std") as t: 626 | for data, _target in train_loader: 627 | numpy_images = data.numpy() 628 | batch_mean = np.mean(numpy_images, axis=(0, 2, 3)) 629 | batch_var = np.var(numpy_images, axis=(0, 2, 3)) 630 | iv.update(batch_mean, len(numpy_images), batch_var) 631 | processed += len(numpy_images) 632 | t.update(1) 633 | 634 | logger.info("=> processed: {} images".format(processed)) 635 | logger.info("=> calculated rgb mean: {}".format(iv.average)) 636 | logger.info("=> calculated rgb std: {}".format(iv.std)) 637 | 638 | np.set_printoptions(formatter={'float': '{:0.3f}'.format}) 639 | rgb_mean_option = np.array2string(iv.average, separator=',').replace('[', '').replace(']', '') 640 | rgb_std_option = np.array2string(iv.std, separator=',').replace('[', '').replace(']', '') 641 | logger.info("Please use following command options when train and test:") 642 | logger.info(" --rgb-mean {} --rgb-std {}".format(rgb_mean_option, rgb_std_option)) 643 | 644 | sys.exit(0) 645 | 646 | 647 | if __name__ == '__main__': 648 | main() 649 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knjcode/pytorch-finetuner/e464c2bef1c42c7fccbcd48b3d7cd6239eb9b83d/util/__init__.py -------------------------------------------------------------------------------- /util/caltech101_prepare.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This file download the caltech 101 dataset 4 | # (http://www.vision.caltech.edu/Image_Datasets/Caltech101/), and split it into 5 | # example_images directory. 6 | 7 | set -u 8 | shopt -s expand_aliases 9 | 10 | PWD=$(pwd) 11 | 12 | # number of images per class for training 13 | TRAIN_NUM=60 14 | VALID_NUM=20 15 | TEST_NUM=20 16 | 17 | # target classes (10 classes) 18 | CLASSES="airplanes Motorbikes Faces watch Leopards bonsai car_side ketch chandelier hawksbill" 19 | 20 | if [ ! -e "$PWD/101_ObjectCategories.tar.gz" ]; then 21 | if which wget > /dev/null 2>&1; then 22 | wget -O "$PWD/101_ObjectCategories.tar.gz" http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz 23 | else 24 | if which curl > /dev/null 2>&1; then 25 | curl http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz > "$PWD/101_ObjectCategories.tar.gz" 26 | else 27 | echo 'wget and curl commands not found.' 1>&2 28 | exit 1 29 | fi 30 | fi 31 | fi 32 | 33 | if which shuf > /dev/null; then 34 | alias shuffle='shuf' 35 | else 36 | if which gshuf > /dev/null; then 37 | alias shuffle='gshuf' 38 | else 39 | echo 'shuf or gshuf command not found.' 1>&2 40 | exit 1 41 | fi 42 | fi 43 | 44 | # check example_images 45 | if [ -e "$PWD/example_images" ]; then 46 | echo './example_images directory already exits. Remove it and retry.' 1>&2 47 | exit 1 48 | fi 49 | 50 | # split into train, validation and test set 51 | tar -xf "$PWD/101_ObjectCategories.tar.gz" 52 | TRAIN_DIR="$PWD/example_images/train" 53 | VALID_DIR="$PWD/example_images/valid" 54 | TEST_DIR="$PWD/example_images/test" 55 | mkdir -p "$TRAIN_DIR" "$VALID_DIR" "$TEST_DIR" 56 | for i in ${PWD}/101_ObjectCategories/*; do 57 | c=$(basename "$i") 58 | if echo "$CLASSES" | grep -q "$c" 59 | then 60 | echo "processing $c" 61 | mkdir -p "$TRAIN_DIR/$c" "$VALID_DIR/$c" "$TEST_DIR/$c" 62 | for j in $(find "$i" -name '*.jpg' | shuffle | head -n "$TRAIN_NUM"); do 63 | mv "$j" "$TRAIN_DIR/$c/" 64 | done 65 | for j in $(find "$i" -name '*.jpg' | shuffle | head -n "$VALID_NUM"); do 66 | mv "$j" "$VALID_DIR/$c/" 67 | done 68 | for j in $(find "$i" -name '*.jpg' | shuffle | head -n "$TEST_NUM"); do 69 | mv "$j" "$TEST_DIR/$c/" 70 | done 71 | fi 72 | done 73 | 74 | # touch .gitignore 75 | touch "$TRAIN_DIR/.gitkeep" "$VALID_DIR/.gitkeep" "$TEST_DIR/.gitkeep" 76 | 77 | # clean 78 | rm -rf "$PWD/101_ObjectCategories/" 79 | -------------------------------------------------------------------------------- /util/counter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Count the number of files/directories in each subdirectory 4 | # 5 | # Usage: 6 | # $ util/counter.sh [options] 7 | # 8 | # Options: 9 | # -r Sort in descending order of the number of files 10 | # 11 | 12 | shopt -s expand_aliases 13 | 14 | usage_exit() { 15 | echo "Usage: util/counter.sh [options] " 1>&2 16 | echo "Options:" 1>&2 17 | echo " -r Sort in descending order of the number of files" 1>&2 18 | exit 1 19 | } 20 | 21 | count_files() { 22 | local DIR="$1" 23 | find "$DIR" -maxdepth 1 -mindepth 1 -type d | while read -r subdir;do 24 | echo "$(basename "$subdir"),$(echo $(ls "$subdir" | wc -l))"; 25 | done 26 | } 27 | 28 | count_directories() { 29 | local DIR="$1" 30 | echo $(find ${DIR}/* -maxdepth 0 -type d | wc -l) 31 | } 32 | 33 | # Reference: http://qiita.com/b4b4r07/items/dcd6be0bb9c9185475bb 34 | declare -i argc=0 35 | declare -a argv=() 36 | while (( $# > 0 )) 37 | do 38 | case "$1" in 39 | -*) 40 | if [[ "$1" =~ 'r' ]]; then 41 | REVERSE=1 42 | fi 43 | if [[ "$1" =~ 'h' ]]; then 44 | usage_exit 45 | fi 46 | shift 47 | ;; 48 | *) 49 | ((++argc)) 50 | argv=("${argv[@]}" "$1") 51 | shift 52 | ;; 53 | esac 54 | done 55 | 56 | if [ $argc -lt 1 ]; then 57 | usage_exit 58 | fi 59 | 60 | DIR="${argv[0]}" 61 | 62 | FILES=$(count_files "$DIR") 63 | CLASSES=$(count_directories "$DIR") 64 | RESULT=$(echo "$FILES" | sort -n -t',' -k2 | tr ',' ' ') 65 | 66 | if which tac >/dev/null; then 67 | alias tac='tac' 68 | else 69 | alias tac='tail -r' 70 | fi 71 | 72 | if [[ "$REVERSE" = 1 ]]; then 73 | RESULT=$(echo "$RESULT" | tac) 74 | fi 75 | 76 | echo "$DIR contains "$CLASSES" directories" 77 | echo "$RESULT" | column -t 78 | -------------------------------------------------------------------------------- /util/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import math 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from torchvision import datasets, transforms 11 | from PIL import Image 12 | 13 | 14 | # generate train and validation image dataset 15 | def get_image_datasets(args, scale_size, input_size): 16 | interpolation = getattr(Image, args.interpolation, 2) 17 | 18 | data_transforms = { 19 | 'train': transforms.Compose([ 20 | transforms.RandomRotation(args.random_rotate_degree, 21 | resample=interpolation), 22 | transforms.RandomResizedCrop(input_size, 23 | scale=args.random_resized_crop_scale, 24 | ratio=args.random_resized_crop_ratio, 25 | interpolation=interpolation), 26 | transforms.RandomHorizontalFlip(args.random_horizontal_flip), 27 | transforms.RandomVerticalFlip(args.random_vertical_flip), 28 | transforms.ColorJitter( 29 | brightness=args.jitter_brightness, 30 | contrast=args.jitter_contrast, 31 | saturation=args.jitter_saturation, 32 | hue=args.jitter_hue 33 | ), 34 | transforms.ToTensor(), 35 | transforms.Normalize(args.rgb_mean, args.rgb_std) 36 | ]), 37 | 'valid': transforms.Compose([ 38 | transforms.Resize(scale_size, interpolation=interpolation), 39 | transforms.CenterCrop(input_size), 40 | transforms.ToTensor(), 41 | transforms.Normalize(args.rgb_mean, args.rgb_std) 42 | ]) 43 | } 44 | 45 | # use Cutout 46 | if args.cutout: 47 | if args.mixup or args.ricap: 48 | pass 49 | # When using mixup or ricap, cutout is applied after batch creation for learning 50 | else: 51 | data_transforms['train'] = transforms.Compose([ 52 | transforms.RandomRotation(args.random_rotate_degree, 53 | resample=interpolation), 54 | transforms.RandomResizedCrop(input_size, 55 | scale=args.random_resized_crop_scale, 56 | ratio=args.random_resized_crop_ratio, 57 | interpolation=interpolation), 58 | transforms.RandomHorizontalFlip(args.random_horizontal_flip), 59 | transforms.RandomVerticalFlip(args.random_vertical_flip), 60 | transforms.ColorJitter( 61 | brightness=args.jitter_brightness, 62 | contrast=args.jitter_contrast, 63 | saturation=args.jitter_saturation, 64 | hue=args.jitter_hue 65 | ), 66 | transforms.ToTensor(), 67 | transforms.Normalize(args.rgb_mean, args.rgb_std), 68 | Cutout(n_holes=args.cutout_holes, length=args.cutout_length) 69 | ]) 70 | else: 71 | args.cutout_holes = None 72 | args.cutout_length = None 73 | 74 | # use Random erasing 75 | if args.random_erasing: 76 | if args.mixup or args.ricap: 77 | pass 78 | # When using mixup or ricap, cutout is applied after batch creation for learning 79 | else: 80 | data_transforms['train'] = transforms.Compose([ 81 | transforms.RandomRotation(args.random_rotate_degree, 82 | resample=interpolation), 83 | transforms.RandomResizedCrop(input_size, 84 | scale=args.random_resized_crop_scale, 85 | ratio=args.random_resized_crop_ratio, 86 | interpolation=interpolation), 87 | transforms.RandomHorizontalFlip(args.random_horizontal_flip), 88 | transforms.RandomVerticalFlip(args.random_vertical_flip), 89 | transforms.ColorJitter( 90 | brightness=args.jitter_brightness, 91 | contrast=args.jitter_contrast, 92 | saturation=args.jitter_saturation, 93 | hue=args.jitter_hue 94 | ), 95 | transforms.ToTensor(), 96 | transforms.Normalize(args.rgb_mean, args.rgb_std), 97 | RandomErasing(p=args.random_erasing_p, 98 | sl=args.random_erasing_sl, 99 | sh=args.random_erasing_sh, 100 | r1=args.random_erasing_r1, 101 | r2=args.random_erasing_r2) 102 | ]) 103 | else: 104 | args.random_erasing_p = None 105 | args.random_erasing_sl = None 106 | args.random_erasing_sh = None 107 | args.random_erasing_r1 = None 108 | args.random_erasing_r2 = None 109 | 110 | image_datasets = { 111 | x: ImageFolderWithPaths(os.path.join(args.data, x), data_transforms[x]) 112 | for x in ['train', 'valid'] 113 | } 114 | 115 | return image_datasets 116 | 117 | 118 | def get_image_datasets_for_rgb_mean_and_std(args, scale_size, input_size): 119 | transform = transforms.Compose([ 120 | transforms.Resize(scale_size), 121 | transforms.CenterCrop(input_size), 122 | transforms.ToTensor() 123 | ]) 124 | image_datasets = datasets.ImageFolder(args.data, transform=transform) 125 | return image_datasets 126 | 127 | 128 | # generate train and validation dataloaders 129 | def get_dataloader(args, scale_size, input_size): 130 | image_datasets = get_image_datasets(args, scale_size, input_size) 131 | 132 | train_num_classes = len(image_datasets['train'].classes) 133 | val_num_classes = len(image_datasets['valid'].classes) 134 | assert train_num_classes == val_num_classes, 'The number of classes for train and validation is different' 135 | 136 | train_class_names = image_datasets['train'].classes 137 | val_class_names = image_datasets['valid'].classes 138 | 139 | kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {} 140 | train_loader = torch.utils.data.DataLoader( 141 | image_datasets['train'], batch_size=args.batch_size, shuffle=True, drop_last=args.drop_last, **kwargs) 142 | val_loader = torch.utils.data.DataLoader( 143 | image_datasets['valid'], batch_size=args.val_batch_size, shuffle=False, **kwargs) 144 | 145 | return train_loader, train_num_classes, train_class_names, \ 146 | val_loader, val_num_classes, val_class_names 147 | 148 | 149 | # https://arxiv.org/pdf/1708.04552.pdf 150 | # modified from https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 151 | class Cutout(object): 152 | """Randomly mask out one or more patches from an image. 153 | Args: 154 | n_holes (int): Number of patches to cut out of each image. 155 | length (int): The length (in pixels) of each square patch. 156 | """ 157 | def __init__(self, n_holes, length): 158 | self.n_holes = n_holes 159 | self.length = length 160 | 161 | def __call__(self, img): 162 | """ 163 | Args: 164 | img (Tensor): Tensor image of size (C, H, W). 165 | Returns: 166 | Tensor: Image with n_holes of dimension length x length cut out of it. 167 | """ 168 | h = img.size(1) 169 | w = img.size(2) 170 | 171 | mask_value = img.mean() 172 | 173 | for n in range(self.n_holes): 174 | top = np.random.randint(0 - self.length // 2, h) 175 | left = np.random.randint(0 - self.length // 2, w) 176 | bottom = top + self.length 177 | right = left + self.length 178 | 179 | top = 0 if top < 0 else top 180 | left = 0 if left < 0 else left 181 | 182 | img[:, top:bottom, left:right].fill_(mask_value) 183 | 184 | return img 185 | 186 | 187 | # https://arxiv.org/pdf/1708.04552.pdf 188 | # modified from https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 189 | class CutoutForBatchImages(object): 190 | """Randomly mask out one or more patches from an image. 191 | Args: 192 | n_holes (int): Number of patches to cut out of each image. 193 | length (int): The length (in pixels) of each square patch. 194 | """ 195 | def __init__(self, n_holes, length): 196 | self.n_holes = n_holes 197 | self.length = length 198 | 199 | def __call__(self, img): 200 | """ 201 | Args: 202 | batch images (Tensor): Tensor images of size (N, C, H, W). 203 | Returns: 204 | Tensor: Images with n_holes of dimension length x length cut out of it. 205 | """ 206 | h = img.size(2) 207 | w = img.size(3) 208 | 209 | mask_value = img.mean() 210 | 211 | for i in range(img.size(0)): 212 | for n in range(self.n_holes): 213 | top = np.random.randint(0 - self.length // 2, h) 214 | left = np.random.randint(0 - self.length // 2, w) 215 | bottom = top + self.length 216 | right = left + self.length 217 | 218 | top = 0 if top < 0 else top 219 | left = 0 if left < 0 else left 220 | 221 | img[i, :, top:bottom, left:right].fill_(mask_value) 222 | 223 | return img 224 | 225 | 226 | # https://arxiv.org/pdf/1708.04896.pdf 227 | # modified from https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py 228 | class RandomErasing(object): 229 | ''' 230 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 231 | ------------------------------------------------------------------------------------- 232 | p: The probability that the operation will be performed. 233 | sl: min erasing area 234 | sh: max erasing area 235 | r1: min aspect ratio 236 | r2: max aspect ratio 237 | ------------------------------------------------------------------------------------- 238 | ''' 239 | def __init__(self, p=0.5, sl=0.02, sh=0.4, r1=0.3, r2=1/0.3): 240 | self.p = p 241 | self.sl = sl 242 | self.sh = sh 243 | self.r1 = r1 244 | self.r2 = r2 245 | 246 | def __call__(self, img): 247 | """ 248 | Args: 249 | img (Tensor): Tensor image of size (C, H, W). 250 | Returns: 251 | Tensor: Image with Random erasing. 252 | """ 253 | if np.random.uniform(0, 1) > self.p: 254 | return img 255 | 256 | area = img.size()[1] * img.size()[2] 257 | for _attempt in range(100): 258 | target_area = np.random.uniform(self.sl, self.sh) * area 259 | aspect_ratio = np.random.uniform(self.r1, self.r2) 260 | 261 | h = int(round(math.sqrt(target_area * aspect_ratio))) 262 | w = int(round(math.sqrt(target_area / aspect_ratio))) 263 | 264 | if w < img.size()[2] and h < img.size()[1]: 265 | x1 = np.random.randint(0, img.size()[1] - h) 266 | y1 = np.random.randint(0, img.size()[2] - w) 267 | img[:, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(3, h, w)) 268 | return img 269 | 270 | return img 271 | 272 | 273 | # https://arxiv.org/pdf/1708.04896.pdf 274 | # modified from https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py 275 | class RandomErasingForBatchImages(object): 276 | ''' 277 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 278 | ------------------------------------------------------------------------------------- 279 | p: The probability that the operation will be performed. 280 | sl: min erasing area 281 | sh: max erasing area 282 | r1: min aspect ratio 283 | r2: max aspect ratio 284 | ------------------------------------------------------------------------------------- 285 | ''' 286 | def __init__(self, p=0.5, sl=0.02, sh=0.4, r1=0.3, r2=1/0.3): 287 | self.p = p 288 | self.sl = sl 289 | self.sh = sh 290 | self.r1 = r1 291 | self.r2 = r2 292 | 293 | def __call__(self, img): 294 | """ 295 | Args: 296 | batch images (Tensor): Tensor images of size (N, C, H, W). 297 | Returns: 298 | Tensor: Images with Random erasing. 299 | """ 300 | area = img.size()[2] * img.size()[3] 301 | for i in range(img.size(0)): 302 | 303 | if np.random.uniform(0, 1) > self.p: 304 | continue 305 | 306 | for _attempt in range(100): 307 | target_area = np.random.uniform(self.sl, self.sh) * area 308 | aspect_ratio = np.random.uniform(self.r1, self.r2) 309 | 310 | h = int(round(math.sqrt(target_area * aspect_ratio))) 311 | w = int(round(math.sqrt(target_area / aspect_ratio))) 312 | 313 | if w < img.size()[3] and h < img.size()[2]: 314 | x1 = np.random.randint(0, img.size()[2] - h) 315 | y1 = np.random.randint(0, img.size()[3] - w) 316 | img[i, :, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(1, 3, h, w)) 317 | break 318 | 319 | return img 320 | 321 | 322 | # taken from https://gist.github.com/andrewjong/6b02ff237533b3b2c554701fb53d5c4d 323 | class ImageFolderWithPaths(datasets.ImageFolder): 324 | """Custom dataset that includes image file paths. Extends 325 | torchvision.datasets.ImageFolder 326 | """ 327 | 328 | # override the __getitem__ method. this is the method dataloader calls 329 | def __getitem__(self, index): 330 | # this is what ImageFolder normally returns 331 | original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) 332 | # the image file path 333 | path = self.imgs[index][0] 334 | # make a new tuple that includes original and the path 335 | tuple_with_path = (original_tuple + (path,)) 336 | return tuple_with_path 337 | -------------------------------------------------------------------------------- /util/functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import datetime 5 | import numbers 6 | import os 7 | import shutil 8 | import sys 9 | import warnings 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | 16 | from cnn_finetune import make_model 17 | 18 | from multiprocessing import cpu_count 19 | from torchvision.transforms.functional import center_crop, hflip, vflip, resize 20 | 21 | 22 | class IncrementalVariance(object): 23 | def __init__(self, avg=0, count=0, var=0): 24 | self.avg = avg 25 | self.count = count 26 | self.var = var 27 | 28 | def update(self, avg, count, var): 29 | delta = self.avg - avg 30 | m_a = self.var * (self.count - 1) 31 | m_b = var * (count - 1) 32 | M2 = m_a + m_b + delta ** 2 * self.count * count / (self.count + count) 33 | self.var = M2 / (self.count + count - 1) 34 | self.avg = (self.avg * self.count + avg * count) / (self.count + count) 35 | self.count = self.count + count 36 | 37 | @property 38 | def average(self): 39 | return self.avg 40 | 41 | @property 42 | def variance(self): 43 | return self.var 44 | 45 | @property 46 | def std(self): 47 | return np.sqrt(self.var) 48 | 49 | 50 | class Metric(object): 51 | def __init__(self, name): 52 | self.name = name 53 | self.sum = torch.tensor(0.) 54 | self.n = torch.tensor(0.) 55 | 56 | def update(self, val): 57 | self.sum += val.item() 58 | self.n += 1 59 | 60 | @property 61 | def avg(self): 62 | return self.sum / self.n 63 | 64 | 65 | def accuracy(output, target): 66 | pred = output.max(1, keepdim=True)[1] 67 | return pred.eq(target.view_as(pred)).cpu().float().mean() 68 | 69 | 70 | def print_batch(batch, epoch, current_num, total_num, ratio, speed, average_acc, average_loss, logger): 71 | logger.info('Epoch[{}] Batch[{}] [{}/{} ({:.0f}%)]\tspeed: {:.2f} samples/sec\taccuracy: {:.10f}\tloss: {:.10f}'.format( 72 | epoch, batch, current_num, total_num, ratio, speed, average_acc, average_loss)) 73 | 74 | 75 | def report(epoch, phase, loss_name, loss_avg, acc_name, acc_avg, logger, log_writer): 76 | logger.info("Epoch[{}] {}-accuracy: {}".format(epoch, phase, acc_avg)) 77 | logger.info("Epoch[{}] {}-loss: {}".format(epoch, phase, loss_avg)) 78 | if log_writer: 79 | log_writer.add_scalar(loss_name, loss_avg, epoch) 80 | log_writer.add_scalar(acc_name, acc_avg, epoch) 81 | 82 | 83 | def report_lr(epoch, lr_name, lr, logger, log_writer): 84 | logger.info("Epoch[{}] learning-rate: {}".format(epoch, lr)) 85 | if log_writer: 86 | log_writer.add_scalar(lr_name, lr, epoch) 87 | 88 | 89 | def get_lr(optimizer): 90 | for param_group in optimizer.param_groups: 91 | return param_group['lr'] 92 | 93 | 94 | def save_model(args, base_model, optimizer, scheduler, is_best, num_classes, class_names, epoch, acc1, logger): 95 | filepath = '{}-{}-{:04}.model'.format(args.prefix, args.model, epoch+1) 96 | savepath = os.path.join(args.model_dir, filepath) 97 | state = { 98 | 'model': base_model.state_dict(), 99 | 'optimizer': optimizer.state_dict(), 100 | 'scheduler': scheduler.state_dict(), 101 | 'arch': args.model, 102 | 'num_classes': num_classes, 103 | 'class_names': class_names, 104 | 'args': args, 105 | 'epoch': epoch + 1, 106 | 'acc1': float(acc1) 107 | } 108 | os.makedirs(args.model_dir, exist_ok=True) 109 | 110 | if not (args.save_best_only or args.save_best_and_last): 111 | torch.save(state, savepath) 112 | logger.info("=> Saved checkpoint to \"{}\"".format(savepath)) 113 | 114 | if is_best: 115 | filepath = '{}-{}-best.model'.format(args.prefix, args.model) 116 | bestpath = os.path.join(args.model_dir, filepath) 117 | if args.save_best_only or args.save_best_and_last: 118 | torch.save(state, bestpath) 119 | else: 120 | shutil.copyfile(savepath, bestpath) 121 | logger.info("=> Saved checkpoint to \"{}\"".format(bestpath)) 122 | 123 | if (args.epochs - 1 == epoch) and args.save_best_and_last: 124 | torch.save(state, savepath) 125 | logger.info("=> Saved checkpoint to \"{}\"".format(savepath)) 126 | 127 | 128 | def load_checkpoint(args, model_path): 129 | device = torch.device("cuda" if args.cuda else "cpu") 130 | print("=> loading saved checkpoint '{}'".format(model_path)) 131 | checkpoint = torch.load(model_path, map_location=device) 132 | return checkpoint 133 | 134 | 135 | def load_model_from_checkpoint(args, checkpoint, test_num_classes, test_class_names): 136 | device = torch.device("cuda" if args.cuda else "cpu") 137 | model_arch = checkpoint['arch'] 138 | num_classes = checkpoint.get('num_classes', 0) 139 | if num_classes == 0: 140 | num_classes = test_num_classes 141 | base_model = make_model(model_arch, num_classes=num_classes, pretrained=False) 142 | base_model.load_state_dict(checkpoint['model']) 143 | class_names = checkpoint.get('class_names', []) 144 | if len(class_names) == 0: 145 | class_names = test_class_names 146 | 147 | if args.cuda: 148 | model = nn.DataParallel(base_model) 149 | else: 150 | model = base_model 151 | model.to(device) 152 | 153 | return model, num_classes, class_names 154 | 155 | 156 | def check_args(args): 157 | args.cuda = not args.no_cuda and torch.cuda.is_available() 158 | 159 | if args.mixup and args.ricap: 160 | warnings.warn('You can only one of the --mixup and --ricap can be activated.') 161 | sys.exit(1) 162 | 163 | if args.cutout and args.random_erasing: 164 | warnings.warn('You can only one of the --cutout and --random-erasing can be activated.') 165 | sys.exit(1) 166 | 167 | try: 168 | args.lr_step_epochs = [int(epoch) for epoch in args.lr_step_epochs.split(',')] 169 | except ValueError: 170 | warnings.warn('invalid --lr-step-epochs') 171 | sys.exit(1) 172 | 173 | try: 174 | args.random_resized_crop_scale = [float(scale) for scale in args.random_resized_crop_scale.split(',')] 175 | if len(args.random_resized_crop_scale) != 2: 176 | raise ValueError 177 | except ValueError: 178 | warnings.warn('invalid --random-resized-crop-scale') 179 | sys.exit(1) 180 | 181 | try: 182 | args.random_resized_crop_ratio = [float(ratio) for ratio in args.random_resized_crop_ratio.split(',')] 183 | if len(args.random_resized_crop_ratio) != 2: 184 | raise ValueError 185 | except ValueError: 186 | warnings.warn('invalid --random-resized-crop-ratio') 187 | sys.exit(1) 188 | 189 | if args.prefix == 'auto': 190 | args.prefix = datetime.datetime.now().strftime('%Y%m%d%H%M%S') 191 | 192 | if args.workers is None: 193 | args.workers = max(1, int(0.8 * cpu_count())) 194 | elif args.workers == -1: 195 | args.workers = cpu_count() 196 | 197 | if args.seed is not None: 198 | np.random.seed(args.seed) 199 | torch.manual_seed(args.seed) 200 | cudnn.deterministic = True 201 | warnings.warn('You have chosen to seed training. ' 202 | 'This will turn on the CUDNN deterministic setting, ' 203 | 'which can slow down your training considerably! ' 204 | 'You may see unexpected behavior when restarting ' 205 | 'from checkpoints.') 206 | 207 | args.rgb_mean = [float(mean) for mean in args.rgb_mean.split(',')] 208 | args.rgb_std = [float(std) for std in args.rgb_std.split(',')] 209 | 210 | if args.model == 'pnasnet5large': 211 | scale_size = 352 212 | input_size = 331 213 | elif 'inception' in args.model: 214 | scale_size = 320 215 | input_size = 299 216 | elif 'xception' in args.model: 217 | scale_size = 320 218 | input_size = 299 219 | else: 220 | scale_size = 256 221 | input_size = 224 222 | 223 | if args.scale_size: 224 | scale_size = args.scale_size 225 | else: 226 | args.scale_size = scale_size 227 | if args.input_size: 228 | input_size = args.input_size 229 | else: 230 | args.input_size = input_size 231 | 232 | if not args.cutout: 233 | args.cutout_holes = None 234 | args.cutout_length = None 235 | 236 | if not args.random_erasing: 237 | args.random_erasing_p = None 238 | args.random_erasing_r1 = None 239 | args.random_erasing_r2 = None 240 | args.random_erasing_sh = None 241 | args.random_erasing_sl = None 242 | 243 | if not args.mixup: 244 | args.mixup_alpha = None 245 | 246 | if not args.ricap: 247 | args.ricap_beta = None 248 | args.ricap_with_line = False 249 | 250 | return args 251 | 252 | 253 | def custom_six_crop(img, size): 254 | """Crop the given PIL Image into custom six crops. 255 | .. Note:: 256 | This transform returns a tuple of images and there may be a 257 | mismatch in the number of inputs and targets your ``Dataset`` returns. 258 | Args: 259 | size (sequence or int): Desired output size of the crop. If size is an 260 | int instead of sequence like (h, w), a square crop (size, size) is 261 | made. 262 | Returns: 263 | tuple: tuple (tl, tr, bl, br, center, full) 264 | """ 265 | if isinstance(size, numbers.Number): 266 | size = (int(size), int(size)) 267 | else: 268 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 269 | 270 | w, h = img.size 271 | crop_h, crop_w = size 272 | if crop_w > w or crop_h > h: 273 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 274 | (h, w))) 275 | tl = img.crop((0, 0, crop_w, crop_h)) 276 | tr = img.crop((w - crop_w, 0, w, crop_h)) 277 | bl = img.crop((0, h - crop_h, crop_w, h)) 278 | br = img.crop((w - crop_w, h - crop_h, w, h)) 279 | center = center_crop(img, (crop_h, crop_w)) 280 | full = resize(img, (crop_h, crop_w)) 281 | return (tl, tr, bl, br, center, full) 282 | 283 | 284 | def custom_seven_crop(img, size): 285 | """Crop the given PIL Image into custom seven crops. 286 | .. Note:: 287 | This transform returns a tuple of images and there may be a 288 | mismatch in the number of inputs and targets your ``Dataset`` returns. 289 | Args: 290 | size (sequence or int): Desired output size of the crop. If size is an 291 | int instead of sequence like (h, w), a square crop (size, size) is 292 | made. 293 | Returns: 294 | tuple: tuple (tl, tr, bl, br, center, semi_full, full) 295 | """ 296 | if isinstance(size, numbers.Number): 297 | size = (int(size), int(size)) 298 | else: 299 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 300 | 301 | w, h = img.size 302 | crop_h, crop_w = size 303 | if crop_w > w or crop_h > h: 304 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 305 | (h, w))) 306 | shift_w = int(round(w - crop_w) / 4.) 307 | shift_h = int(round(h - crop_h) / 4.) 308 | 309 | tl = img.crop((0, 0, crop_w, crop_h)) 310 | tr = img.crop((w - crop_w, 0, w, crop_h)) 311 | bl = img.crop((0, h - crop_h, crop_w, h)) 312 | br = img.crop((w - crop_w, h - crop_h, w, h)) 313 | center = center_crop(img, (crop_h, crop_w)) 314 | semi_full = resize(img.crop((shift_w, shift_h, w - shift_w, h - shift_h)), (crop_h, crop_w)) 315 | full = resize(img, (crop_h, crop_w)) 316 | return (tl, tr, bl, br, center, semi_full, full) 317 | 318 | 319 | def custom_ten_crop(img, size): 320 | """Crop the given PIL Image into custom ten crops. 321 | .. Note:: 322 | This transform returns a tuple of images and there may be a 323 | mismatch in the number of inputs and targets your ``Dataset`` returns. 324 | Args: 325 | size (sequence or int): Desired output size of the crop. If size is an 326 | int instead of sequence like (h, w), a square crop (size, size) is 327 | made. 328 | Returns: 329 | tuple: tuple (tl, tr, bl, br, center, tl2, tr2, bl2, br2, full) 330 | """ 331 | if isinstance(size, numbers.Number): 332 | size = (int(size), int(size)) 333 | else: 334 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 335 | 336 | w, h = img.size 337 | crop_h, crop_w = size 338 | if crop_w > w or crop_h > h: 339 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 340 | (h, w))) 341 | shift_w = int(round(w - crop_w) / 4.) 342 | shift_h = int(round(h - crop_h) / 4.) 343 | 344 | tl = img.crop((0, 0, crop_w, crop_h)) 345 | tr = img.crop((w - crop_w, 0, w, crop_h)) 346 | bl = img.crop((0, h - crop_h, crop_w, h)) 347 | br = img.crop((w - crop_w, h - crop_h, w, h)) 348 | center = center_crop(img, (crop_h, crop_w)) 349 | tl2 = img.crop((shift_w, shift_h, crop_w + shift_w, crop_h + shift_h)) # + + 350 | tr2 = img.crop((w - crop_w - shift_w, shift_h, w - shift_w, crop_h + shift_h)) # - + 351 | bl2 = img.crop((shift_w, h - crop_h - shift_h, crop_w + shift_w, h - shift_h)) # + - 352 | br2 = img.crop((w - crop_w - shift_w, h - crop_h - shift_h, w - shift_w, h - shift_h)) # - - 353 | full = resize(img, (crop_h, crop_w)) 354 | return (tl, tr, bl, br, center, tl2, tr2, bl2, br2, full) 355 | 356 | 357 | def custom_twenty_crop(img, size, vertical_flip=False): 358 | r"""Crop the given PIL Image into custom twenty crops. 359 | .. Note:: 360 | This transform returns a tuple of images and there may be a 361 | mismatch in the number of inputs and targets your ``Dataset`` returns. 362 | Args: 363 | size (sequence or int): Desired output size of the crop. If size is an 364 | int instead of sequence like (h, w), a square crop (size, size) is 365 | made. 366 | vertical_flip (bool): Use vertical flipping instead of horizontal 367 | Returns: 368 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) 369 | Corresponding top left, top right, bottom left, bottom right and center crop 370 | and same for the flipped image. 371 | """ 372 | if isinstance(size, numbers.Number): 373 | size = (int(size), int(size)) 374 | else: 375 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 376 | 377 | first_ten = custom_ten_crop(img, size) 378 | 379 | if vertical_flip: 380 | img = vflip(img) 381 | else: 382 | img = hflip(img) 383 | 384 | second_ten = custom_ten_crop(img, size) 385 | return first_ten + second_ten 386 | 387 | 388 | def rand_bbox(size, lam): 389 | W = size[2] 390 | H = size[3] 391 | cut_rat = np.sqrt(1. - lam) 392 | cut_w = np.int(W * cut_rat) 393 | cut_h = np.int(H * cut_rat) 394 | 395 | # uniform 396 | cx = np.random.randint(W) 397 | cy = np.random.randint(H) 398 | 399 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 400 | bby1 = np.clip(cy - cut_h // 2, 0, H) 401 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 402 | bby2 = np.clip(cy + cut_h // 2, 0, H) 403 | 404 | return bbx1, bby1, bbx2, bby2 405 | 406 | 407 | class CustomSixCrop(object): 408 | def __init__(self, size): 409 | self.size = size 410 | if isinstance(size, numbers.Number): 411 | self.size = (int(size), int(size)) 412 | else: 413 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 414 | self.size = size 415 | 416 | def __call__(self, img): 417 | return custom_six_crop(img, self.size) 418 | 419 | def __repr__(self): 420 | return self.__class__.__name__ + '(size={0})'.format(self.size) 421 | 422 | 423 | class CustomSevenCrop(object): 424 | def __init__(self, size): 425 | self.size = size 426 | if isinstance(size, numbers.Number): 427 | self.size = (int(size), int(size)) 428 | else: 429 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 430 | self.size = size 431 | 432 | def __call__(self, img): 433 | return custom_seven_crop(img, self.size) 434 | 435 | def __repr__(self): 436 | return self.__class__.__name__ + '(size={0})'.format(self.size) 437 | 438 | 439 | class CustomTenCrop(object): 440 | def __init__(self, size): 441 | self.size = size 442 | if isinstance(size, numbers.Number): 443 | self.size = (int(size), int(size)) 444 | else: 445 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 446 | self.size = size 447 | 448 | def __call__(self, img): 449 | return custom_ten_crop(img, self.size) 450 | 451 | def __repr__(self): 452 | return self.__class__.__name__ + '(size={0})'.format(self.size) 453 | 454 | 455 | class CustomTwentyCrop(object): 456 | def __init__(self, size): 457 | self.size = size 458 | if isinstance(size, numbers.Number): 459 | self.size = (int(size), int(size)) 460 | else: 461 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 462 | self.size = size 463 | 464 | def __call__(self, img): 465 | return custom_twenty_crop(img, self.size) 466 | 467 | def __repr__(self): 468 | return self.__class__.__name__ + '(size={0})'.format(self.size) 469 | -------------------------------------------------------------------------------- /util/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import torch.optim as optim 5 | import adabound 6 | 7 | 8 | def get_optimizer(args, model): 9 | if args.optimizer == 'sgd': 10 | optimizer = optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.wd) 11 | elif args.optimizer == 'nag': 12 | optimizer = optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) 13 | elif args.optimizer == 'adam': 14 | optimizer = optim.Adam(model.parameters(), lr=args.base_lr, weight_decay=args.wd) 15 | elif args.optimizer == 'amsgrad': 16 | optimizer = optim.Adam(model.parameters(), lr=args.base_lr, weight_decay=args.wd, amsgrad=True) 17 | elif args.optimizer == 'rmsprop': 18 | optimizer = optim.RMSprop(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.wd) 19 | elif args.optimizer == 'adabound': 20 | optimizer = adabound.AdaBound(model.parameters(), lr=args.base_lr, weight_decay=args.wd, final_lr=args.final_lr) 21 | elif args.optimizer == 'amsbound': 22 | optimizer = adabound.AdaBound(model.parameters(), lr=args.base_lr, weight_decay=args.wd, final_lr=args.final_lr, amsbound=True) 23 | else: 24 | raise 'unknown optimizer' 25 | 26 | return optimizer 27 | -------------------------------------------------------------------------------- /util/scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import torch.optim as optim 5 | 6 | 7 | def get_cosine_annealing_lr_scheduler(args, optimizer, max_epoch, iteration): 8 | return optim.lr_scheduler.CosineAnnealingLR(optimizer, max_epoch * iteration, eta_min=args.cosine_annealing_eta_min) 9 | 10 | 11 | def get_multi_step_lr_scheduler(args, optimizer, lr_step_epochs, lr_factor): 12 | return optim.lr_scheduler.MultiStepLR(optimizer, lr_step_epochs, lr_factor) 13 | 14 | 15 | def get_reduce_lr_on_plateau_scheduler(args, optimizer): 16 | return optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=args.lr_factor, patience=args.lr_patience, verbose=True,) 17 | --------------------------------------------------------------------------------