├── LICENSE_ADAIN ├── LICENSE_CUTMIX ├── LICENSE_PUZZLEMIX ├── NOTICE_CUTMIX ├── README.md ├── function.py ├── makeStyleDistanceMatrix.py ├── makeStyleDistanceMatrix.sh ├── models ├── decoder.pth.tar └── vgg_normalised.pth ├── net_cutmix.py ├── net_mixup.py ├── net_styleDistance.py ├── pyramidnet.py ├── styleDistanceMatrix10 ├── styleDistanceMatrix10.pt ├── styleDistanceMatrix100 ├── styleDistanceMatrix100.pt ├── test.py ├── test.sh ├── train.py ├── train.sh └── utils.py /LICENSE_ADAIN: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Naoto Inoue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE_CUTMIX: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /LICENSE_PUZZLEMIX: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jang-Hyun Kim, Wonho Choo, and Hyun Oh Song 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTICE_CUTMIX: -------------------------------------------------------------------------------- 1 | CutMix 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | This project contains subcomponents with separate copyright notices and license terms. 5 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 6 | 7 | ======================================================================= 8 | pytorch/vision from https://github.com/pytorch/vision 9 | ======================================================================= 10 | 11 | BSD 3-Clause License 12 | 13 | Copyright (c) Soumith Chintala 2016, 14 | All rights reserved. 15 | 16 | Redistribution and use in source and binary forms, with or without 17 | modification, are permitted provided that the following conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright notice, this 20 | list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright notice, 23 | this list of conditions and the following disclaimer in the documentation 24 | and/or other materials provided with the distribution. 25 | 26 | * Neither the name of the copyright holder nor the names of its 27 | contributors may be used to endorse or promote products derived from 28 | this software without specific prior written permission. 29 | 30 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ======================================================================= 42 | dyhan0920/PyramidNet-PyTorch from https://github.com/dyhan0920/PyramidNet-PyTorch 43 | ======================================================================= 44 | 45 | MIT License 46 | 47 | Copyright (c) 2019 Dongyoon Han 48 | 49 | Permission is hereby granted, free of charge, to any person obtaining a copy 50 | of this software and associated documentation files (the "Software"), to deal 51 | in the Software without restriction, including without limitation the rights 52 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 53 | copies of the Software, and to permit persons to whom the Software is 54 | furnished to do so, subject to the following conditions: 55 | 56 | The above copyright notice and this permission notice shall be included in all 57 | copies or substantial portions of the Software. 58 | 59 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 60 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 61 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 62 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 63 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 64 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 65 | SOFTWARE. 66 | 67 | ======================================================================= 68 | eladhoffer/convNet.pytorch from https://github.com/eladhoffer/convNet.pytorch 69 | ======================================================================= 70 | 71 | MIT License 72 | 73 | Copyright (c) 2017 Elad Hoffer 74 | 75 | Permission is hereby granted, free of charge, to any person obtaining a copy 76 | of this software and associated documentation files (the "Software"), to deal 77 | in the Software without restriction, including without limitation the rights 78 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 79 | copies of the Software, and to permit persons to whom the Software is 80 | furnished to do so, subject to the following conditions: 81 | 82 | The above copyright notice and this permission notice shall be included in all 83 | copies or substantial portions of the Software. 84 | 85 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 86 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 87 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 88 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 89 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 90 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 91 | SOFTWARE. 92 | 93 | ===== 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleMix: Separating Content and Style for Enhanced Data Augmentation (CVPR 2021) 2 | 3 | This repository contains the official PyTorch implementation for our CVPR 2021 paper. 4 | - Minui Hong*, Jinwoo Choi* and Gunhee Kim. StyleMix: Separating Content and Style for Enhanced Data Augmentation. In CVPR, 2021. (* equal contribution) 5 | 6 | [[paper]](https://openaccess.thecvf.com/content/CVPR2021/papers/Hong_StyleMix_Separating_Content_and_Style_for_Enhanced_Data_Augmentation_CVPR_2021_paper.pdf)[[supp]](https://openaccess.thecvf.com/content/CVPR2021/supplemental/Hong_StyleMix_Separating_Content_CVPR_2021_supplemental.pdf) 7 | 8 | ## Reference 9 | 10 | If you cite this paper, please refer to the following: 11 | ```bibtex 12 | @InProceedings{hong2021stylemix, 13 | author = {Minui Hong and Jinwoo Choi and Gunhee Kim}, 14 | title = {StyleMix: Separating Content and Style for Enhanced Data Augmentation}, 15 | booktitle = {CVPR}, 16 | year = {2021} 17 | } 18 | ``` 19 | 20 | ## Usage 21 | 22 | 1) train.py : Code to train the model 23 | 2) train.sh : Script to run train.py 24 | 3) test.py : Code to check CIFAR-100 and CIFAR-10 classification performance (if fgsm option is set to False) or experiment with FGSM Attack (if fgsm option is set to True) in Pyramid200 model. 25 | 4) test.sh : Script to run test.py 26 | 5) makeStyleDistanceMatrix.py : Code to make styleDistanceMatrix10, 100 27 | 6) makeStyleDistanceMatrix.sh : Script to run makeStyleDistance.sh 28 | 7) models : The directory containing the pre-trained style transfer encoder and decoder networks. 29 | 30 | ## Train Model 31 | 32 | Modify the contents of train.sh according to each situation and run it. 33 | 34 | 1. StyleCutMix_Auto_Gamma + CIFAR-100 35 | ``` 36 | python train.py \ 37 | --net_type pyramidnet \ 38 | --dataset cifar100 \ 39 | --depth 200 \ 40 | --alpha 240 \ 41 | --batch_size 64 \ 42 | --lr 0.25 \ 43 | --expname PyraNet200 \ 44 | --epochs 300 \ 45 | --prob 0.5 \ 46 | --r 0.7 \ 47 | --delta 3.0 \ 48 | --method StyleCutMix_Auto_Gamma \ 49 | --save_dir /set/your/save/dir \ 50 | --data_dir /set/your/data/dir \ 51 | --no-verbose 52 | ``` 53 | 2. StyleCutMix_Auto_Gamma + CIFAR-10 54 | ``` 55 | python train.py \ 56 | --net_type pyramidnet \ 57 | --dataset cifar10 \ 58 | --depth 200 \ 59 | --alpha 240 \ 60 | --batch_size 64 \ 61 | --lr 0.25 \ 62 | --expname PyraNet200 \ 63 | --epochs 300 \ 64 | --prob 0.5 \ 65 | --r 0.7 \ 66 | --delta 1.0 \ 67 | --method StyleCutMix_Auto_Gamma \ 68 | --save_dir /set/your/save/dir \ 69 | --data_dir /set/your/data/dir \ 70 | --no-verbose 71 | ``` 72 | 3. StyleCutMix + CIFAR-100 73 | ``` 74 | python train.py \ 75 | --net_type pyramidnet \ 76 | --dataset cifar100 \ 77 | --depth 200 \ 78 | --alpha 240 \ 79 | --batch_size 64 \ 80 | --lr 0.25 \ 81 | --expname PyraNet200 \ 82 | --epochs 300 \ 83 | --prob 0.5 \ 84 | --r 0.7 \ 85 | --alpha2 0.8 \ 86 | --method StyleCutMix \ 87 | --save_dir /set/your/save/dir \ 88 | --data_dir /set/your/data/dir \ 89 | --no-verbose 90 | ``` 91 | 4. StyleCutMix + CIFAR-10 92 | ``` 93 | python train.py \ 94 | --net_type pyramidnet \ 95 | --dataset cifar10 \ 96 | --depth 200 \ 97 | --alpha 240 \ 98 | --batch_size 64 \ 99 | --lr 0.25 \ 100 | --expname PyraNet200 \ 101 | --epochs 300 \ 102 | --prob 0.5 \ 103 | --r 0.7 \ 104 | --alpha2 0.8 \ 105 | --method StyleCutMix \ 106 | --save_dir /set/your/save/dir \ 107 | --data_dir /set/your/data/dir \ 108 | --no-verbose 109 | ``` 110 | 5. StyleMix + CIFAR-100 111 | ``` 112 | python train.py \ 113 | --net_type pyramidnet \ 114 | --dataset cifar100 \ 115 | --depth 200 \ 116 | --alpha 240 \ 117 | --batch_size 64 \ 118 | --lr 0.25 \ 119 | --expname PyraNet200 \ 120 | --epochs 300 \ 121 | --alpha1 0.5 \ 122 | --prob 0.2 \ 123 | --r 0.7 \ 124 | --method StyleMix \ 125 | --save_dir /set/your/save/dir \ 126 | --data_dir /set/your/data/dir \ 127 | --no-verbose 128 | ``` 129 | 6. StyleMix + CIFAR-10 130 | ``` 131 | python train.py \ 132 | --net_type pyramidnet \ 133 | --dataset cifar10 \ 134 | --depth 200 \ 135 | --alpha 240 \ 136 | --batch_size 64 \ 137 | --lr 0.25 \ 138 | --expname PyraNet200 \ 139 | --epochs 300 \ 140 | --alpha1 0.5 \ 141 | --prob 0.2 \ 142 | --r 0.7 \ 143 | --method StyleMix \ 144 | --save_dir /set/your/save/dir \ 145 | --data_dir /set/your/data/dir \ 146 | --no-verbose 147 | ``` 148 | ## Test classification performance 149 | 150 | Modify the contents of test.sh according to each situation and run it. 151 | 152 | 1. CIFAR-10 153 | ``` 154 | test.sh : 155 | python test.py \ 156 | --net_type pyramidnet \ 157 | --dataset cifar10 \ 158 | --batch_size 128 \ 159 | --depth 200 \ 160 | --alpha 240 \ 161 | --fgsm False \ 162 | --data_dir /set/your/data/dir \ 163 | --pretrained /set/pretrained/model/dir 164 | ``` 165 | 2. CIFAR-100 166 | ``` 167 | test.sh : 168 | python test.py \ 169 | --net_type pyramidnet \ 170 | --dataset cifar100 \ 171 | --batch_size 128 \ 172 | --depth 200 \ 173 | --alpha 240 \ 174 | --fgsm False \ 175 | --data_dir /set/your/data/dir \ 176 | --pretrained /set/pretrained/model/dir 177 | ``` 178 | ## Test FGSM Attack 179 | 180 | Modify the contents of test.sh according to each situation and run it. 181 | 182 | 1. FGSM Attack on CIFAR-10 183 | ``` 184 | test.sh : 185 | python test.py \ 186 | --net_type pyramidnet \ 187 | --dataset cifar10 \ 188 | --batch_size 128 \ 189 | --depth 200 \ 190 | --alpha 240 \ 191 | --fgsm True \ 192 | --eps 1 \ 193 | --data_dir /set/your/data/dir \ 194 | --pretrained /set/pretrained/model/dir 195 | ``` 196 | (You can change eps to 1, 2, 4) 197 | 198 | 2. FGSM Attack on CIFAR-100 199 | ``` 200 | test.sh : 201 | python test.py \ 202 | --net_type pyramidnet \ 203 | --dataset cifar100 \ 204 | --batch_size 128 \ 205 | --depth 200 \ 206 | --alpha 240 \ 207 | --fgsm True \ 208 | --eps 1 \ 209 | --data_dir /set/your/data/dir \ 210 | --pretrained /set/pretrained/model/dir 211 | ``` 212 | (You can change eps to 1, 2, 4) 213 | 214 | ## Acknowledgments 215 | 216 | This code is based on the implementations for [CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features](https://github.com/clovaai/CutMix-PyTorch), [Puzzle Mix: Exploiting Saliency and Local Statistics for Optimal Mixup](https://github.com/snu-mllab/PuzzleMix), and [Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization](https://github.com/naoto0804/pytorch-AdaIN). 217 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.size() 7 | assert (len(size) == 4) 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | 15 | def adaptive_instance_normalization(content_feat, style_feat): 16 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 17 | size = content_feat.size() 18 | style_mean, style_std = calc_mean_std(style_feat) 19 | content_mean, content_std = calc_mean_std(content_feat) 20 | 21 | normalized_feat = (content_feat - content_mean.expand( 22 | size)) / content_std.expand(size) 23 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 24 | 25 | 26 | def _calc_feat_flatten_mean_std(feat): 27 | # takes 3D feat (C, H, W), return mean and std of array within channels 28 | assert (feat.size()[0] == 3) 29 | assert (isinstance(feat, torch.FloatTensor)) 30 | feat_flatten = feat.view(3, -1) 31 | mean = feat_flatten.mean(dim=-1, keepdim=True) 32 | std = feat_flatten.std(dim=-1, keepdim=True) 33 | return feat_flatten, mean, std 34 | 35 | 36 | def _mat_sqrt(x): 37 | U, D, V = torch.svd(x) 38 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 39 | 40 | 41 | def coral(source, target): 42 | # assume both source and target are 3D array (C, H, W) 43 | # Note: flatten -> f 44 | 45 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 46 | source_f_norm = (source_f - source_f_mean.expand_as( 47 | source_f)) / source_f_std.expand_as(source_f) 48 | source_f_cov_eye = \ 49 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 50 | 51 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 52 | target_f_norm = (target_f - target_f_mean.expand_as( 53 | target_f)) / target_f_std.expand_as(target_f) 54 | target_f_cov_eye = \ 55 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 56 | 57 | source_f_norm_transfer = torch.mm( 58 | _mat_sqrt(target_f_cov_eye), 59 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), 60 | source_f_norm) 61 | ) 62 | 63 | source_f_transfer = source_f_norm_transfer * \ 64 | target_f_std.expand_as(source_f_norm) + \ 65 | target_f_mean.expand_as(source_f_norm) 66 | 67 | return source_f_transfer.view(source.size()) 68 | -------------------------------------------------------------------------------- /makeStyleDistanceMatrix.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/train.py 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | import pyramidnet as PYRM 19 | import utils 20 | import numpy as np 21 | import torchvision.utils 22 | from torchvision.utils import save_image 23 | import warnings 24 | from matplotlib import pyplot as plt 25 | import matplotlib.gridspec as gridspec 26 | 27 | import net_styleDistance 28 | from function import adaptive_instance_normalization, coral 29 | import torch.nn.functional as F 30 | from IPython import embed 31 | # Because part of the training data is truncated image 32 | from PIL import ImageFile 33 | ImageFile.LOAD_TRUNCATED_IMAGES = True 34 | 35 | warnings.filterwarnings("ignore") 36 | # Check 37 | model_names = sorted(name for name in models.__dict__ 38 | if name.islower() and not name.startswith("__") 39 | and callable(models.__dict__[name])) 40 | 41 | parser = argparse.ArgumentParser(description='Cutmix PyTorch CIFAR-10, CIFAR-100 and ImageNet-1k Training') 42 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 43 | help='number of data loading workers (default: 4)') 44 | parser.add_argument('-b', '--batch_size', default=128, type=int, 45 | metavar='N', help='mini-batch size (default: 256)') 46 | parser.add_argument('--dataset', dest='dataset', default='imagenet', type=str, 47 | help='dataset (options: cifar10, cifar100, and imagenet)') 48 | parser.add_argument('--vgg', type=str, default='./models/vgg_normalised.pth') 49 | parser.add_argument('--decoder', type=str, default='./models/decoder.pth.tar') 50 | parser.add_argument('--data_dir', type=str, default='/set/your/data/dir') 51 | 52 | 53 | def main(): 54 | global args, res, styleDistanceMatrix, numberofclass 55 | args = parser.parse_args() 56 | if args.dataset == 'cifar100': 57 | res = torch.zeros((3, 100, 1920)).cuda() 58 | styleDistanceMatrix = torch.zeros((100, 100)).cuda() 59 | elif args.dataset == 'cifar10': 60 | res = torch.zeros((3, 10, 1920)).cuda() 61 | styleDistanceMatrix = torch.zeros((10, 10)).cuda() 62 | else: 63 | raise Exception('unknown dataset: {}'.format(args.dataset)) 64 | 65 | global decoder, vgg, pretrained 66 | decoder = net_styleDistance.decoder 67 | vgg = net_styleDistance.vgg 68 | decoder.eval() 69 | vgg.eval() 70 | decoder.load_state_dict(torch.load(args.decoder)) 71 | vgg.load_state_dict(torch.load(args.vgg)) 72 | vgg = nn.Sequential(*list(vgg.children())[:31]) 73 | vgg.cuda() 74 | decoder.cuda() 75 | 76 | global network 77 | network = net_styleDistance.Net(vgg, decoder) 78 | network.eval() 79 | network = torch.nn.DataParallel(network).cuda() 80 | 81 | if args.dataset.startswith('cifar'): 82 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 83 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 84 | transform_train = transforms.Compose([ 85 | transforms.RandomCrop(32, padding=4), 86 | transforms.RandomHorizontalFlip(), 87 | transforms.ToTensor(), 88 | normalize, 89 | ]) 90 | 91 | if args.dataset == 'cifar100': 92 | train_loader = torch.utils.data.DataLoader( 93 | datasets.CIFAR100(args.data_dir+'/dataCifar100/', train=True, download=True, transform=transform_train), 94 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 95 | numberofclass = 100 96 | elif args.dataset == 'cifar10': 97 | train_loader = torch.utils.data.DataLoader( 98 | datasets.CIFAR10(args.data_dir+'/dataCifar10/', train=True, download=True, transform=transform_train), 99 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 100 | numberofclass = 10 101 | else: 102 | raise Exception('unknown dataset: {}'.format(args.dataset)) 103 | else: 104 | raise Exception('unknown dataset: {}'.format(args.dataset)) 105 | 106 | for i, (input, target) in enumerate(train_loader): 107 | # measure data loading time 108 | with torch.no_grad(): 109 | input = input.cuda() 110 | target = target.cuda() 111 | u = nn.Upsample(size=(224, 224), mode='bilinear') 112 | content = u(input) 113 | mean1, std1, mean2, std2, mean3, std3, mean4, std4 = network(content) 114 | sv = torch.cat((mean1, std1, mean2, std2, mean3, std3, mean4, std4), 1) 115 | sv = sv.view(content.shape[0], -1) 116 | res[0, target] += sv 117 | res[1, target] += 1 118 | res[2, target] = res[0, target] / res[1, target] 119 | print("Total : ",res[0]) 120 | print("Count : ",res[1]) 121 | print("Avg : ",res[2]) 122 | mse_loss = nn.MSELoss() 123 | for i in range(numberofclass): 124 | for j in range(numberofclass): 125 | styleDistanceMatrix[i, j] = mse_loss(res[2, i], res[2, j]) 126 | torch.save(styleDistanceMatrix, './styleDistanceMatrix'+str(numberofclass)+'.pt') 127 | np.savetxt('./styleDistanceMatrix'+str(numberofclass), styleDistanceMatrix.cpu().numpy(), fmt='%.10e', delimiter=',') 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /makeStyleDistanceMatrix.sh: -------------------------------------------------------------------------------- 1 | python makeStyleDistanceMatrix.py \ 2 | --dataset cifar10 \ 3 | --batch_size 256 \ 4 | --data_dir /set/your/data/dir 5 | -------------------------------------------------------------------------------- /models/decoder.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alsdml/StyleMix/e947c7d5b55ac55ddb7174a90daf64fecc452ab7/models/decoder.pth.tar -------------------------------------------------------------------------------- /models/vgg_normalised.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alsdml/StyleMix/e947c7d5b55ac55ddb7174a90daf64fecc452ab7/models/vgg_normalised.pth -------------------------------------------------------------------------------- /net_cutmix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from function import adaptive_instance_normalization as adain 5 | from IPython import embed 6 | import numpy as np 7 | decoder = nn.Sequential( 8 | nn.ReflectionPad2d((1, 1, 1, 1)), 9 | nn.Conv2d(512, 256, (3, 3)), 10 | nn.ReLU(), 11 | nn.Upsample(scale_factor=2, mode='nearest'), 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(256, 256, (3, 3)), 14 | nn.ReLU(), 15 | nn.ReflectionPad2d((1, 1, 1, 1)), 16 | nn.Conv2d(256, 256, (3, 3)), 17 | nn.ReLU(), 18 | nn.ReflectionPad2d((1, 1, 1, 1)), 19 | nn.Conv2d(256, 256, (3, 3)), 20 | nn.ReLU(), 21 | nn.ReflectionPad2d((1, 1, 1, 1)), 22 | nn.Conv2d(256, 128, (3, 3)), 23 | nn.ReLU(), 24 | nn.Upsample(scale_factor=2, mode='nearest'), 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(128, 128, (3, 3)), 27 | nn.ReLU(), 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(128, 64, (3, 3)), 30 | nn.ReLU(), 31 | nn.Upsample(scale_factor=2, mode='nearest'), 32 | nn.ReflectionPad2d((1, 1, 1, 1)), 33 | nn.Conv2d(64, 64, (3, 3)), 34 | nn.ReLU(), 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(64, 3, (3, 3)), 37 | ) 38 | 39 | vgg = nn.Sequential( 40 | nn.Conv2d(3, 3, (1, 1)), 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(3, 64, (3, 3)), 43 | nn.ReLU(), # relu1-1 44 | nn.ReflectionPad2d((1, 1, 1, 1)), 45 | nn.Conv2d(64, 64, (3, 3)), 46 | nn.ReLU(), # relu1-2 47 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(64, 128, (3, 3)), 50 | nn.ReLU(), # relu2-1 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(128, 128, (3, 3)), 53 | nn.ReLU(), # relu2-2 54 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(128, 256, (3, 3)), 57 | nn.ReLU(), # relu3-1 58 | nn.ReflectionPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(256, 256, (3, 3)), 60 | nn.ReLU(), # relu3-2 61 | nn.ReflectionPad2d((1, 1, 1, 1)), 62 | nn.Conv2d(256, 256, (3, 3)), 63 | nn.ReLU(), # relu3-3 64 | nn.ReflectionPad2d((1, 1, 1, 1)), 65 | nn.Conv2d(256, 256, (3, 3)), 66 | nn.ReLU(), # relu3-4 67 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 68 | nn.ReflectionPad2d((1, 1, 1, 1)), 69 | nn.Conv2d(256, 512, (3, 3)), 70 | nn.ReLU(), # relu4-1, this is the last layer used 71 | nn.ReflectionPad2d((1, 1, 1, 1)), 72 | nn.Conv2d(512, 512, (3, 3)), 73 | nn.ReLU(), # relu4-2 74 | nn.ReflectionPad2d((1, 1, 1, 1)), 75 | nn.Conv2d(512, 512, (3, 3)), 76 | nn.ReLU(), # relu4-3 77 | nn.ReflectionPad2d((1, 1, 1, 1)), 78 | nn.Conv2d(512, 512, (3, 3)), 79 | nn.ReLU(), # relu4-4 80 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 81 | nn.ReflectionPad2d((1, 1, 1, 1)), 82 | nn.Conv2d(512, 512, (3, 3)), 83 | nn.ReLU(), # relu5-1 84 | nn.ReflectionPad2d((1, 1, 1, 1)), 85 | nn.Conv2d(512, 512, (3, 3)), 86 | nn.ReLU(), # relu5-2 87 | nn.ReflectionPad2d((1, 1, 1, 1)), 88 | nn.Conv2d(512, 512, (3, 3)), 89 | nn.ReLU(), # relu5-3 90 | nn.ReflectionPad2d((1, 1, 1, 1)), 91 | nn.Conv2d(512, 512, (3, 3)), 92 | nn.ReLU() # relu5-4 93 | ) 94 | 95 | 96 | class Net_E(nn.Module): 97 | def __init__(self, encoder): 98 | super(Net_E, self).__init__() 99 | enc_layers = list(encoder.children()) 100 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 101 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 102 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 103 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 104 | 105 | # fix the encoder 106 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 107 | for param in getattr(self, name).parameters(): 108 | param.requires_grad = False 109 | 110 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 111 | def encode_with_intermediate(self, input): 112 | results = [input] 113 | for i in range(4): 114 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 115 | results.append(func(results[-1])) 116 | return results[1:] 117 | 118 | # extract relu4_1 from input image 119 | def encode(self, input): 120 | for i in range(4): 121 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 122 | return input 123 | 124 | def forward(self, x): 125 | x_feat = self.encode(x) 126 | return x_feat 127 | 128 | 129 | class Net_D(nn.Module): 130 | def __init__(self, encoder, decoder): 131 | super(Net_D, self).__init__() 132 | enc_layers = list(encoder.children()) 133 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 134 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 135 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 136 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 137 | self.decoder = decoder 138 | 139 | # fix the encoder 140 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 141 | for param in getattr(self, name).parameters(): 142 | param.requires_grad = False 143 | 144 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 145 | def encode_with_intermediate(self, input): 146 | results = [input] 147 | for i in range(4): 148 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 149 | results.append(func(results[-1])) 150 | return results[1:] 151 | 152 | # extract relu4_1 from input image 153 | def encode(self, input): 154 | for i in range(4): 155 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 156 | return input 157 | 158 | def forward(self, x1, x2, x1_feat, x2_feat, rs, gamma, bbx1, bby1, bbx2, bby2): 159 | if torch.is_tensor(gamma): 160 | gamma = gamma.view(-1,1,1,1) 161 | g11 = x1 162 | g22 = x2 163 | g12 = gamma*x1 + (1.-gamma)*self.decoder(adain(x1_feat, x2_feat)) 164 | g21 = gamma*x2 + (1.-gamma)*self.decoder(adain(x2_feat, x1_feat)) 165 | 166 | Rc = torch.zeros(1, 1, 224, 224).cuda() 167 | Rc[:,:,bbx1:bbx2,bby1:bby2] = 1 168 | Rs = rs * torch.ones(1, 1, 224, 224).cuda() 169 | zero = torch.zeros(1, 1, 224 ,224).cuda() 170 | T = torch.max(zero,Rc+Rs-1) 171 | 172 | output = T*g11 + (1.0-Rc-Rs+T)*g22 + (Rc-T)*g12 + (Rs-T)*g21 173 | return output 174 | 175 | 176 | -------------------------------------------------------------------------------- /net_mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from function import adaptive_instance_normalization as adain 5 | from function import calc_mean_std 6 | from IPython import embed 7 | import numpy as np 8 | decoder = nn.Sequential( 9 | nn.ReflectionPad2d((1, 1, 1, 1)), 10 | nn.Conv2d(512, 256, (3, 3)), 11 | nn.ReLU(), 12 | nn.Upsample(scale_factor=2, mode='nearest'), 13 | nn.ReflectionPad2d((1, 1, 1, 1)), 14 | nn.Conv2d(256, 256, (3, 3)), 15 | nn.ReLU(), 16 | nn.ReflectionPad2d((1, 1, 1, 1)), 17 | nn.Conv2d(256, 256, (3, 3)), 18 | nn.ReLU(), 19 | nn.ReflectionPad2d((1, 1, 1, 1)), 20 | nn.Conv2d(256, 256, (3, 3)), 21 | nn.ReLU(), 22 | nn.ReflectionPad2d((1, 1, 1, 1)), 23 | nn.Conv2d(256, 128, (3, 3)), 24 | nn.ReLU(), 25 | nn.Upsample(scale_factor=2, mode='nearest'), 26 | nn.ReflectionPad2d((1, 1, 1, 1)), 27 | nn.Conv2d(128, 128, (3, 3)), 28 | nn.ReLU(), 29 | nn.ReflectionPad2d((1, 1, 1, 1)), 30 | nn.Conv2d(128, 64, (3, 3)), 31 | nn.ReLU(), 32 | nn.Upsample(scale_factor=2, mode='nearest'), 33 | nn.ReflectionPad2d((1, 1, 1, 1)), 34 | nn.Conv2d(64, 64, (3, 3)), 35 | nn.ReLU(), 36 | nn.ReflectionPad2d((1, 1, 1, 1)), 37 | nn.Conv2d(64, 3, (3, 3)), 38 | ) 39 | 40 | vgg = nn.Sequential( 41 | nn.Conv2d(3, 3, (1, 1)), 42 | nn.ReflectionPad2d((1, 1, 1, 1)), 43 | nn.Conv2d(3, 64, (3, 3)), 44 | nn.ReLU(), # relu1-1 45 | nn.ReflectionPad2d((1, 1, 1, 1)), 46 | nn.Conv2d(64, 64, (3, 3)), 47 | nn.ReLU(), # relu1-2 48 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 49 | nn.ReflectionPad2d((1, 1, 1, 1)), 50 | nn.Conv2d(64, 128, (3, 3)), 51 | nn.ReLU(), # relu2-1 52 | nn.ReflectionPad2d((1, 1, 1, 1)), 53 | nn.Conv2d(128, 128, (3, 3)), 54 | nn.ReLU(), # relu2-2 55 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 56 | nn.ReflectionPad2d((1, 1, 1, 1)), 57 | nn.Conv2d(128, 256, (3, 3)), 58 | nn.ReLU(), # relu3-1 59 | nn.ReflectionPad2d((1, 1, 1, 1)), 60 | nn.Conv2d(256, 256, (3, 3)), 61 | nn.ReLU(), # relu3-2 62 | nn.ReflectionPad2d((1, 1, 1, 1)), 63 | nn.Conv2d(256, 256, (3, 3)), 64 | nn.ReLU(), # relu3-3 65 | nn.ReflectionPad2d((1, 1, 1, 1)), 66 | nn.Conv2d(256, 256, (3, 3)), 67 | nn.ReLU(), # relu3-4 68 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 69 | nn.ReflectionPad2d((1, 1, 1, 1)), 70 | nn.Conv2d(256, 512, (3, 3)), 71 | nn.ReLU(), # relu4-1, this is the last layer used 72 | nn.ReflectionPad2d((1, 1, 1, 1)), 73 | nn.Conv2d(512, 512, (3, 3)), 74 | nn.ReLU(), # relu4-2 75 | nn.ReflectionPad2d((1, 1, 1, 1)), 76 | nn.Conv2d(512, 512, (3, 3)), 77 | nn.ReLU(), # relu4-3 78 | nn.ReflectionPad2d((1, 1, 1, 1)), 79 | nn.Conv2d(512, 512, (3, 3)), 80 | nn.ReLU(), # relu4-4 81 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 82 | nn.ReflectionPad2d((1, 1, 1, 1)), 83 | nn.Conv2d(512, 512, (3, 3)), 84 | nn.ReLU(), # relu5-1 85 | nn.ReflectionPad2d((1, 1, 1, 1)), 86 | nn.Conv2d(512, 512, (3, 3)), 87 | nn.ReLU(), # relu5-2 88 | nn.ReflectionPad2d((1, 1, 1, 1)), 89 | nn.Conv2d(512, 512, (3, 3)), 90 | nn.ReLU(), # relu5-3 91 | nn.ReflectionPad2d((1, 1, 1, 1)), 92 | nn.Conv2d(512, 512, (3, 3)), 93 | nn.ReLU() # relu5-4 94 | ) 95 | 96 | 97 | class Net_E(nn.Module): 98 | def __init__(self, encoder): 99 | super(Net_E, self).__init__() 100 | enc_layers = list(encoder.children()) 101 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 102 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 103 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 104 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 105 | 106 | # fix the encoder 107 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 108 | for param in getattr(self, name).parameters(): 109 | param.requires_grad = False 110 | 111 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 112 | def encode_with_intermediate(self, input): 113 | results = [input] 114 | for i in range(4): 115 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 116 | results.append(func(results[-1])) 117 | return results[1:] 118 | 119 | # extract relu4_1 from input image 120 | def encode(self, input): 121 | for i in range(4): 122 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 123 | return input 124 | 125 | def forward(self, x): 126 | x_feat = self.encode(x) 127 | return x_feat 128 | 129 | 130 | class Net_D(nn.Module): 131 | def __init__(self, encoder, decoder): 132 | super(Net_D, self).__init__() 133 | enc_layers = list(encoder.children()) 134 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 135 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 136 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 137 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 138 | self.decoder = decoder 139 | 140 | # fix the encoder 141 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 142 | for param in getattr(self, name).parameters(): 143 | param.requires_grad = False 144 | 145 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 146 | def encode_with_intermediate(self, input): 147 | results = [input] 148 | for i in range(4): 149 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 150 | results.append(func(results[-1])) 151 | return results[1:] 152 | 153 | # extract relu4_1 from input image 154 | def encode(self, input): 155 | for i in range(4): 156 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 157 | return input 158 | 159 | def forward(self, f11, f22, rc, rs): 160 | t = np.random.uniform(max(0, rc+rs-1), min(rc, rs), 1)[0] 161 | return self.decoder(t * f11 + (1.0-rc-rs+t) * f22 + (rc-t) * adain(f11, f22) + (rs-t) * adain(f22, f11)) 162 | -------------------------------------------------------------------------------- /net_styleDistance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from function import adaptive_instance_normalization as adain 5 | from function import calc_mean_std 6 | from IPython import embed 7 | import numpy as np 8 | decoder = nn.Sequential( 9 | nn.ReflectionPad2d((1, 1, 1, 1)), 10 | nn.Conv2d(512, 256, (3, 3)), 11 | nn.ReLU(), 12 | nn.Upsample(scale_factor=2, mode='nearest'), 13 | nn.ReflectionPad2d((1, 1, 1, 1)), 14 | nn.Conv2d(256, 256, (3, 3)), 15 | nn.ReLU(), 16 | nn.ReflectionPad2d((1, 1, 1, 1)), 17 | nn.Conv2d(256, 256, (3, 3)), 18 | nn.ReLU(), 19 | nn.ReflectionPad2d((1, 1, 1, 1)), 20 | nn.Conv2d(256, 256, (3, 3)), 21 | nn.ReLU(), 22 | nn.ReflectionPad2d((1, 1, 1, 1)), 23 | nn.Conv2d(256, 128, (3, 3)), 24 | nn.ReLU(), 25 | nn.Upsample(scale_factor=2, mode='nearest'), 26 | nn.ReflectionPad2d((1, 1, 1, 1)), 27 | nn.Conv2d(128, 128, (3, 3)), 28 | nn.ReLU(), 29 | nn.ReflectionPad2d((1, 1, 1, 1)), 30 | nn.Conv2d(128, 64, (3, 3)), 31 | nn.ReLU(), 32 | nn.Upsample(scale_factor=2, mode='nearest'), 33 | nn.ReflectionPad2d((1, 1, 1, 1)), 34 | nn.Conv2d(64, 64, (3, 3)), 35 | nn.ReLU(), 36 | nn.ReflectionPad2d((1, 1, 1, 1)), 37 | nn.Conv2d(64, 3, (3, 3)), 38 | ) 39 | 40 | vgg = nn.Sequential( 41 | nn.Conv2d(3, 3, (1, 1)), 42 | nn.ReflectionPad2d((1, 1, 1, 1)), 43 | nn.Conv2d(3, 64, (3, 3)), 44 | nn.ReLU(), # relu1-1 45 | nn.ReflectionPad2d((1, 1, 1, 1)), 46 | nn.Conv2d(64, 64, (3, 3)), 47 | nn.ReLU(), # relu1-2 48 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 49 | nn.ReflectionPad2d((1, 1, 1, 1)), 50 | nn.Conv2d(64, 128, (3, 3)), 51 | nn.ReLU(), # relu2-1 52 | nn.ReflectionPad2d((1, 1, 1, 1)), 53 | nn.Conv2d(128, 128, (3, 3)), 54 | nn.ReLU(), # relu2-2 55 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 56 | nn.ReflectionPad2d((1, 1, 1, 1)), 57 | nn.Conv2d(128, 256, (3, 3)), 58 | nn.ReLU(), # relu3-1 59 | nn.ReflectionPad2d((1, 1, 1, 1)), 60 | nn.Conv2d(256, 256, (3, 3)), 61 | nn.ReLU(), # relu3-2 62 | nn.ReflectionPad2d((1, 1, 1, 1)), 63 | nn.Conv2d(256, 256, (3, 3)), 64 | nn.ReLU(), # relu3-3 65 | nn.ReflectionPad2d((1, 1, 1, 1)), 66 | nn.Conv2d(256, 256, (3, 3)), 67 | nn.ReLU(), # relu3-4 68 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 69 | nn.ReflectionPad2d((1, 1, 1, 1)), 70 | nn.Conv2d(256, 512, (3, 3)), 71 | nn.ReLU(), # relu4-1, this is the last layer used 72 | nn.ReflectionPad2d((1, 1, 1, 1)), 73 | nn.Conv2d(512, 512, (3, 3)), 74 | nn.ReLU(), # relu4-2 75 | nn.ReflectionPad2d((1, 1, 1, 1)), 76 | nn.Conv2d(512, 512, (3, 3)), 77 | nn.ReLU(), # relu4-3 78 | nn.ReflectionPad2d((1, 1, 1, 1)), 79 | nn.Conv2d(512, 512, (3, 3)), 80 | nn.ReLU(), # relu4-4 81 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 82 | nn.ReflectionPad2d((1, 1, 1, 1)), 83 | nn.Conv2d(512, 512, (3, 3)), 84 | nn.ReLU(), # relu5-1 85 | nn.ReflectionPad2d((1, 1, 1, 1)), 86 | nn.Conv2d(512, 512, (3, 3)), 87 | nn.ReLU(), # relu5-2 88 | nn.ReflectionPad2d((1, 1, 1, 1)), 89 | nn.Conv2d(512, 512, (3, 3)), 90 | nn.ReLU(), # relu5-3 91 | nn.ReflectionPad2d((1, 1, 1, 1)), 92 | nn.Conv2d(512, 512, (3, 3)), 93 | nn.ReLU() # relu5-4 94 | ) 95 | 96 | class Net(nn.Module): 97 | def __init__(self, encoder, decoder): 98 | super(Net, self).__init__() 99 | enc_layers = list(encoder.children()) 100 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 101 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 102 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 103 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 104 | self.decoder = decoder 105 | self.mse_loss = nn.MSELoss() 106 | self.mse_loss_none = nn.MSELoss(reduction='none') 107 | 108 | # fix the encoder 109 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 110 | for param in getattr(self, name).parameters(): 111 | param.requires_grad = False 112 | 113 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 114 | def encode_with_intermediate(self, input): 115 | results = [input] 116 | for i in range(4): 117 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 118 | results.append(func(results[-1])) 119 | return results[1:] 120 | 121 | # extract relu4_1 from input image 122 | def encode(self, input): 123 | for i in range(4): 124 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 125 | return input 126 | 127 | def per_calc_content_loss(self, input, target): 128 | assert (input.size() == target.size()) 129 | assert (target.requires_grad is False) 130 | t = self.mse_loss_none(input, target) 131 | return torch.mean(t,dim=[1,2,3]) 132 | 133 | def per_calc_style_loss(self, input, target): 134 | assert (input.size() == target.size()) 135 | assert (target.requires_grad is False) 136 | input_mean, input_std = calc_mean_std(input) 137 | target_mean, target_std = calc_mean_std(target) 138 | t1 = self.mse_loss_none(input_mean, target_mean) 139 | t2 = self.mse_loss_none(input_std, target_std) 140 | return torch.mean(t1, dim=[1,2,3]) + torch.mean(t2, dim=[1,2,3]) 141 | 142 | def calc_content_loss(self, input, target): 143 | assert (input.size() == target.size()) 144 | assert (target.requires_grad is False) 145 | return self.mse_loss(input, target) 146 | 147 | def calc_style_loss(self, input, target): 148 | assert (input.size() == target.size()) 149 | assert (target.requires_grad is False) 150 | input_mean, input_std = calc_mean_std(input) 151 | target_mean, target_std = calc_mean_std(target) 152 | return self.mse_loss(input_mean, target_mean) + \ 153 | self.mse_loss(input_std, target_std) 154 | 155 | def forward(self, style): 156 | style_f1, style_f2, style_f3, style_f4 = self.encode_with_intermediate(style) 157 | mean1, std1 = calc_mean_std(style_f1) 158 | mean2, std2 = calc_mean_std(style_f2) 159 | mean3, std3 = calc_mean_std(style_f3) 160 | mean4, std4 = calc_mean_std(style_f4) 161 | return mean1, std1, mean2, std2, mean3, std3, mean4, std4 162 | 163 | -------------------------------------------------------------------------------- /pyramidnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/PyramidNet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | outchannel_ratio = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn3 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | 29 | out = self.bn1(x) 30 | out = self.conv1(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | out = self.conv2(out) 34 | out = self.bn3(out) 35 | if self.downsample is not None: 36 | shortcut = self.downsample(x) 37 | featuremap_size = shortcut.size()[2:4] 38 | else: 39 | shortcut = x 40 | featuremap_size = out.size()[2:4] 41 | 42 | batch_size = out.size()[0] 43 | residual_channel = out.size()[1] 44 | shortcut_channel = shortcut.size()[1] 45 | 46 | if residual_channel != shortcut_channel: 47 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 48 | out += torch.cat((shortcut, padding), 1) 49 | else: 50 | out += shortcut 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | outchannel_ratio = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 59 | super(Bottleneck, self).__init__() 60 | self.bn1 = nn.BatchNorm2d(inplanes) 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, (planes), kernel_size=3, stride=stride, padding=1, bias=False, groups=1) 64 | self.bn3 = nn.BatchNorm2d((planes)) 65 | self.conv3 = nn.Conv2d((planes), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 66 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | 74 | out = self.bn1(x) 75 | out = self.conv1(out) 76 | 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | out = self.conv2(out) 80 | 81 | out = self.bn3(out) 82 | out = self.relu(out) 83 | out = self.conv3(out) 84 | 85 | out = self.bn4(out) 86 | if self.downsample is not None: 87 | shortcut = self.downsample(x) 88 | featuremap_size = shortcut.size()[2:4] 89 | else: 90 | shortcut = x 91 | featuremap_size = out.size()[2:4] 92 | 93 | batch_size = out.size()[0] 94 | residual_channel = out.size()[1] 95 | shortcut_channel = shortcut.size()[1] 96 | 97 | if residual_channel != shortcut_channel: 98 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 99 | out += torch.cat((shortcut, padding), 1) 100 | else: 101 | out += shortcut 102 | 103 | return out 104 | 105 | 106 | class PyramidNet(nn.Module): 107 | 108 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=False): 109 | super(PyramidNet, self).__init__() 110 | self.dataset = dataset 111 | if self.dataset.startswith('cifar'): 112 | self.inplanes = 16 113 | if bottleneck == True: 114 | n = int((depth - 2) / 9) 115 | block = Bottleneck 116 | else: 117 | n = int((depth - 2) / 6) 118 | block = BasicBlock 119 | 120 | self.addrate = alpha / (3*n*1.0) 121 | 122 | self.input_featuremap_dim = self.inplanes 123 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 124 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 125 | 126 | self.featuremap_dim = self.input_featuremap_dim 127 | self.layer1 = self.pyramidal_make_layer(block, n) 128 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 129 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 130 | 131 | self.final_featuremap_dim = self.input_featuremap_dim 132 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) 133 | self.relu_final = nn.ReLU(inplace=True) 134 | self.avgpool = nn.AvgPool2d(8) 135 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 136 | 137 | elif dataset == 'imagenet': 138 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 139 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 140 | 141 | if layers.get(depth) is None: 142 | if bottleneck == True: 143 | blocks[depth] = Bottleneck 144 | temp_cfg = int((depth-2)/12) 145 | else: 146 | blocks[depth] = BasicBlock 147 | temp_cfg = int((depth-2)/8) 148 | 149 | layers[depth]= [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 150 | print('=> the layer configuration for each stage is set to', layers[depth]) 151 | 152 | self.inplanes = 64 153 | self.addrate = alpha / (sum(layers[depth])*1.0) 154 | 155 | self.input_featuremap_dim = self.inplanes 156 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 157 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 158 | self.relu = nn.ReLU(inplace=True) 159 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 160 | 161 | self.featuremap_dim = self.input_featuremap_dim 162 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 163 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 164 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 165 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 166 | 167 | self.final_featuremap_dim = self.input_featuremap_dim 168 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) 169 | self.relu_final = nn.ReLU(inplace=True) 170 | self.avgpool = nn.AvgPool2d(7) 171 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 176 | m.weight.data.normal_(0, math.sqrt(2. / n)) 177 | elif isinstance(m, nn.BatchNorm2d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | 181 | def pyramidal_make_layer(self, block, block_depth, stride=1): 182 | downsample = None 183 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 184 | downsample = nn.AvgPool2d((2,2), stride = (2, 2), ceil_mode=True) 185 | 186 | layers = [] 187 | self.featuremap_dim = self.featuremap_dim + self.addrate 188 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample)) 189 | for i in range(1, block_depth): 190 | temp_featuremap_dim = self.featuremap_dim + self.addrate 191 | layers.append(block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1)) 192 | self.featuremap_dim = temp_featuremap_dim 193 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward(self, x): 198 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 199 | x = self.conv1(x) 200 | x = self.bn1(x) 201 | 202 | x = self.layer1(x) 203 | x = self.layer2(x) 204 | x = self.layer3(x) 205 | 206 | x = self.bn_final(x) 207 | x = self.relu_final(x) 208 | x = self.avgpool(x) 209 | x = x.view(x.size(0), -1) 210 | x = self.fc(x) 211 | 212 | elif self.dataset == 'imagenet': 213 | x = self.conv1(x) 214 | x = self.bn1(x) 215 | x = self.relu(x) 216 | x = self.maxpool(x) 217 | 218 | x = self.layer1(x) 219 | x = self.layer2(x) 220 | x = self.layer3(x) 221 | x = self.layer4(x) 222 | 223 | x = self.bn_final(x) 224 | x = self.relu_final(x) 225 | x = self.avgpool(x) 226 | x = x.view(x.size(0), -1) 227 | x = self.fc(x) 228 | 229 | return x 230 | -------------------------------------------------------------------------------- /styleDistanceMatrix10: -------------------------------------------------------------------------------- 1 | 0.0000000000e+00,1.1152156591e+00,5.0094264746e-01,9.1899502277e-01,9.0263599157e-01,9.2004841566e-01,1.6020164490e+00,9.4485962391e-01,1.0029406846e-01,9.4869703054e-01 2 | 1.1152156591e+00,0.0000000000e+00,1.1159808636e+00,6.8440753222e-01,1.0445179939e+00,7.3641455173e-01,8.3409565687e-01,4.6992495656e-01,9.7281879187e-01,8.8051997125e-02 3 | 5.0094264746e-01,1.1159808636e+00,0.0000000000e+00,3.0495527387e-01,1.3483661413e-01,2.6760476828e-01,5.8059424162e-01,3.7537804246e-01,4.5420405269e-01,9.7760480642e-01 4 | 9.1899502277e-01,6.8440753222e-01,3.0495527387e-01,0.0000000000e+00,2.5351962447e-01,3.0504621565e-02,2.7863362432e-01,1.5136308968e-01,7.9841178656e-01,6.1002314091e-01 5 | 9.0263599157e-01,1.0445179939e+00,1.3483661413e-01,2.5351962447e-01,0.0000000000e+00,2.1668659151e-01,2.8311863542e-01,2.6937046647e-01,7.3826050758e-01,9.2354351282e-01 6 | 9.2004841566e-01,7.3641455173e-01,2.6760476828e-01,3.0504621565e-02,2.1668659151e-01,0.0000000000e+00,2.7801358700e-01,1.3462370634e-01,8.1246948242e-01,6.3699191809e-01 7 | 1.6020164490e+00,8.3409565687e-01,5.8059424162e-01,2.7863362432e-01,2.8311863542e-01,2.7801358700e-01,0.0000000000e+00,2.5108212233e-01,1.3980617523e+00,8.5034561157e-01 8 | 9.4485962391e-01,4.6992495656e-01,3.7537804246e-01,1.5136308968e-01,2.6937046647e-01,1.3462370634e-01,2.5108212233e-01,0.0000000000e+00,8.4996002913e-01,3.5369625688e-01 9 | 1.0029406846e-01,9.7281879187e-01,4.5420405269e-01,7.9841178656e-01,7.3826050758e-01,8.1246948242e-01,1.3980617523e+00,8.4996002913e-01,0.0000000000e+00,8.2651221752e-01 10 | 9.4869703054e-01,8.8051997125e-02,9.7760480642e-01,6.1002314091e-01,9.2354351282e-01,6.3699191809e-01,8.5034561157e-01,3.5369625688e-01,8.2651221752e-01,0.0000000000e+00 11 | -------------------------------------------------------------------------------- /styleDistanceMatrix10.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alsdml/StyleMix/e947c7d5b55ac55ddb7174a90daf64fecc452ab7/styleDistanceMatrix10.pt -------------------------------------------------------------------------------- /styleDistanceMatrix100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alsdml/StyleMix/e947c7d5b55ac55ddb7174a90daf64fecc452ab7/styleDistanceMatrix100.pt -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/train.py 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | import pyramidnet as PYRM 19 | from torch.autograd import Variable 20 | 21 | import warnings 22 | 23 | warnings.filterwarnings("ignore") 24 | 25 | model_names = sorted(name for name in models.__dict__ 26 | if name.islower() and not name.startswith("__") 27 | and callable(models.__dict__[name])) 28 | 29 | def str2bool(v): 30 | if isinstance(v, bool): 31 | return v 32 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 33 | return True 34 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 35 | return False 36 | else: 37 | raise argparse.ArgumentTypeError('Boolean value expected.') 38 | 39 | parser = argparse.ArgumentParser(description='Cutmix PyTorch CIFAR-10, CIFAR-100 and ImageNet-1k Test') 40 | parser.add_argument('--net_type', default='pyramidnet', type=str, 41 | help='networktype: pyramidnet') 42 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 43 | help='number of data loading workers (default: 4)') 44 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 45 | help='number of total epochs to run') 46 | parser.add_argument('-b', '--batch_size', default=128, type=int, 47 | metavar='N', help='mini-batch size (default: 256)') 48 | parser.add_argument('--print-freq', '-p', default=1, type=int, 49 | metavar='N', help='print frequency (default: 10)') 50 | parser.add_argument('--depth', default=32, type=int, 51 | help='depth of the network (default: 32)') 52 | parser.add_argument('--no-bottleneck', dest='bottleneck', action='store_false', 53 | help='to use basicblock for CIFAR datasets (default: bottleneck)') 54 | parser.add_argument('--dataset', dest='dataset', default='imagenet', type=str, 55 | help='dataset (options: cifar10, cifar100, and imagenet)') 56 | parser.add_argument('--alpha', default=300, type=float, 57 | help='number of new channel increases per depth (default: 300)') 58 | parser.add_argument('--no-verbose', dest='verbose', action='store_false', 59 | help='to print the status at every iteration') 60 | parser.add_argument('--data_dir', default='/write/your/data/dir', type=str, metavar='PATH') 61 | parser.add_argument('--pretrained', default='/set/your/model', type=str, metavar='PATH') 62 | parser.add_argument('--fgsm', type=str2bool, default=False, help='true for fgsm') 63 | parser.add_argument('--eps', default=1, type=int, help='1, 2, 4') 64 | parser.set_defaults(bottleneck=True) 65 | parser.set_defaults(verbose=True) 66 | 67 | best_err1 = 100 68 | best_err5 = 100 69 | 70 | 71 | def main(): 72 | global args, best_err1, best_err5 73 | args = parser.parse_args() 74 | 75 | if args.dataset.startswith('cifar'): 76 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 77 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 78 | 79 | transform_train = transforms.Compose([ 80 | transforms.RandomCrop(32, padding=4), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | normalize, 84 | ]) 85 | 86 | transform_test = transforms.Compose([ 87 | transforms.ToTensor(), 88 | normalize 89 | ]) 90 | if args.dataset == 'cifar100': 91 | val_loader = torch.utils.data.DataLoader( 92 | datasets.CIFAR100(args.data_dir+'/dataCifar100/', train=False, transform=transform_test), 93 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 94 | numberofclass = 100 95 | elif args.dataset == 'cifar10': 96 | val_loader = torch.utils.data.DataLoader( 97 | datasets.CIFAR10(args.data_dir+'/dataCifar10/', train=False, transform=transform_test), 98 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 99 | numberofclass = 10 100 | else: 101 | raise Exception('unknown dataset: {}'.format(args.dataset)) 102 | else: 103 | raise Exception('unknown dataset: {}'.format(args.dataset)) 104 | 105 | print("=> creating model '{}'".format(args.net_type)) 106 | if args.net_type == 'pyramidnet': 107 | model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, 108 | args.bottleneck) 109 | else: 110 | raise Exception('unknown network architecture: {}'.format(args.net_type)) 111 | 112 | model = torch.nn.DataParallel(model).cuda() 113 | 114 | if os.path.isfile(args.pretrained): 115 | print("=> loading checkpoint '{}'".format(args.pretrained)) 116 | checkpoint = torch.load(args.pretrained) 117 | model.load_state_dict(checkpoint['state_dict']) 118 | print("=> loaded checkpoint '{}'".format(args.pretrained)) 119 | else: 120 | raise Exception("=> no checkpoint found at '{}'".format(args.pretrained)) 121 | 122 | print(model) 123 | print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 124 | 125 | # define loss function (criterion) and optimizer 126 | criterion = nn.CrossEntropyLoss().cuda() 127 | 128 | cudnn.benchmark = True 129 | mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]], dtype=torch.float32).view(1,3,1,1).cuda() 130 | std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]], dtype=torch.float32).view(1,3,1,1).cuda() 131 | 132 | err1, err5, val_loss = validate(val_loader, model, args.fgsm, args.eps, mean, std) 133 | 134 | print('Accuracy (top-1 and 5 error):', err1, err5) 135 | 136 | def validate(val_loader, model, fgsm, eps, mean, std): 137 | '''evaluate trained model''' 138 | losses = AverageMeter() 139 | top1 = AverageMeter() 140 | top5 = AverageMeter() 141 | 142 | criterion = nn.CrossEntropyLoss().cuda() 143 | # switch to evaluate mode 144 | model.eval() 145 | 146 | for i, (input, target) in enumerate(val_loader): 147 | input = input.cuda() 148 | target = target.cuda() 149 | 150 | # check FGSM for adversarial training 151 | if fgsm: 152 | input_var = Variable(input, requires_grad=True) 153 | target_var = Variable(target) 154 | 155 | optimizer_input = torch.optim.SGD([input_var], lr=0.1) 156 | output = model(input_var) 157 | loss = criterion(output, target_var) 158 | optimizer_input.zero_grad() 159 | loss.backward() 160 | 161 | sign_data_grad = input_var.grad.sign() 162 | input = input * std + mean + eps / 255. * sign_data_grad 163 | input = torch.clamp(input, 0, 1) 164 | input = (input - mean)/std 165 | 166 | with torch.no_grad(): 167 | input_var = Variable(input) 168 | target_var = Variable(target) 169 | 170 | # compute output 171 | output = model(input_var) 172 | loss = criterion(output, target_var) 173 | 174 | # measure accuracy and record loss 175 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 176 | losses.update(loss.item(), input.size(0)) 177 | top1.update(prec1.item(), input.size(0)) 178 | top5.update(prec5.item(), input.size(0)) 179 | 180 | if fgsm: 181 | print('Attack (eps : {}) Prec@1 {top1.avg:.2f}'.format(eps, top1=top1)) 182 | print('Attack (eps : {}) Prec@5 {top5.avg:.2f}'.format(eps, top5=top5)) 183 | else: 184 | print(' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f} Loss: {losses.avg:.3f} '.format(top1=top1, top5=top5, error1=100-top1.avg, losses=losses)) 185 | return top1.avg, top5.avg, losses.avg 186 | 187 | 188 | class AverageMeter(object): 189 | """Computes and stores the average and current value""" 190 | 191 | def __init__(self): 192 | self.reset() 193 | 194 | def reset(self): 195 | self.val = 0 196 | self.avg = 0 197 | self.sum = 0 198 | self.count = 0 199 | 200 | def update(self, val, n=1): 201 | self.val = val 202 | self.sum += val * n 203 | self.count += n 204 | self.avg = self.sum / self.count 205 | 206 | 207 | def accuracy(output, target, topk=(1,)): 208 | """Computes the precision@k for the specified values of k""" 209 | maxk = max(topk) 210 | batch_size = target.size(0) 211 | 212 | _, pred = output.topk(maxk, 1, True, True) 213 | pred = pred.t() 214 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 215 | 216 | res = [] 217 | for k in topk: 218 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 219 | wrong_k = batch_size - correct_k 220 | res.append(wrong_k.mul_(100.0 / batch_size)) 221 | 222 | return res 223 | 224 | 225 | if __name__ == '__main__': 226 | main() 227 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py \ 2 | --net_type pyramidnet \ 3 | --dataset cifar10 \ 4 | --batch_size 256 \ 5 | --depth 200 \ 6 | --alpha 240 \ 7 | --fgsm False \ 8 | --eps 2 \ 9 | --data_dir /set/your/data/dir \ 10 | --pretrained /set/your/pretrained/model/dir/model_best.pth.tar 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import torch.utils.data.distributed 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | import torchvision.models as models 15 | import pyramidnet as PYRM 16 | import utils 17 | import numpy as np 18 | import torchvision.utils 19 | from torchvision.utils import save_image 20 | import warnings 21 | from matplotlib import pyplot as plt 22 | import matplotlib.gridspec as gridspec 23 | from function import calc_mean_std 24 | from torch.utils.tensorboard import SummaryWriter 25 | import net_cutmix 26 | import net_mixup 27 | from function import adaptive_instance_normalization, coral 28 | import torch.nn.functional as F 29 | from IPython import embed 30 | from PIL import ImageFile 31 | ImageFile.LOAD_TRUNCATED_IMAGES = True 32 | from datetime import datetime 33 | warnings.filterwarnings("ignore") 34 | model_names = sorted(name for name in models.__dict__ 35 | if name.islower() and not name.startswith("__") 36 | and callable(models.__dict__[name])) 37 | 38 | parser = argparse.ArgumentParser(description='StyleMix CIFAR-10, CIFAR-100 training code') 39 | parser.add_argument('--net_type', default='pyramidnet', type=str, 40 | help='networktype: pyramidnet') 41 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', 42 | help='number of data loading workers (default: 4)') 43 | parser.add_argument('--epochs', default=300, type=int, metavar='N', 44 | help='number of total epochs to run') # 250 45 | parser.add_argument('-b', '--batch_size', default=256, type=int, 46 | metavar='N', help='mini-batch size (default: 256)') 47 | parser.add_argument('--lr', '--learning-rate', default=0.25, type=float, 48 | metavar='LR', help='initial learning rate') 49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 50 | help='momentum') 51 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 52 | metavar='W', help='weight decay (default: 1e-4)') 53 | parser.add_argument('--print-freq', '-p', default=1, type=int, 54 | metavar='N', help='print frequency (default: 10)') 55 | parser.add_argument('--depth', default=18, type=int, 56 | help='depth of the network (default: 32)') 57 | parser.add_argument('--no-bottleneck', dest='bottleneck', action='store_false', 58 | help='to use basicblock for CIFAR datasets (default: bottleneck)') 59 | parser.add_argument('--dataset', dest='dataset', default='cifar100', type=str, 60 | help='dataset (options: cifar10, cifar100)') 61 | parser.add_argument('--no-verbose', dest='verbose', action='store_false', 62 | help='to print the status at every iteration') 63 | parser.add_argument('--alpha', default=200, type=float, 64 | help='number of new channel increases per depth (default: 200)') 65 | parser.add_argument('--expname', default='PyraNet200', type=str, 66 | help='name of experiment') 67 | parser.add_argument('--vgg', type=str, default='./models/vgg_normalised.pth') 68 | parser.add_argument('--decoder', type=str, default='./models/decoder.pth.tar') 69 | parser.add_argument('--prob', default=0.5, type=float) 70 | parser.add_argument('--r', default=0.7, type=float) 71 | parser.add_argument('--alpha1', default=1.0, type=float) 72 | parser.add_argument('--alpha2', default=1.0, type=float) 73 | parser.add_argument('--delta', default=3.0, type=float) 74 | parser.add_argument('--method', type=str, default='StyleCutMix_Auto_Gamma', help='StyleCutMix_Auto_Gamma, StyleCutMix, StyleMix') 75 | parser.add_argument('--save_dir', type=str, default='/write/your/save/dir') 76 | parser.add_argument('--data_dir', type=str, default='/write/your/data/dir') 77 | parser.set_defaults(bottleneck=True) 78 | parser.set_defaults(verbose=True) 79 | 80 | best_err1 = 100 81 | best_err5 = 100 82 | 83 | def main(): 84 | 85 | global args, best_err1, best_err5, styleDistanceMatrix, writer 86 | args = parser.parse_args() 87 | writer = SummaryWriter(args.save_dir+'/writer/'+args.method) 88 | if args.method == 'StyleCutMix_Auto_Gamma' : 89 | if args.dataset == 'cifar100': 90 | styleDistanceMatrix = torch.load('styleDistanceMatrix100.pt', map_location='cuda:0') 91 | elif args.dataset == 'cifar10': 92 | styleDistanceMatrix = torch.load('styleDistanceMatrix10.pt', map_location='cuda:0') 93 | else : 94 | raise Exception('unknown dataset: {}'.format(args.dataset)) 95 | styleDistanceMatrix = styleDistanceMatrix.cpu() 96 | ind = torch.arange(styleDistanceMatrix.shape[1]) 97 | styleDistanceMatrix[ind, ind] += 2 # Prevent diagonal lines from zero 98 | 99 | global decoder, vgg, pretrained, network_E, network_D 100 | if args.method.startswith('Style'): 101 | if args.method.startswith('StyleCutMix'): 102 | decoder = net_cutmix.decoder 103 | vgg = net_cutmix.vgg 104 | print("select network StyleCutMix") 105 | network_E = net_cutmix.Net_E(vgg) 106 | network_D = net_cutmix.Net_D(vgg, decoder) 107 | elif args.method == 'StyleMix': 108 | decoder = net_mixup.decoder 109 | vgg = net_mixup.vgg 110 | print("select network StyleMix") 111 | network_E = net_mixup.Net_E(vgg) 112 | network_D = net_mixup.Net_D(vgg, decoder) 113 | else : 114 | raise Exception('unknown method: {}'.format(args.method)) 115 | decoder.eval() 116 | vgg.eval() 117 | decoder.load_state_dict(torch.load(args.decoder)) 118 | vgg.load_state_dict(torch.load(args.vgg)) 119 | vgg = nn.Sequential(*list(vgg.children())[:31]) 120 | vgg.cuda() 121 | decoder.cuda() 122 | network_E.eval() 123 | network_D.eval() 124 | network_E = torch.nn.DataParallel(network_E).cuda() 125 | network_D = torch.nn.DataParallel(network_D).cuda() 126 | 127 | if args.dataset.startswith('cifar'): 128 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 129 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 130 | transform_train = transforms.Compose([ 131 | transforms.RandomCrop(32, padding=4), 132 | transforms.RandomHorizontalFlip(), 133 | transforms.ToTensor(), 134 | normalize, 135 | ]) 136 | transform_test = transforms.Compose([ 137 | transforms.ToTensor(), 138 | normalize 139 | ]) 140 | 141 | if args.dataset == 'cifar100': 142 | train_loader = torch.utils.data.DataLoader( 143 | datasets.CIFAR100(args.data_dir+'/dataCifar100/', train=True, download=True, transform=transform_train), 144 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 145 | val_loader = torch.utils.data.DataLoader( 146 | datasets.CIFAR100(args.data_dir+'/dataCifar100/', train=False, transform=transform_test), 147 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 148 | numberofclass = 100 149 | elif args.dataset == 'cifar10': 150 | train_loader = torch.utils.data.DataLoader( 151 | datasets.CIFAR10(args.data_dir+'/dataCifar10/', train=True, download=True, transform=transform_train), 152 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 153 | val_loader = torch.utils.data.DataLoader( 154 | datasets.CIFAR10(args.data_dir+'/dataCifar10/', train=False, transform=transform_test), 155 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 156 | numberofclass = 10 157 | else: 158 | raise Exception('unknown dataset: {}'.format(args.dataset)) 159 | else: 160 | raise Exception('unknown dataset: {}'.format(args.dataset)) 161 | 162 | print("=> creating model '{}'".format(args.net_type)) 163 | if args.net_type == 'pyramidnet': 164 | model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, 165 | args.bottleneck) 166 | else: 167 | raise Exception('unknown network architecture: {}'.format(args.net_type)) 168 | model = torch.nn.DataParallel(model).cuda() 169 | print(model) 170 | print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 171 | 172 | # define loss function (criterion) and optimizer 173 | criterion = nn.CrossEntropyLoss().cuda() 174 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 175 | momentum=args.momentum, 176 | weight_decay=args.weight_decay, nesterov=True) 177 | cudnn.benchmark = True 178 | 179 | for epoch in range(0, args.epochs): 180 | adjust_learning_rate(optimizer, epoch) 181 | 182 | # train for one epoch 183 | train_loss = train(train_loader, model, criterion, optimizer, epoch) 184 | 185 | # evaluate on validation set 186 | err1, err5, val_loss = validate(val_loader, model, criterion, epoch) 187 | 188 | writer.add_scalar('train_loss', train_loss, epoch+1) 189 | writer.add_scalar('val_loss', val_loss, epoch+1) 190 | writer.add_scalar('err1', err1, epoch+1) 191 | writer.add_scalar('err5', err5, epoch+1) 192 | # remember best prec@1 and save checkpoint 193 | is_best = err1 <= best_err1 194 | best_err1 = min(err1, best_err1) 195 | if is_best: 196 | best_err5 = err5 197 | 198 | print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) 199 | save_checkpoint({ 200 | 'epoch': epoch, 201 | 'arch': args.net_type, 202 | 'state_dict': model.state_dict(), 203 | 'best_err1': best_err1, 204 | 'best_err5': best_err5, 205 | 'optimizer': optimizer.state_dict(), 206 | }, is_best, args.save_dir, args.dataset) 207 | 208 | print('Best accuracy (top-1 and 5 error):', best_err1, best_err5) 209 | 210 | 211 | def train(train_loader, model, criterion, optimizer, epoch): 212 | batch_time = AverageMeter() 213 | data_time = AverageMeter() 214 | losses = AverageMeter() 215 | top1 = AverageMeter() 216 | top5 = AverageMeter() 217 | 218 | # switch to train mode 219 | model.train() 220 | start = time.time() 221 | end = time.time() 222 | current_LR = get_learning_rate(optimizer)[0] 223 | print("current_LR : ",current_LR) 224 | for i, (input, target) in enumerate(train_loader): 225 | data_time.update(time.time() - end) 226 | input = input.cuda() 227 | target = target.cuda() 228 | prob = np.random.rand(1) 229 | if prob < args.prob: 230 | rand_index = torch.randperm(input.size()[0]).cuda() 231 | target_1 = target 232 | target_2 = target[rand_index] 233 | if args.method.startswith('StyleCutMix'): 234 | if args.method == 'StyleCutMix_Auto_Gamma' : 235 | styleDistance = styleDistanceMatrix[target_1, target_2] 236 | gamma = torch.tanh(styleDistance/args.delta) 237 | else : 238 | gamma = np.random.beta(args.alpha2, args.alpha2) 239 | 240 | u = nn.Upsample(size=(224, 224), mode='bilinear') 241 | x1 = u(input) 242 | x2 = x1[rand_index] 243 | rs = np.random.beta(args.alpha1, args.alpha1) 244 | M = torch.zeros(1,1,224,224).float() 245 | lam_temp = np.random.beta(args.alpha1, args.alpha1) 246 | bbx1, bby1, bbx2, bby2 = rand_bbox(M.size(), 1.-lam_temp) 247 | with torch.no_grad(): 248 | x1_feat = network_E(x1) 249 | mixImage = network_D(x1, x2, x1_feat, x1_feat[rand_index], rs, gamma, bbx1, bby1, bbx2, bby2) 250 | lam = ((bbx2 - bbx1)*(bby2-bby1)/(224.*224.)) 251 | uinv = nn.Upsample(size=(32,32), mode='bilinear') 252 | output = model(uinv(mixImage)) 253 | 254 | log_preds = F.log_softmax(output, dim=-1) # dimension [batch_size, numberofclass] 255 | a_loss = -log_preds[torch.arange(output.shape[0]),target_1] # cross-entropy for A 256 | b_loss = -log_preds[torch.arange(output.shape[0]),target_2] # cross-entropy for B 257 | if args.method == 'StyleCutMix_Auto_Gamma': 258 | gamma = gamma.cuda() 259 | lam_s = gamma * lam + (1.0 - gamma) * rs 260 | loss_c = a_loss * (lam) + b_loss * (1. - lam) 261 | loss_s = a_loss * (lam_s) + b_loss * (1. - lam_s) 262 | r = args.r 263 | loss = (r * loss_c + (1.0 - r) * loss_s).mean() 264 | elif args.method == 'StyleMix': 265 | u = nn.Upsample(size=(224, 224), mode='bilinear') 266 | x1 = u(input) 267 | x2 = x1[rand_index] 268 | rc = np.random.beta(args.alpha1, args.alpha1) 269 | rs = np.random.beta(args.alpha1, args.alpha1) 270 | with torch.no_grad(): 271 | x1_feat = network_E(x1) 272 | mixImage = network_D(x1_feat, x1_feat[rand_index], rc, rs) 273 | uinv = nn.Upsample(size=(32,32), mode='bilinear') 274 | output = model(uinv(mixImage)) 275 | 276 | loss_c = rc * criterion(output, target_1) + (1.0 - rc) * criterion(output, target_2) 277 | loss_s = rs * criterion(output, target_1) + (1.0 - rs) * criterion(output, target_2) 278 | r = args.r 279 | loss = r * loss_c + (1.0-r) * loss_s 280 | else: 281 | output = model(input) 282 | loss = criterion(output, target) 283 | # measure accuracy and record loss 284 | err1, err5 = accuracy(output.data, target, topk=(1, 5)) 285 | 286 | losses.update(loss.item(), input.size(0)) 287 | top1.update(err1.item(), input.size(0)) 288 | top5.update(err5.item(), input.size(0)) 289 | # compute gradient and do SGD step 290 | optimizer.zero_grad() 291 | loss.backward() 292 | optimizer.step() 293 | 294 | # measure elapsed time 295 | batch_time.update(time.time() - end) 296 | end = time.time() 297 | 298 | if i % args.print_freq == 0 and args.verbose == True: 299 | print('Epoch: [{0}/{1}][{2}/{3}]\t' 300 | 'LR: {LR:.6f}\t' 301 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 302 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 303 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 304 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 305 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( 306 | epoch, args.epochs, i, len(train_loader), LR=current_LR, batch_time=batch_time, 307 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 308 | print("Time taken for 1 epoch : ",time.time()-start) 309 | print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Train Loss {loss.avg:.3f}'.format( 310 | epoch, args.epochs, top1=top1, top5=top5, loss=losses)) 311 | 312 | return losses.avg 313 | 314 | 315 | def rand_bbox(size, lam): 316 | W = size[2] 317 | H = size[3] 318 | cut_rat = np.sqrt(1. - lam) 319 | cut_w = np.int(W * cut_rat) 320 | cut_h = np.int(H * cut_rat) 321 | 322 | # uniform 323 | cx = np.random.randint(W) 324 | cy = np.random.randint(H) 325 | 326 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 327 | bby1 = np.clip(cy - cut_h // 2, 0, H) 328 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 329 | bby2 = np.clip(cy + cut_h // 2, 0, H) 330 | 331 | return bbx1, bby1, bbx2, bby2 332 | 333 | 334 | def validate(val_loader, model, criterion, epoch): 335 | batch_time = AverageMeter() 336 | losses = AverageMeter() 337 | top1 = AverageMeter() 338 | top5 = AverageMeter() 339 | 340 | # switch to evaluate mode 341 | model.eval() 342 | 343 | end = time.time() 344 | for i, (input, target) in enumerate(val_loader): 345 | target = target.cuda() 346 | 347 | output = model(input) 348 | loss = criterion(output, target) 349 | 350 | # measure accuracy and record loss 351 | err1, err5 = accuracy(output.data, target, topk=(1, 5)) 352 | 353 | losses.update(loss.item(), input.size(0)) 354 | 355 | top1.update(err1.item(), input.size(0)) 356 | top5.update(err5.item(), input.size(0)) 357 | 358 | # measure elapsed time 359 | batch_time.update(time.time() - end) 360 | end = time.time() 361 | 362 | if i % args.print_freq == 0 and args.verbose == True: 363 | print('Test (on val set): [{0}/{1}][{2}/{3}]\t' 364 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 365 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 366 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 367 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( 368 | epoch, args.epochs, i, len(val_loader), batch_time=batch_time, loss=losses, 369 | top1=top1, top5=top5)) 370 | 371 | print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Test Loss {loss.avg:.3f}'.format( 372 | epoch, args.epochs, top1=top1, top5=top5, loss=losses)) 373 | return top1.avg, top5.avg, losses.avg 374 | 375 | 376 | def save_checkpoint(state, is_best, save_dir, dataset, filename='checkpoint.pth.tar'): 377 | directory = save_dir+"/model/"+dataset+"/"+str(args.method)+"/%s/" % (args.expname) 378 | if not os.path.exists(directory): 379 | os.makedirs(directory) 380 | filename = directory + filename 381 | torch.save(state, filename) 382 | 383 | if is_best: 384 | shutil.copyfile(filename, save_dir+"/model/"+dataset+"/"+str(args.method)+'/%s/' % (args.expname) + 'model_best.pth.tar') 385 | 386 | 387 | class AverageMeter(object): 388 | """Computes and stores the average and current value""" 389 | 390 | def __init__(self): 391 | self.reset() 392 | 393 | def reset(self): 394 | self.val = 0 395 | self.avg = 0 396 | self.sum = 0 397 | self.count = 0 398 | 399 | def update(self, val, n=1): 400 | self.val = val 401 | self.sum += val * n 402 | self.count += n 403 | self.avg = self.sum / self.count 404 | 405 | 406 | def adjust_learning_rate(optimizer, epoch): 407 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 408 | lr = args.lr * (0.1 ** (epoch // (args.epochs * 0.5))) * (0.1 ** (epoch // (args.epochs * 0.75))) 409 | 410 | for param_group in optimizer.param_groups: 411 | param_group['lr'] = lr 412 | 413 | 414 | def get_learning_rate(optimizer): 415 | lr = [] 416 | for param_group in optimizer.param_groups: 417 | lr += [param_group['lr']] 418 | return lr 419 | 420 | 421 | def accuracy(output, target, topk=(1,)): 422 | """Computes the precision@k for the specified values of k""" 423 | maxk = max(topk) 424 | batch_size = target.size(0) 425 | 426 | _, pred = output.topk(maxk, 1, True, True) 427 | pred = pred.t() 428 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 429 | 430 | res = [] 431 | for k in topk: 432 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 433 | wrong_k = batch_size - correct_k 434 | res.append(wrong_k.mul_(100.0 / batch_size)) 435 | 436 | return res 437 | 438 | if __name__ == '__main__': 439 | main() 440 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --net_type pyramidnet \ 3 | --dataset cifar100 \ 4 | --depth 200 \ 5 | --alpha 240 \ 6 | --batch_size 128 \ 7 | --lr 0.25 \ 8 | --expname PyraNet200 \ 9 | --epochs 300 \ 10 | --prob 0.5 \ 11 | --r 0.7 \ 12 | --delta 3.0 \ 13 | --method StyleCutMix_Auto_Gamma \ 14 | --save_dir /write/your/save/dir \ 15 | --data_dir /write/your/data/dir \ 16 | --no-verbose 17 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py 2 | 3 | import torch 4 | import random 5 | 6 | __all__ = ["Compose", "Lighting", "ColorJitter"] 7 | 8 | 9 | class Compose(object): 10 | """Composes several transforms together. 11 | 12 | Args: 13 | transforms (list of ``Transform`` objects): list of transforms to compose. 14 | 15 | Example: 16 | >>> transforms.Compose([ 17 | >>> transforms.CenterCrop(10), 18 | >>> transforms.ToTensor(), 19 | >>> ]) 20 | """ 21 | 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __call__(self, img): 26 | for t in self.transforms: 27 | img = t(img) 28 | return img 29 | 30 | def __repr__(self): 31 | format_string = self.__class__.__name__ + '(' 32 | for t in self.transforms: 33 | format_string += '\n' 34 | format_string += ' {0}'.format(t) 35 | format_string += '\n)' 36 | return format_string 37 | 38 | 39 | class Lighting(object): 40 | """Lighting noise(AlexNet - style PCA - based noise)""" 41 | 42 | def __init__(self, alphastd, eigval, eigvec): 43 | self.alphastd = alphastd 44 | self.eigval = torch.Tensor(eigval) 45 | self.eigvec = torch.Tensor(eigvec) 46 | 47 | def __call__(self, img): 48 | if self.alphastd == 0: 49 | return img 50 | 51 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 52 | rgb = self.eigvec.type_as(img).clone() \ 53 | .mul(alpha.view(1, 3).expand(3, 3)) \ 54 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 55 | .sum(1).squeeze() 56 | 57 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 58 | 59 | 60 | class Grayscale(object): 61 | 62 | def __call__(self, img): 63 | gs = img.clone() 64 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 65 | gs[1].copy_(gs[0]) 66 | gs[2].copy_(gs[0]) 67 | return gs 68 | 69 | 70 | class Saturation(object): 71 | 72 | def __init__(self, var): 73 | self.var = var 74 | 75 | def __call__(self, img): 76 | gs = Grayscale()(img) 77 | alpha = random.uniform(-self.var, self.var) 78 | return img.lerp(gs, alpha) 79 | 80 | 81 | class Brightness(object): 82 | 83 | def __init__(self, var): 84 | self.var = var 85 | 86 | def __call__(self, img): 87 | gs = img.new().resize_as_(img).zero_() 88 | alpha = random.uniform(-self.var, self.var) 89 | return img.lerp(gs, alpha) 90 | 91 | 92 | class Contrast(object): 93 | 94 | def __init__(self, var): 95 | self.var = var 96 | 97 | def __call__(self, img): 98 | gs = Grayscale()(img) 99 | gs.fill_(gs.mean()) 100 | alpha = random.uniform(-self.var, self.var) 101 | return img.lerp(gs, alpha) 102 | 103 | 104 | class ColorJitter(object): 105 | 106 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 107 | self.brightness = brightness 108 | self.contrast = contrast 109 | self.saturation = saturation 110 | 111 | def __call__(self, img): 112 | self.transforms = [] 113 | if self.brightness != 0: 114 | self.transforms.append(Brightness(self.brightness)) 115 | if self.contrast != 0: 116 | self.transforms.append(Contrast(self.contrast)) 117 | if self.saturation != 0: 118 | self.transforms.append(Saturation(self.saturation)) 119 | 120 | random.shuffle(self.transforms) 121 | transform = Compose(self.transforms) 122 | # print(transform) 123 | return transform(img) 124 | --------------------------------------------------------------------------------