├── datasets └── .gitadd ├── imgs ├── C10_acc.png ├── C100_acc.png └── SVHN_acc.png ├── requirements.txt ├── setup.sh ├── LICENSE ├── utils.py ├── README.md ├── .gitignore ├── main.py ├── dataloader.py ├── mlp_mixer.py └── train.py /datasets/.gitadd: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/C10_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omihub777/MLP-Mixer-CIFAR/HEAD/imgs/C10_acc.png -------------------------------------------------------------------------------- /imgs/C100_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omihub777/MLP-Mixer-CIFAR/HEAD/imgs/C100_acc.png -------------------------------------------------------------------------------- /imgs/SVHN_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omihub777/MLP-Mixer-CIFAR/HEAD/imgs/SVHN_acc.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.0 2 | torchvision 3 | numpy 4 | wandb 5 | einops 6 | torchsummary 7 | scipy -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | pip install --upgrade pip 4 | pip install -r requirements.txt 5 | pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 6 | git clone https://github.com/DeepVoltaire/AutoAugment.git -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Omiita 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_model(args): 4 | model = None 5 | if args.model=='mlp_mixer': 6 | from mlp_mixer import MLPMixer 7 | model = MLPMixer( 8 | in_channels=3, 9 | img_size=args.size, 10 | hidden_size=args.hidden_size, 11 | patch_size = args.patch_size, 12 | hidden_c = args.hidden_c, 13 | hidden_s = args.hidden_s, 14 | num_layers = args.num_layers, 15 | num_classes=args.num_classes, 16 | drop_p=args.drop_p, 17 | off_act=args.off_act, 18 | is_cls_token=args.is_cls_token 19 | ) 20 | else: 21 | raise ValueError(f"No such model: {args.model}") 22 | 23 | return model.to(args.device) 24 | 25 | def rand_bbox(size, lam): 26 | W = size[2] 27 | H = size[3] 28 | cut_rat = np.sqrt(1. - lam) 29 | cut_w = np.int(W * cut_rat) 30 | cut_h = np.int(H * cut_rat) 31 | 32 | # uniform 33 | cx = np.random.randint(W) 34 | cy = np.random.randint(H) 35 | 36 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 37 | bby1 = np.clip(cy - cut_h // 2, 0, H) 38 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 39 | bby2 = np.clip(cy + cut_h // 2, 0, H) 40 | 41 | return bbx1, bby1, bbx2, bby2 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLP-Mixer-CIFAR 2 | PyTorch implementation of **Mixer-nano** (#parameters is **0.67M**, originally Mixer-S/16 has 18M) with **90.83 % acc.** on CIFAR-10. Training from scratch. 3 | 4 | ## 1.Prerequisite 5 | * Python 3.9.6 6 | * PyTorch 1.10.0 7 | * [Weights and Biases](https://wandb.ai/site) account for logging experiments. 8 | 9 | ## 2.Quick Start 10 | ```shell 11 | $git clone https://github.com/omihub777/MLP-Mixer-CIFAR.git 12 | $cd MLP-Mixer-CIFAR 13 | $bash setup.sh 14 | $main.py --dataset c10 --model mlp_mixer --autoaugment --cutmix-prob 0.5 15 | ``` 16 | 17 | ## 3.Result 18 | |Dataset|Acc.(%)|Time(hh:mm:ss)|Steps| 19 | |:--:|:--:|:--:|:--:| 20 | |CIFAR-10|**90.83%**|3:34.31|117.3k| 21 | |CIFAR-100|**67.51%**|3:35.26|117.3k| 22 | |SVHN|**97.63%**|5:23.26|171.9k| 23 | * Number of Parameters: 0.67M 24 | * Device: P100 (single GPU) 25 | 26 | ### 3.1 CIFAR-10 27 | * Accuracy 28 | 29 | ![Validation Acc. on CIFAR-10](imgs/C10_acc.png) 30 | 31 | ### 3.2 CIFAR-100 32 | * Accuracy 33 | 34 | ![Validation Acc. on CIFAR-100](imgs/C100_acc.png) 35 | 36 | 37 | ### 3.3 SVHN 38 | * Accuracy 39 | 40 | ![Validation Acc. on SVHN](imgs/SVHN_acc.png) 41 | 42 | 43 | ## 4. Experiment Settings 44 | 45 | |Param|Value| 46 | |:--|:--:| 47 | |Adam beta1|0.9| 48 | |Adam beta2|0.99| 49 | |AutoAugment|True| 50 | |Batch Size|128| 51 | |CutMix prob.|0.5| 52 | |CutMix beta|1.0| 53 | |Dropout|0.0| 54 | |Epoch|300| 55 | |Hidden_C|512| 56 | |Hidden_S|64| 57 | |Hidden|128| 58 | |(Init LR, Last LR)|(1e-3, 1e-6)| 59 | |Label Smoothing|0.1| 60 | |Layers|8| 61 | |LR Scheduler|Cosine| 62 | |Optimizer|Adam| 63 | |Random Seed|3407| 64 | |Weight Decay|5e-5| 65 | |Warmup|5 epochs| 66 | 67 | ## 5. Resources 68 | * [MLP-Mixer: An all-MLP Architecture for Vision, Tolstikhin, I., (2021)](https://arxiv.org/abs/2105.01601) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | wandb/ 3 | !datasets/.gitadd 4 | AutoAugment/ 5 | sam.py 6 | .DS_Store 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import wandb 5 | wandb.login() 6 | 7 | from dataloader import get_dataloaders 8 | from utils import get_model 9 | from train import Trainer 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', required=True, choices=['c10', 'c100', 'svhn']) 13 | parser.add_argument('--model', required=True, choices=['mlp_mixer']) 14 | parser.add_argument('--batch-size', type=int, default=128) 15 | parser.add_argument('--eval-batch-size', type=int, default=1024) 16 | parser.add_argument('--num-workers', type=int, default=4) 17 | parser.add_argument('--seed', type=int, default=3407) 18 | parser.add_argument('--epochs', type=int, default=300) 19 | # parser.add_argument('--precision', type=int, default=16) 20 | 21 | parser.add_argument('--patch-size', type=int, default=4) 22 | parser.add_argument('--hidden-size', type=int, default=128) 23 | parser.add_argument('--hidden-c', type=int, default=512) 24 | parser.add_argument('--hidden-s', type=int, default=64) 25 | parser.add_argument('--num-layers', type=int, default=8) 26 | parser.add_argument('--drop-p', type=int, default=0.) 27 | parser.add_argument('--off-act', action='store_true', help='Disable activation function') 28 | parser.add_argument('--is-cls-token', action='store_true', help='Introduce a class token.') 29 | 30 | parser.add_argument('--lr', type=float, default=1e-3) 31 | parser.add_argument('--min-lr', type=float, default=1e-6) 32 | parser.add_argument('--momentum', type=float, default=0.9) 33 | parser.add_argument('--optimizer', default='adam', choices=['adam', 'sgd']) 34 | parser.add_argument('--scheduler', default='cosine', choices=['step', 'cosine']) 35 | parser.add_argument('--beta1', type=float, default=0.9) 36 | parser.add_argument('--beta2', type=float, default=0.99) 37 | parser.add_argument('--weight-decay', type=float, default=5e-5) 38 | parser.add_argument('--off-nesterov', action='store_true') 39 | parser.add_argument('--label-smoothing', type=float, default=0.1) 40 | parser.add_argument('--gamma', type=float, default=0.1) 41 | parser.add_argument('--warmup-epoch', type=int, default=5) 42 | parser.add_argument('--autoaugment', action='store_true') 43 | parser.add_argument('--clip-grad', type=float, default=0, help="0 means disabling clip-grad") 44 | parser.add_argument('--cutmix-beta', type=float, default=1.0) 45 | parser.add_argument('--cutmix-prob', type=float, default=0.) 46 | 47 | args = parser.parse_args() 48 | args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 49 | args.nesterov = not args.off_nesterov 50 | torch.random.manual_seed(args.seed) 51 | 52 | experiment_name = f"{args.model}_{args.dataset}_{args.optimizer}_{args.scheduler}" 53 | if args.autoaugment: 54 | experiment_name += "_aa" 55 | if args.clip_grad: 56 | experiment_name += f"_cg{args.clip_grad}" 57 | if args.off_act: 58 | experiment_name += f"_noact" 59 | if args.cutmix_prob>0.: 60 | experiment_name += f'_cm' 61 | if args.is_cls_token: 62 | experiment_name += f"_cls" 63 | 64 | 65 | if __name__=='__main__': 66 | with wandb.init(project='mlp_mixer', config=args, name=experiment_name): 67 | train_dl, test_dl = get_dataloaders(args) 68 | model = get_model(args) 69 | trainer = Trainer(model, args) 70 | trainer.fit(train_dl, test_dl) -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./AutoAugment/') 3 | 4 | import torch 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from AutoAugment.autoaugment import CIFAR10Policy, SVHNPolicy 8 | 9 | 10 | def get_dataloaders(args): 11 | train_transform, test_transform = get_transform(args) 12 | 13 | if args.dataset == "c10": 14 | train_ds = torchvision.datasets.CIFAR10('./datasets', train=True, transform=train_transform, download=True) 15 | test_ds = torchvision.datasets.CIFAR10('./datasets', train=False, transform=test_transform, download=True) 16 | args.num_classes = 10 17 | elif args.dataset == "c100": 18 | train_ds = torchvision.datasets.CIFAR100('./datasets', train=True, transform=train_transform, download=True) 19 | test_ds = torchvision.datasets.CIFAR100('./datasets', train=False, transform=test_transform, download=True) 20 | args.num_classes = 100 21 | elif args.dataset == "svhn": 22 | train_ds = torchvision.datasets.SVHN('./datasets', split='train', transform=train_transform, download=True) 23 | test_ds = torchvision.datasets.SVHN('./datasets', split='test', transform=test_transform, download=True) 24 | args.num_classes = 10 25 | else: 26 | raise ValueError(f"No such dataset:{args.dataset}") 27 | 28 | train_dl = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) 29 | test_dl = torch.utils.data.DataLoader(test_ds, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) 30 | 31 | return train_dl, test_dl 32 | 33 | def get_transform(args): 34 | if args.dataset in ["c10", "c100", 'svhn']: 35 | args.padding=4 36 | args.size = 32 37 | if args.dataset=="c10": 38 | args.mean, args.std = [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616] 39 | elif args.dataset=="c100": 40 | args.mean, args.std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761] 41 | elif args.dataset=="svhn": 42 | args.mean, args.std = [0.4377, 0.4438, 0.4728], [0.1980, 0.2010, 0.1970] 43 | else: 44 | args.padding=28 45 | args.size = 224 46 | args.mean, args.std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 47 | train_transform_list = [transforms.RandomCrop(size=(args.size,args.size), padding=args.padding)] 48 | if args.dataset!="svhn": 49 | train_transform_list.append(transforms.RandomCrop(size=(args.size,args.size), padding=args.padding)) 50 | 51 | if args.autoaugment: 52 | if args.dataset == 'c10' or args.dataset=='c100': 53 | train_transform_list.append(CIFAR10Policy()) 54 | elif args.dataset == 'svhn': 55 | train_transform_list.append(SVHNPolicy()) 56 | else: 57 | print(f"No AutoAugment for {args.dataset}") 58 | 59 | 60 | train_transform = transforms.Compose( 61 | train_transform_list+[ 62 | transforms.ToTensor(), 63 | transforms.Normalize( 64 | mean=args.mean, 65 | std = args.std 66 | ) 67 | ] 68 | ) 69 | test_transform = transforms.Compose([ 70 | transforms.ToTensor(), 71 | transforms.Normalize( 72 | mean=args.mean, 73 | std = args.std 74 | ) 75 | ]) 76 | 77 | return train_transform, test_transform 78 | -------------------------------------------------------------------------------- /mlp_mixer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.layers.torch import Rearrange 5 | import torchsummary 6 | 7 | 8 | class MLPMixer(nn.Module): 9 | def __init__(self,in_channels=3,img_size=32, patch_size=4, hidden_size=512, hidden_s=256, hidden_c=2048, num_layers=8, num_classes=10, drop_p=0., off_act=False, is_cls_token=False): 10 | super(MLPMixer, self).__init__() 11 | num_patches = img_size // patch_size * img_size // patch_size 12 | # (b, c, h, w) -> (b, d, h//p, w//p) -> (b, h//p*w//p, d) 13 | self.is_cls_token = is_cls_token 14 | 15 | self.patch_emb = nn.Sequential( 16 | nn.Conv2d(in_channels, hidden_size ,kernel_size=patch_size, stride=patch_size), 17 | Rearrange('b d h w -> b (h w) d') 18 | ) 19 | 20 | if self.is_cls_token: 21 | self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size)) 22 | num_patches += 1 23 | 24 | 25 | self.mixer_layers = nn.Sequential( 26 | *[ 27 | MixerLayer(num_patches, hidden_size, hidden_s, hidden_c, drop_p, off_act) 28 | for _ in range(num_layers) 29 | ] 30 | ) 31 | self.ln = nn.LayerNorm(hidden_size) 32 | 33 | self.clf = nn.Linear(hidden_size, num_classes) 34 | 35 | 36 | def forward(self, x): 37 | out = self.patch_emb(x) 38 | if self.is_cls_token: 39 | out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out], dim=1) 40 | out = self.mixer_layers(out) 41 | out = self.ln(out) 42 | out = out[:, 0] if self.is_cls_token else out.mean(dim=1) 43 | out = self.clf(out) 44 | return out 45 | 46 | 47 | class MixerLayer(nn.Module): 48 | def __init__(self, num_patches, hidden_size, hidden_s, hidden_c, drop_p, off_act): 49 | super(MixerLayer, self).__init__() 50 | self.mlp1 = MLP1(num_patches, hidden_s, hidden_size, drop_p, off_act) 51 | self.mlp2 = MLP2(hidden_size, hidden_c, drop_p, off_act) 52 | def forward(self, x): 53 | out = self.mlp1(x) 54 | out = self.mlp2(out) 55 | return out 56 | 57 | class MLP1(nn.Module): 58 | def __init__(self, num_patches, hidden_s, hidden_size, drop_p, off_act): 59 | super(MLP1, self).__init__() 60 | self.ln = nn.LayerNorm(hidden_size) 61 | self.fc1 = nn.Conv1d(num_patches, hidden_s, kernel_size=1) 62 | self.do1 = nn.Dropout(p=drop_p) 63 | self.fc2 = nn.Conv1d(hidden_s, num_patches, kernel_size=1) 64 | self.do2 = nn.Dropout(p=drop_p) 65 | self.act = F.gelu if not off_act else lambda x:x 66 | def forward(self, x): 67 | out = self.do1(self.act(self.fc1(self.ln(x)))) 68 | out = self.do2(self.fc2(out)) 69 | return out+x 70 | 71 | class MLP2(nn.Module): 72 | def __init__(self, hidden_size, hidden_c, drop_p, off_act): 73 | super(MLP2, self).__init__() 74 | self.ln = nn.LayerNorm(hidden_size) 75 | self.fc1 = nn.Linear(hidden_size, hidden_c) 76 | self.do1 = nn.Dropout(p=drop_p) 77 | self.fc2 = nn.Linear(hidden_c, hidden_size) 78 | self.do2 = nn.Dropout(p=drop_p) 79 | self.act = F.gelu if not off_act else lambda x:x 80 | def forward(self, x): 81 | out = self.do1(self.act(self.fc1(self.ln(x)))) 82 | out = self.do2(self.fc2(out)) 83 | return out+x 84 | 85 | if __name__ == '__main__': 86 | net = MLPMixer( 87 | in_channels=3, 88 | img_size=32, 89 | patch_size=4, 90 | hidden_size=128, 91 | hidden_s=512, 92 | hidden_c=64, 93 | num_layers=8, 94 | num_classes=10, 95 | drop_p=0., 96 | off_act=False, 97 | is_cls_token=True 98 | ) 99 | torchsummary.summary(net, (3,32,32)) 100 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import wandb 6 | import warmup_scheduler 7 | import numpy as np 8 | 9 | from utils import rand_bbox 10 | 11 | 12 | class Trainer(object): 13 | def __init__(self, model, args): 14 | wandb.config.update(args) 15 | self.device = args.device 16 | self.clip_grad = args.clip_grad 17 | self.cutmix_beta = args.cutmix_beta 18 | self.cutmix_prob = args.cutmix_prob 19 | self.model = model 20 | if args.optimizer=='sgd': 21 | self.optimizer = optim.SGD(self.model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) 22 | elif args.optimizer=='adam': 23 | self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) 24 | else: 25 | raise ValueError(f"No such optimizer: {self.optimizer}") 26 | 27 | if args.scheduler=='step': 28 | self.base_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[args.epochs//2, 3*args.epochs//4], gamma=args.gamma) 29 | elif args.scheduler=='cosine': 30 | self.base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=args.epochs, eta_min=args.min_lr) 31 | else: 32 | raise ValueError(f"No such scheduler: {self.scheduler}") 33 | 34 | 35 | if args.warmup_epoch: 36 | self.scheduler = warmup_scheduler.GradualWarmupScheduler(self.optimizer, multiplier=1., total_epoch=args.warmup_epoch, after_scheduler=self.base_scheduler) 37 | else: 38 | self.scheduler = self.base_scheduler 39 | self.scaler = torch.cuda.amp.GradScaler() 40 | 41 | self.epochs = args.epochs 42 | self.criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) 43 | 44 | self.num_steps = 0 45 | self.epoch_loss, self.epoch_corr, self.epoch_acc = 0., 0., 0. 46 | 47 | def _train_one_step(self, batch): 48 | self.model.train() 49 | img, label = batch 50 | self.num_steps += 1 51 | img, label = img.to(self.device), label.to(self.device) 52 | 53 | self.optimizer.zero_grad() 54 | r = np.random.rand(1) 55 | if self.cutmix_beta > 0 and r < self.cutmix_prob: 56 | # generate mixed sample 57 | lam = np.random.beta(self.cutmix_beta, self.cutmix_beta) 58 | rand_index = torch.randperm(img.size(0)).to(self.device) 59 | target_a = label 60 | target_b = label[rand_index] 61 | bbx1, bby1, bbx2, bby2 = rand_bbox(img.size(), lam) 62 | img[:, :, bbx1:bbx2, bby1:bby2] = img[rand_index, :, bbx1:bbx2, bby1:bby2] 63 | # adjust lambda to exactly match pixel ratio 64 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img.size()[-1] * img.size()[-2])) 65 | # compute output 66 | with torch.cuda.amp.autocast(): 67 | out = self.model(img) 68 | loss = self.criterion(out, target_a) * lam + self.criterion(out, target_b) * (1. - lam) 69 | else: 70 | # compute output 71 | with torch.cuda.amp.autocast(): 72 | out = self.model(img) 73 | loss = self.criterion(out, label) 74 | 75 | self.scaler.scale(loss).backward() 76 | if self.clip_grad: 77 | nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad) 78 | self.scaler.step(self.optimizer) 79 | self.scaler.update() 80 | 81 | acc = out.argmax(dim=-1).eq(label).sum(-1)/img.size(0) 82 | wandb.log({ 83 | 'loss':loss, 84 | 'acc':acc 85 | }, step=self.num_steps) 86 | 87 | 88 | # @torch.no_grad 89 | def _test_one_step(self, batch): 90 | self.model.eval() 91 | img, label = batch 92 | img, label = img.to(self.device), label.to(self.device) 93 | 94 | with torch.no_grad(): 95 | out = self.model(img) 96 | loss = self.criterion(out, label) 97 | 98 | self.epoch_loss += loss * img.size(0) 99 | self.epoch_corr += out.argmax(dim=-1).eq(label).sum(-1) 100 | 101 | 102 | def fit(self, train_dl, test_dl): 103 | for epoch in range(1, self.epochs+1): 104 | for batch in train_dl: 105 | self._train_one_step(batch) 106 | wandb.log({ 107 | 'epoch': epoch, 108 | # 'lr': self.scheduler.get_last_lr(), 109 | 'lr':self.optimizer.param_groups[0]["lr"] 110 | }, step=self.num_steps 111 | ) 112 | self.scheduler.step() 113 | 114 | 115 | num_imgs = 0. 116 | self.epoch_loss, self.epoch_corr, self.epoch_acc = 0., 0., 0. 117 | for batch in test_dl: 118 | self._test_one_step(batch) 119 | num_imgs += batch[0].size(0) 120 | self.epoch_loss /= num_imgs 121 | self.epoch_acc = self.epoch_corr / num_imgs 122 | wandb.log({ 123 | 'val_loss': self.epoch_loss, 124 | 'val_acc': self.epoch_acc 125 | }, step=self.num_steps 126 | ) 127 | --------------------------------------------------------------------------------