├── README.md ├── main.py └── main.yaml /README.md: -------------------------------------------------------------------------------- 1 | # mixup.pytorch 2 | 3 | An implementation of *mixup: Beyond Empirical Risk Minimization* by Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, and David Lopez-Paz. 4 | 5 | ``` 6 | python main.py model.mixup={true,false} model.input_only={true,false} 7 | ``` 8 | 9 | mixup can (mathematically) avoid mixing labels by replacing `beta(alpha, alpha)` with `beta(alpha+1, alpha)`. `model.input_only=true` is to confirm this. (*) 10 | 11 | ## Requirements 12 | 13 | * PyTorch==1.6.0 14 | * torchvision==0.7.0 15 | * homura `pip install -U git+https://github.com/moskoumule/homura@v2020.08` 16 | 17 | ## Results 18 | 19 | ### CIFAR10 on ResNet-20 20 | 21 | | ERM | mixup | mixup (`input_only=true`) | 22 | |:--- |:--- |:--- | 23 | | 0.923 | 0.932 | 0.931 | 24 | 25 | The results suggest that the alternative mixup strategy (*) is as effective as the original. -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | from homura import trainers, optim, lr_scheduler 7 | from homura.metrics import accuracy 8 | from homura.vision import DATASET_REGISTRY, MODEL_REGISTRY 9 | from torch.nn import functional as F 10 | 11 | 12 | def partial_mixup(input: torch.Tensor, 13 | gamma: float, 14 | indices: torch.Tensor 15 | ) -> torch.Tensor: 16 | if input.size(0) != indices.size(0): 17 | raise RuntimeError("Size mismatch!") 18 | perm_input = input[indices] 19 | return input.mul(gamma).add(perm_input, alpha=1 - gamma) 20 | 21 | 22 | def mixup(input: torch.Tensor, 23 | target: torch.Tensor, 24 | gamma: float, 25 | ) -> Tuple[torch.Tensor, torch.Tensor]: 26 | indices = torch.randperm(input.size(0), device=input.device, dtype=torch.long) 27 | return partial_mixup(input, gamma, indices), partial_mixup(target, gamma, indices) 28 | 29 | 30 | def naive_cross_entropy_loss(input: torch.Tensor, 31 | target: torch.Tensor 32 | ) -> torch.Tensor: 33 | return -(input.log_softmax(dim=-1) * target).sum(dim=-1).mean() 34 | 35 | 36 | class Trainer(trainers.SupervisedTrainer): 37 | def iteration(self, 38 | data): 39 | input, target = data 40 | _target = target 41 | target = F.one_hot(target, self.num_classes) 42 | if self.is_train and self.cfg.mixup: 43 | if self.cfg.input_only: 44 | input = partial_mixup(input, np.random.beta(self.cfg.alpha + 1, self.cfg.alpha), 45 | torch.randperm(input.size(0), device=input.device, dtype=torch.long)) 46 | else: 47 | input, target = mixup(input, target, np.random.beta(self.cfg.alpha, self.cfg.alpha)) 48 | 49 | output = self.model(input) 50 | loss = self.loss_f(output, target) 51 | 52 | self.reporter.add('loss', loss.detach()) 53 | self.reporter.add('accuracy', accuracy(output, _target)) 54 | 55 | if self.is_train: 56 | self.optimizer.zero_grad() 57 | loss.backward() 58 | self.optimizer.step() 59 | 60 | 61 | @hydra.main("main.yaml") 62 | def main(cfg): 63 | train_loader, test_loader, num_classes = DATASET_REGISTRY(cfg.data.name)(cfg.data.batch_size, 64 | return_num_classes=True, 65 | num_workers=4) 66 | model = MODEL_REGISTRY(cfg.model.name)(num_classes=num_classes) 67 | optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=cfg.optim.weight_decay) 68 | scheduler = lr_scheduler.CosineAnnealingWithWarmup(200, 4, 5) 69 | 70 | with Trainer(model, optimizer, naive_cross_entropy_loss, scheduler=scheduler, cfg=cfg.model, 71 | num_classes=num_classes) as trainer: 72 | for _ in trainer.epoch_range(200): 73 | trainer.train(train_loader) 74 | trainer.test(test_loader) 75 | print(f"Max Test Accuracy={max(trainer.reporter.history('accuracy/test')):.4f}") 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /main.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: cifar10 3 | batch_size: 128 4 | 5 | optim: 6 | epochs: 200 7 | weight_decay: 1e-4 8 | lr_decay: 0.1 9 | 10 | model: 11 | name: resnet20 12 | alpha: 0.2 13 | mixup: true 14 | input_only: false 15 | --------------------------------------------------------------------------------