├── .gitignore ├── README.md ├── dataset ├── cifar.py └── randaugment.py ├── models ├── ema.py ├── resnext.py └── wideresnet.py ├── train.py └── utils └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.npz 3 | *.jpg 4 | *.JPG 5 | *.jpeg 6 | *.JPEG 7 | *.png 8 | *.PNG 9 | *.webp 10 | *.WEBP 11 | *.gif 12 | *.GIF 13 | checkpoint/ 14 | data/ 15 | debug/ 16 | wandb/ 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UDA-pytorch 2 | An unofficial PyTorch implementation of [Unsupervised Data Augmentation for Consistency Training](https://arxiv.org/abs/1904.12848) (UDA). 3 | The official Tensorflow implementation is [here](https://github.com/google-research/uda). 4 | 5 | This code is only available in UDA for image classifications. 6 | 7 | 8 | ## Results 9 | 10 | | | CIFAR-10-4K | SVHN-1K | 11 | |:---:|:---:|:---:| 12 | | Paper (WRN-28-2) | 95.68 ± 0.08 | 97.77 ± 0.07 | 13 | | This code (WRN-28-2) | - | - | 14 | | Acc. curve | - | - | 15 | 16 | \* This code has not been tested, but only part of [my FixMatch code](https://github.com/kekmodel/FixMatch-pytorch) that has been tested several times has been modified. 17 | 18 | ## Requirements 19 | - python 3.6+ 20 | - torch 1.4 21 | - torchvision 0.5 22 | - tensorboard 23 | - numpy 24 | - tqdm 25 | - apex (optional) 26 | 27 | 28 | ## Citations 29 | ``` 30 | @article{xie2019unsupervised, 31 | title={Unsupervised Data Augmentation for Consistency Training}, 32 | author={Xie, Qizhe and Dai, Zihang and Hovy, Eduard and Luong, Minh-Thang and Le, Quoc V}, 33 | journal={arXiv preprint arXiv:1904.12848}, 34 | year={2019} 35 | } 36 | 37 | @article{cubuk2019randaugment, 38 | title={RandAugment: Practical data augmentation with no separate search}, 39 | author={Cubuk, Ekin D and Zoph, Barret and Shlens, Jonathon and Le, Quoc V}, 40 | journal={arXiv preprint arXiv:1909.13719}, 41 | year={2019} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /dataset/cifar.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from torchvision import datasets 7 | from torchvision import transforms 8 | 9 | from .randaugment import RandAugment 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | cifar10_mean = (0.4914, 0.4822, 0.4465) 14 | cifar10_std = (0.2471, 0.2435, 0.2616) 15 | cifar100_mean = (0.5071, 0.4867, 0.4408) 16 | cifar100_std = (0.2675, 0.2565, 0.2761) 17 | normal_mean = (0.5, 0.5, 0.5) 18 | normal_std = (0.5, 0.5, 0.5) 19 | 20 | 21 | def get_cifar10(args, root): 22 | transform_labeled = transforms.Compose([ 23 | transforms.RandomHorizontalFlip(), 24 | transforms.RandomCrop(size=32, 25 | padding=int(32*0.125), 26 | padding_mode='reflect'), 27 | transforms.ToTensor(), 28 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 29 | ]) 30 | transform_val = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 33 | ]) 34 | base_dataset = datasets.CIFAR10(root, train=True, download=True) 35 | 36 | train_labeled_idxs, train_unlabeled_idxs = x_u_split( 37 | args, base_dataset.targets) 38 | 39 | train_labeled_dataset = CIFAR10SSL( 40 | root, train_labeled_idxs, train=True, 41 | transform=transform_labeled) 42 | 43 | train_unlabeled_dataset = CIFAR10SSL( 44 | root, train_unlabeled_idxs, train=True, 45 | transform=TransformUDA(mean=cifar10_mean, std=cifar10_std)) 46 | 47 | test_dataset = datasets.CIFAR10( 48 | root, train=False, transform=transform_val, download=False) 49 | 50 | return train_labeled_dataset, train_unlabeled_dataset, test_dataset 51 | 52 | 53 | def get_cifar100(args, root): 54 | 55 | transform_labeled = transforms.Compose([ 56 | transforms.RandomHorizontalFlip(), 57 | transforms.RandomCrop(size=32, 58 | padding=int(32*0.125), 59 | padding_mode='reflect'), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 62 | 63 | transform_val = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 66 | 67 | base_dataset = datasets.CIFAR100( 68 | root, train=True, download=True) 69 | 70 | train_labeled_idxs, train_unlabeled_idxs = x_u_split( 71 | args, base_dataset.targets) 72 | 73 | train_labeled_dataset = CIFAR100SSL( 74 | root, train_labeled_idxs, train=True, 75 | transform=transform_labeled) 76 | 77 | train_unlabeled_dataset = CIFAR100SSL( 78 | root, train_unlabeled_idxs, train=True, 79 | transform=TransformUDA(mean=cifar100_mean, std=cifar100_std)) 80 | 81 | test_dataset = datasets.CIFAR100( 82 | root, train=False, transform=transform_val, download=False) 83 | 84 | return train_labeled_dataset, train_unlabeled_dataset, test_dataset 85 | 86 | 87 | def x_u_split(args, labels): 88 | label_per_class = args.num_labeled // args.num_classes 89 | labels = np.array(labels) 90 | labeled_idx = [] 91 | # unlabeled data: all data 92 | unlabeled_idx = np.array(range(len(labels))) 93 | for i in range(args.num_classes): 94 | idx = np.where(labels == i)[0] 95 | idx = np.random.choice(idx, label_per_class, False) 96 | labeled_idx.extend(idx) 97 | labeled_idx = np.array(labeled_idx) 98 | assert len(labeled_idx) == args.num_labeled 99 | 100 | if args.expand_labels or args.num_labeled < args.batch_size: 101 | num_expand_x = math.ceil( 102 | args.batch_size * args.eval_step / args.num_labeled) 103 | labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)]) 104 | np.random.shuffle(labeled_idx) 105 | return labeled_idx, unlabeled_idx 106 | 107 | 108 | class TransformUDA(object): 109 | def __init__(self, mean, std): 110 | self.weak = transforms.Compose([ 111 | transforms.RandomHorizontalFlip(), 112 | transforms.RandomCrop(size=32, 113 | padding=int(32*0.125), 114 | padding_mode='reflect')]) 115 | self.strong = transforms.Compose([ 116 | transforms.RandomHorizontalFlip(), 117 | transforms.RandomCrop(size=32, 118 | padding=int(32*0.125), 119 | padding_mode='reflect'), 120 | RandAugment(n=2, m=10)]) 121 | self.normalize = transforms.Compose([ 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean=mean, std=std)]) 124 | 125 | def __call__(self, x): 126 | weak = self.weak(x) 127 | strong = self.strong(x) 128 | return self.normalize(weak), self.normalize(strong) 129 | 130 | 131 | class CIFAR10SSL(datasets.CIFAR10): 132 | def __init__(self, root, indexs, train=True, 133 | transform=None, target_transform=None, 134 | download=False): 135 | super().__init__(root, train=train, 136 | transform=transform, 137 | target_transform=target_transform, 138 | download=download) 139 | if indexs is not None: 140 | self.data = self.data[indexs] 141 | self.targets = np.array(self.targets)[indexs] 142 | 143 | def __getitem__(self, index): 144 | img, target = self.data[index], self.targets[index] 145 | img = Image.fromarray(img) 146 | 147 | if self.transform is not None: 148 | img = self.transform(img) 149 | 150 | if self.target_transform is not None: 151 | target = self.target_transform(target) 152 | 153 | return img, target 154 | 155 | 156 | class CIFAR100SSL(datasets.CIFAR100): 157 | def __init__(self, root, indexs, train=True, 158 | transform=None, target_transform=None, 159 | download=False): 160 | super().__init__(root, train=train, 161 | transform=transform, 162 | target_transform=target_transform, 163 | download=download) 164 | if indexs is not None: 165 | self.data = self.data[indexs] 166 | self.targets = np.array(self.targets)[indexs] 167 | 168 | def __getitem__(self, index): 169 | img, target = self.data[index], self.targets[index] 170 | img = Image.fromarray(img) 171 | 172 | if self.transform is not None: 173 | img = self.transform(img) 174 | 175 | if self.target_transform is not None: 176 | target = self.target_transform(target) 177 | 178 | return img, target 179 | 180 | 181 | DATASET_GETTERS = {'cifar10': get_cifar10, 182 | 'cifar100': get_cifar100} 183 | -------------------------------------------------------------------------------- /dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | from audioop import bias 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | PARAMETER_MAX = 10 19 | RESAMPLE_MODE = None 20 | 21 | 22 | def AutoContrast(img, **kwarg): 23 | return PIL.ImageOps.autocontrast(img) 24 | 25 | 26 | def Brightness(img, v, max_v, bias=0): 27 | v = _float_parameter(v, max_v) + bias 28 | return PIL.ImageEnhance.Brightness(img).enhance(v) 29 | 30 | 31 | def Color(img, v, max_v, bias=0): 32 | v = _float_parameter(v, max_v) + bias 33 | return PIL.ImageEnhance.Color(img).enhance(v) 34 | 35 | 36 | def Contrast(img, v, max_v, bias=0): 37 | v = _float_parameter(v, max_v) + bias 38 | return PIL.ImageEnhance.Contrast(img).enhance(v) 39 | 40 | 41 | def Cutout(img, v, max_v, **kwarg): 42 | if v == 0: 43 | return img 44 | v = _float_parameter(v, max_v) 45 | v = int(v * min(img.size)) 46 | w, h = img.size 47 | x0 = np.random.uniform(0, w) 48 | y0 = np.random.uniform(0, h) 49 | x0 = int(max(0, x0 - v / 2.)) 50 | y0 = int(max(0, y0 - v / 2.)) 51 | x1 = int(min(w, x0 + v)) 52 | y1 = int(min(h, y0 + v)) 53 | xy = (x0, y0, x1, y1) 54 | # gray 55 | color = (127, 127, 127) 56 | img = img.copy() 57 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 58 | return img 59 | 60 | 61 | def CutoutConst(img, v, max_v, **kwarg): 62 | v = _int_parameter(v, max_v) 63 | w, h = img.size 64 | x0 = np.random.uniform(0, w) 65 | y0 = np.random.uniform(0, h) 66 | x0 = int(max(0, x0 - v / 2.)) 67 | y0 = int(max(0, y0 - v / 2.)) 68 | x1 = int(min(w, x0 + v)) 69 | y1 = int(min(h, y0 + v)) 70 | xy = (x0, y0, x1, y1) 71 | # gray 72 | color = (127, 127, 127) 73 | img = img.copy() 74 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 75 | return img 76 | 77 | 78 | def Equalize(img, **kwarg): 79 | return PIL.ImageOps.equalize(img) 80 | 81 | 82 | def Identity(img, **kwarg): 83 | return img 84 | 85 | 86 | def Invert(img, **kwarg): 87 | return PIL.ImageOps.invert(img) 88 | 89 | 90 | def Posterize(img, v, max_v, bias, **kwarg): 91 | v = _int_parameter(v, max_v) + bias 92 | return PIL.ImageOps.posterize(img, v) 93 | 94 | 95 | def Rotate(img, v, max_v, **kwarg): 96 | v = _float_parameter(v, max_v) 97 | if random.random() < 0.5: 98 | v = -v 99 | return img.rotate(v) 100 | 101 | 102 | def Sharpness(img, v, max_v, bias): 103 | v = _float_parameter(v, max_v) + bias 104 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 105 | 106 | 107 | def ShearX(img, v, max_v, **kwarg): 108 | v = _float_parameter(v, max_v) 109 | if random.random() < 0.5: 110 | v = -v 111 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0), RESAMPLE_MODE) 112 | 113 | 114 | def ShearY(img, v, max_v, **kwarg): 115 | v = _float_parameter(v, max_v) 116 | if random.random() < 0.5: 117 | v = -v 118 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0), RESAMPLE_MODE) 119 | 120 | 121 | def Solarize(img, v, max_v, **kwarg): 122 | v = _int_parameter(v, max_v) 123 | return PIL.ImageOps.solarize(img, 256 - v) 124 | 125 | 126 | def SolarizeAdd(img, v, max_v, threshold=128, **kwarg): 127 | v = _int_parameter(v, max_v) 128 | if random.random() < 0.5: 129 | v = -v 130 | img_np = np.array(img).astype(np.int) 131 | img_np = img_np + v 132 | img_np = np.clip(img_np, 0, 255) 133 | img_np = img_np.astype(np.uint8) 134 | img = Image.fromarray(img_np) 135 | return PIL.ImageOps.solarize(img, threshold) 136 | 137 | 138 | def TranslateX(img, v, max_v, **kwarg): 139 | v = _float_parameter(v, max_v) 140 | if random.random() < 0.5: 141 | v = -v 142 | v = int(v * img.size[0]) 143 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), RESAMPLE_MODE) 144 | 145 | 146 | def TranslateY(img, v, max_v, **kwarg): 147 | v = _float_parameter(v, max_v) 148 | if random.random() < 0.5: 149 | v = -v 150 | v = int(v * img.size[1]) 151 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), RESAMPLE_MODE) 152 | 153 | 154 | def TranslateXConst(img, v, max_v, **kwarg): 155 | v = _float_parameter(v, max_v) 156 | if random.random() > 0.5: 157 | v = -v 158 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), RESAMPLE_MODE) 159 | 160 | 161 | def TranslateYConst(img, v, max_v, **kwarg): 162 | v = _float_parameter(v, max_v) 163 | if random.random() > 0.5: 164 | v = -v 165 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), RESAMPLE_MODE) 166 | 167 | 168 | def _float_parameter(v, max_v): 169 | return float(v) * max_v / PARAMETER_MAX 170 | 171 | 172 | def _int_parameter(v, max_v): 173 | return int(v * max_v / PARAMETER_MAX) 174 | 175 | 176 | def rand_augment_pool(): 177 | augs = [(AutoContrast, None, None), 178 | (Brightness, 1.8, 0.1), 179 | (Color, 1.8, 0.1), 180 | (Contrast, 1.8, 0.1), 181 | (CutoutConst, 40, None), 182 | (Equalize, None, None), 183 | (Invert, None, None), 184 | (Posterize, 4, 0), 185 | (Rotate, 30, None), 186 | (Sharpness, 1.8, 0.1), 187 | (ShearX, 0.3, None), 188 | (ShearY, 0.3, None), 189 | (Solarize, 256, None), 190 | (TranslateXConst, 100, None), 191 | (TranslateYConst, 100, None), 192 | ] 193 | return augs 194 | 195 | 196 | class RandAugment(object): 197 | def __init__(self, n, m, resample_mode=PIL.Image.BILINEAR): 198 | assert n >= 1 199 | assert m >= 1 200 | global RESAMPLE_MODE 201 | RESAMPLE_MODE = resample_mode 202 | self.n = n 203 | self.m = m 204 | self.augment_pool = rand_augment_pool() 205 | 206 | def __call__(self, img): 207 | ops = random.choices(self.augment_pool, k=self.n) 208 | for op, max_v, bias in ops: 209 | prob = np.random.uniform(0.2, 0.8) 210 | if random.random() + prob >= 1: 211 | img = op(img, v=self.m, max_v=max_v, bias=bias) 212 | return img 213 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ModelEMA(nn.Module): 8 | def __init__(self, model, decay=0.9999, device=None): 9 | super().__init__() 10 | # make a copy of the model for accumulating moving average of weights 11 | self.module = deepcopy(model) 12 | self.module.eval() 13 | self.decay = decay 14 | self.device = device # perform ema on different device from model if set 15 | if self.device is not None: 16 | self.module.to(device=device) 17 | 18 | def forward(self, input): 19 | return self.module(input) 20 | 21 | def _update(self, model, update_fn): 22 | with torch.no_grad(): 23 | for ema_v, model_v in zip(self.module.parameters(), model.parameters()): 24 | if self.device is not None: 25 | model_v = model_v.to(device=self.device) 26 | ema_v.copy_(update_fn(ema_v, model_v)) 27 | for ema_v, model_v in zip(self.module.buffers(), model.buffers()): 28 | if self.device is not None: 29 | model_v = model_v.to(device=self.device) 30 | ema_v.copy_(model_v) 31 | 32 | def update_parameters(self, model): 33 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 34 | 35 | def state_dict(self): 36 | return self.module.state_dict() 37 | 38 | def load_state_dict(self, state_dict): 39 | self.module.load_state_dict(state_dict) -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def mish(x): 11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 12 | return x * torch.tanh(F.softplus(x)) 13 | 14 | 15 | class PSBatchNorm2d(nn.BatchNorm2d): 16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 17 | 18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True): 19 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 20 | self.alpha = alpha 21 | 22 | def forward(self, x): 23 | return super().forward(x) + self.alpha 24 | 25 | 26 | class ResNeXtBottleneck(nn.Module): 27 | """ 28 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 29 | """ 30 | 31 | def __init__(self, in_channels, out_channels, stride, 32 | cardinality, base_width, widen_factor): 33 | """ Constructor 34 | Args: 35 | in_channels: input channel dimensionality 36 | out_channels: output channel dimensionality 37 | stride: conv stride. Replaces pooling layer. 38 | cardinality: num of convolution groups. 39 | base_width: base number of channels in each group. 40 | widen_factor: factor to reduce the input dimensionality before convolution. 41 | """ 42 | super().__init__() 43 | width_ratio = out_channels / (widen_factor * 64.) 44 | D = cardinality * int(base_width * width_ratio) 45 | self.conv_reduce = nn.Conv2d( 46 | in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 47 | self.bn_reduce = nn.BatchNorm2d(D, momentum=0.001) 48 | self.conv_conv = nn.Conv2d(D, D, 49 | kernel_size=3, stride=stride, padding=1, 50 | groups=cardinality, bias=False) 51 | self.bn = nn.BatchNorm2d(D, momentum=0.001) 52 | self.act = mish 53 | self.conv_expand = nn.Conv2d( 54 | D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 55 | self.bn_expand = nn.BatchNorm2d(out_channels, momentum=0.001) 56 | 57 | self.shortcut = nn.Sequential() 58 | if in_channels != out_channels: 59 | self.shortcut.add_module('shortcut_conv', 60 | nn.Conv2d(in_channels, out_channels, 61 | kernel_size=1, 62 | stride=stride, 63 | padding=0, 64 | bias=False)) 65 | self.shortcut.add_module( 66 | 'shortcut_bn', nn.BatchNorm2d(out_channels, momentum=0.001)) 67 | 68 | def forward(self, x): 69 | bottleneck = self.conv_reduce.forward(x) 70 | bottleneck = self.act(self.bn_reduce.forward(bottleneck)) 71 | bottleneck = self.conv_conv.forward(bottleneck) 72 | bottleneck = self.act(self.bn.forward(bottleneck)) 73 | bottleneck = self.conv_expand.forward(bottleneck) 74 | bottleneck = self.bn_expand.forward(bottleneck) 75 | residual = self.shortcut.forward(x) 76 | return self.act(residual + bottleneck) 77 | 78 | 79 | class CifarResNeXt(nn.Module): 80 | """ 81 | ResNext optimized for the Cifar dataset, as specified in 82 | https://arxiv.org/pdf/1611.05431.pdf 83 | """ 84 | 85 | def __init__(self, cardinality, depth, num_classes, 86 | base_width, widen_factor=4): 87 | """ Constructor 88 | Args: 89 | cardinality: number of convolution groups. 90 | depth: number of layers. 91 | nlabels: number of classes 92 | base_width: base number of channels in each group. 93 | widen_factor: factor to adjust the channel dimensionality 94 | """ 95 | super().__init__() 96 | self.cardinality = cardinality 97 | self.depth = depth 98 | self.block_depth = (self.depth - 2) // 9 99 | self.base_width = base_width 100 | self.widen_factor = widen_factor 101 | self.nlabels = num_classes 102 | self.output_size = 64 103 | self.stages = [64, 64 * self.widen_factor, 128 * 104 | self.widen_factor, 256 * self.widen_factor] 105 | 106 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 107 | self.bn_1 = nn.BatchNorm2d(64, momentum=0.001) 108 | self.act = mish 109 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 110 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 111 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 112 | self.classifier = nn.Linear(self.stages[3], num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, 117 | mode='fan_out', 118 | nonlinearity='leaky_relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1.0) 121 | nn.init.constant_(m.bias, 0.0) 122 | elif isinstance(m, nn.Linear): 123 | nn.init.xavier_normal_(m.weight) 124 | nn.init.constant_(m.bias, 0.0) 125 | 126 | def block(self, name, in_channels, out_channels, pool_stride=2): 127 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 128 | Args: 129 | name: string name of the current block. 130 | in_channels: number of input channels 131 | out_channels: number of output channels 132 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 133 | Returns: a Module consisting of n sequential bottlenecks. 134 | """ 135 | block = nn.Sequential() 136 | for bottleneck in range(self.block_depth): 137 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 138 | if bottleneck == 0: 139 | block.add_module(name_, ResNeXtBottleneck(in_channels, 140 | out_channels, 141 | pool_stride, 142 | self.cardinality, 143 | self.base_width, 144 | self.widen_factor)) 145 | else: 146 | block.add_module(name_, 147 | ResNeXtBottleneck(out_channels, 148 | out_channels, 149 | 1, 150 | self.cardinality, 151 | self.base_width, 152 | self.widen_factor)) 153 | return block 154 | 155 | def forward(self, x): 156 | x = self.conv_1_3x3.forward(x) 157 | x = self.act(self.bn_1.forward(x)) 158 | x = self.stage_1.forward(x) 159 | x = self.stage_2.forward(x) 160 | x = self.stage_3.forward(x) 161 | x = F.adaptive_avg_pool2d(x, 1) 162 | x = x.view(-1, self.stages[3]) 163 | return self.classifier(x) 164 | 165 | 166 | def build_resnext(cardinality, depth, width, num_classes): 167 | logger.info(f"Model: ResNeXt {depth+1}x{width}") 168 | return CifarResNeXt(cardinality=cardinality, 169 | depth=depth, 170 | base_width=width, 171 | num_classes=num_classes) 172 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def mish(x): 11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 12 | return x * torch.tanh(F.softplus(x)) 13 | 14 | 15 | class PSBatchNorm2d(nn.BatchNorm2d): 16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 17 | 18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True): 19 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 20 | self.alpha = alpha 21 | 22 | def forward(self, x): 23 | return super().forward(x) + self.alpha 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False): 28 | super(BasicBlock, self).__init__() 29 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 30 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 31 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 34 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 35 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 36 | padding=1, bias=False) 37 | self.drop_rate = drop_rate 38 | self.equalInOut = (in_planes == out_planes) 39 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 40 | padding=0, bias=False) or None 41 | self.activate_before_residual = activate_before_residual 42 | 43 | def forward(self, x): 44 | if not self.equalInOut and self.activate_before_residual == True: 45 | x = self.relu1(self.bn1(x)) 46 | else: 47 | out = self.relu1(self.bn1(x)) 48 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 49 | if self.drop_rate > 0: 50 | out = F.dropout(out, p=self.drop_rate, training=self.training) 51 | out = self.conv2(out) 52 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 53 | 54 | 55 | class NetworkBlock(nn.Module): 56 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False): 57 | super(NetworkBlock, self).__init__() 58 | self.layer = self._make_layer( 59 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual) 60 | 61 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual): 62 | layers = [] 63 | for i in range(int(nb_layers)): 64 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, 65 | i == 0 and stride or 1, drop_rate, activate_before_residual)) 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | return self.layer(x) 70 | 71 | 72 | class WideResNet(nn.Module): 73 | def __init__(self, num_classes, depth=28, widen_factor=2, drop_rate=0.0): 74 | super(WideResNet, self).__init__() 75 | channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 76 | assert((depth - 4) % 6 == 0) 77 | n = (depth - 4) / 6 78 | block = BasicBlock 79 | # 1st conv before any network block 80 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = NetworkBlock( 84 | n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual=True) 85 | # 2nd block 86 | self.block2 = NetworkBlock( 87 | n, channels[1], channels[2], block, 2, drop_rate) 88 | # 3rd block 89 | self.block3 = NetworkBlock( 90 | n, channels[2], channels[3], block, 2, drop_rate) 91 | # global average pooling and classifier 92 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001) 93 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 94 | self.fc = nn.Linear(channels[3], num_classes) 95 | self.channels = channels[3] 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, 100 | mode='fan_out', 101 | nonlinearity='leaky_relu') 102 | elif isinstance(m, nn.BatchNorm2d): 103 | nn.init.constant_(m.weight, 1.0) 104 | nn.init.constant_(m.bias, 0.0) 105 | elif isinstance(m, nn.Linear): 106 | nn.init.xavier_normal_(m.weight) 107 | nn.init.constant_(m.bias, 0.0) 108 | 109 | def forward(self, x): 110 | out = self.conv1(x) 111 | out = self.block1(out) 112 | out = self.block2(out) 113 | out = self.block3(out) 114 | out = self.relu(self.bn1(out)) 115 | out = F.adaptive_avg_pool2d(out, 1) 116 | out = out.view(-1, self.channels) 117 | return self.fc(out) 118 | 119 | 120 | def build_wideresnet(depth, widen_factor, dropout, num_classes): 121 | logger.info(f"Model: WideResNet {depth}x{widen_factor}") 122 | return WideResNet(depth=depth, 123 | widen_factor=widen_factor, 124 | drop_rate=dropout, 125 | num_classes=num_classes) 126 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | from collections import OrderedDict 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.tensorboard import SummaryWriter 18 | from tqdm import tqdm 19 | 20 | from dataset.cifar import DATASET_GETTERS 21 | from utils import AverageMeter, accuracy 22 | 23 | logger = logging.getLogger(__name__) 24 | best_acc = 0 25 | 26 | 27 | def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'): 28 | filepath = os.path.join(checkpoint, filename) 29 | torch.save(state, filepath) 30 | if is_best: 31 | shutil.copyfile(filepath, os.path.join(checkpoint, 32 | 'model_best.pth.tar')) 33 | 34 | 35 | def set_seed(args): 36 | random.seed(args.seed) 37 | np.random.seed(args.seed) 38 | torch.manual_seed(args.seed) 39 | torch.cuda.manual_seed_all(args.seed) 40 | 41 | 42 | def create_model(args): 43 | if args.arch == 'wideresnet': 44 | import models.wideresnet as models 45 | model = models.build_wideresnet(depth=args.model_depth, 46 | widen_factor=args.model_width, 47 | dropout=0, 48 | num_classes=args.num_classes) 49 | elif args.arch == 'resnext': 50 | import models.resnext as models 51 | model = models.build_resnext(cardinality=args.model_cardinality, 52 | depth=args.model_depth, 53 | width=args.model_width, 54 | num_classes=args.num_classes) 55 | logger.info("Total params: {:.2f}M".format(sum(p.numel() for p in model.parameters())/1e6)) 56 | return model 57 | 58 | 59 | def get_cosine_schedule_with_warmup(optimizer, 60 | num_warmup_steps, 61 | num_training_steps, 62 | num_cycles=7./16., 63 | last_epoch=-1): 64 | def _lr_lambda(current_step): 65 | if current_step < num_warmup_steps: 66 | return float(current_step) / float(max(1, num_warmup_steps)) 67 | no_progress = float(current_step - num_warmup_steps) / \ 68 | float(max(1, num_training_steps - num_warmup_steps)) 69 | return max(0., math.cos(math.pi * num_cycles * no_progress)) 70 | 71 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 72 | 73 | 74 | def interleave(x, size): 75 | s = list(x.shape) 76 | return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:]) 77 | 78 | 79 | def de_interleave(x, size): 80 | s = list(x.shape) 81 | return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:]) 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--gpu-id', default='0', type=int, 87 | help='id for CUDA_VISIBLE_DEVICES') 88 | parser.add_argument('--num-workers', type=int, default=4, 89 | help='number of workers') 90 | parser.add_argument('--dataset', default='cifar10', type=str, 91 | choices=['cifar10', 'cifar100'], 92 | help='dataset name') 93 | parser.add_argument('--num-labeled', type=int, default=4000, 94 | help='number of labeled data') 95 | parser.add_argument("--expand-labels", action="store_true", 96 | help="expand labels to fit eval steps") 97 | parser.add_argument('--arch', default='wideresnet', type=str, 98 | choices=['wideresnet', 'resnext'], 99 | help='model architecture name') 100 | parser.add_argument('--total-steps', default=2**20, type=int, 101 | help='number of total steps to run') 102 | parser.add_argument('--eval-step', default=1024, type=int, 103 | help='number of eval steps to run') 104 | parser.add_argument('--start-epoch', default=0, type=int, 105 | help='manual epoch number (useful on restarts)') 106 | parser.add_argument('--batch-size', default=64, type=int, 107 | help='train batchsize') 108 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 109 | help='initial learning rate') 110 | parser.add_argument('--warmup', default=0, type=float, 111 | help='warmup epochs (unlabeled data based)') 112 | parser.add_argument('--wdecay', default=5e-4, type=float, 113 | help='weight decay') 114 | parser.add_argument('--nesterov', action='store_true', default=True, 115 | help='use nesterov momentum') 116 | parser.add_argument('--use-ema', action='store_true', default=True, 117 | help='use EMA model') 118 | parser.add_argument('--ema-decay', default=0.999, type=float, 119 | help='EMA decay rate') 120 | parser.add_argument('--mu', default=7, type=int, 121 | help='coefficient of unlabeled batch size') 122 | parser.add_argument('--lambda-u', default=1, type=float, 123 | help='coefficient of unlabeled loss') 124 | parser.add_argument('--T', default=0.4, type=float, 125 | help='pseudo label temperature') 126 | parser.add_argument('--threshold', default=0.8, type=float, 127 | help='pseudo label threshold') 128 | parser.add_argument('--out', default='result', 129 | help='directory to output the result') 130 | parser.add_argument('--resume', default='', type=str, 131 | help='path to checkpoint') 132 | parser.add_argument('--seed', default=None, type=int, 133 | help="random seed") 134 | parser.add_argument("--amp", action="store_true", 135 | help="use 16-bit (mixed) precision") 136 | parser.add_argument("--opt_level", type=str, default="O1", 137 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 138 | "See details at https://nvidia.github.io/apex/amp.html") 139 | parser.add_argument("--local_rank", type=int, default=-1, 140 | help="For distributed training: local_rank") 141 | 142 | args = parser.parse_args() 143 | global best_acc 144 | 145 | if args.local_rank != -1: 146 | args.gpu_id = args.local_rank 147 | torch.distributed.init_process_group(backend='nccl') 148 | args.world_size = torch.distributed.get_world_size() 149 | args.n_gpu = 1 150 | else: 151 | args.world_size = 1 152 | args.n_gpu = torch.cuda.device_count() 153 | 154 | args.device = torch.device('cuda', args.gpu_id) 155 | 156 | logging.basicConfig( 157 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 158 | datefmt="%m/%d/%Y %H:%M:%S", 159 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 160 | 161 | logger.warning( 162 | f"Process rank: {args.local_rank}, " 163 | f"device: {args.device}, " 164 | f"n_gpu: {args.n_gpu}, " 165 | f"distributed training: {bool(args.local_rank != -1)}, " 166 | f"16-bits training: {args.amp}",) 167 | 168 | logger.info(dict(args._get_kwargs())) 169 | 170 | if args.seed is not None: 171 | set_seed(args) 172 | 173 | if args.local_rank in [-1, 0]: 174 | os.makedirs(args.out, exist_ok=True) 175 | writer = SummaryWriter(args.out) 176 | 177 | if args.dataset == 'cifar10': 178 | args.num_classes = 10 179 | if args.arch == 'wideresnet': 180 | args.model_depth = 28 181 | args.model_width = 2 182 | elif args.arch == 'resnext': 183 | args.model_cardinality = 4 184 | args.model_depth = 28 185 | args.model_width = 4 186 | 187 | elif args.dataset == 'cifar100': 188 | args.num_classes = 100 189 | if args.arch == 'wideresnet': 190 | args.model_depth = 28 191 | args.model_width = 8 192 | elif args.arch == 'resnext': 193 | args.model_cardinality = 8 194 | args.model_depth = 29 195 | args.model_width = 64 196 | 197 | if args.local_rank not in [-1, 0]: 198 | torch.distributed.barrier() 199 | 200 | labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](args, './data') 201 | 202 | if args.local_rank == 0: 203 | torch.distributed.barrier() 204 | 205 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler 206 | 207 | labeled_trainloader = DataLoader( 208 | labeled_dataset, 209 | sampler=train_sampler(labeled_dataset), 210 | batch_size=args.batch_size, 211 | num_workers=args.num_workers, 212 | drop_last=True, pin_memory=True) 213 | 214 | unlabeled_trainloader = DataLoader( 215 | unlabeled_dataset, 216 | sampler=train_sampler(unlabeled_dataset), 217 | batch_size=args.batch_size*args.mu, 218 | num_workers=args.num_workers, 219 | drop_last=True, pin_memory=True) 220 | 221 | test_loader = DataLoader( 222 | test_dataset, 223 | sampler=SequentialSampler(test_dataset), 224 | batch_size=args.batch_size, 225 | num_workers=args.num_workers) 226 | 227 | if args.local_rank not in [-1, 0]: 228 | torch.distributed.barrier() 229 | 230 | model = create_model(args) 231 | 232 | if args.local_rank == 0: 233 | torch.distributed.barrier() 234 | 235 | model.to(args.device) 236 | 237 | no_decay = ['bias', 'bn'] 238 | grouped_parameters = [ 239 | {'params': [p for n, p in model.named_parameters() if not any( 240 | nd in n for nd in no_decay)], 'weight_decay': args.wdecay}, 241 | {'params': [p for n, p in model.named_parameters() if any( 242 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 243 | ] 244 | optimizer = optim.SGD(grouped_parameters, lr=args.lr, 245 | momentum=0.9, nesterov=args.nesterov) 246 | 247 | args.epochs = math.ceil(args.total_steps / args.eval_step) 248 | scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup, args.total_steps) 249 | 250 | if args.use_ema: 251 | from models.ema import ModelEMA 252 | ema_model = ModelEMA(args, model, args.ema_decay) 253 | 254 | args.start_epoch = 0 255 | 256 | if args.resume: 257 | logger.info("==> Resuming from checkpoint..") 258 | assert os.path.isfile(args.resume), "Error: no checkpoint directory found!" 259 | args.out = os.path.dirname(args.resume) 260 | checkpoint = torch.load(args.resume) 261 | best_acc = checkpoint['best_acc'] 262 | args.start_epoch = checkpoint['epoch'] 263 | model.load_state_dict(checkpoint['state_dict']) 264 | if args.use_ema: 265 | ema_model.load_state_dict(checkpoint['ema_state_dict']) 266 | optimizer.load_state_dict(checkpoint['optimizer']) 267 | scheduler.load_state_dict(checkpoint['scheduler']) 268 | 269 | if args.amp: 270 | from apex import amp 271 | model, optimizer = amp.initialize( 272 | model, optimizer, opt_level=args.opt_level) 273 | 274 | if args.local_rank != -1: 275 | model = torch.nn.parallel.DistributedDataParallel( 276 | model, device_ids=[args.local_rank], 277 | output_device=args.local_rank, find_unused_parameters=True) 278 | 279 | logger.info("***** Running training *****") 280 | logger.info(f" Task = {args.dataset}@{args.num_labeled}") 281 | logger.info(f" Num Epochs = {args.epochs}") 282 | logger.info(f" Batch size per GPU = {args.batch_size}") 283 | logger.info(f" Total train batch size = {args.batch_size*args.world_size}") 284 | logger.info(f" Total optimization steps = {args.total_steps}") 285 | 286 | model.zero_grad() 287 | train(args, labeled_trainloader, unlabeled_trainloader, test_loader, 288 | model, optimizer, ema_model, scheduler, writer) 289 | 290 | 291 | def train(args, labeled_trainloader, unlabeled_trainloader, test_loader, 292 | model, optimizer, ema_model, scheduler, writer): 293 | if args.amp: 294 | from apex import amp 295 | global best_acc 296 | batch_time = AverageMeter() 297 | data_time = AverageMeter() 298 | losses = AverageMeter() 299 | losses_x = AverageMeter() 300 | losses_u = AverageMeter() 301 | mask_probs = AverageMeter() 302 | end = time.time() 303 | model.train() 304 | 305 | if args.world_size > 1: 306 | labeled_epoch = 0 307 | unlabeled_epoch = 0 308 | labeled_trainloader.sampler.set_epoch(labeled_epoch) 309 | unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch) 310 | 311 | labeled_iter = iter(labeled_trainloader) 312 | unlabeled_iter = iter(unlabeled_trainloader) 313 | for epoch in range(args.start_epoch, args.epochs): 314 | p_bar = tqdm(range(args.eval_step), disable=args.local_rank not in [-1, 0]) 315 | for batch_idx in range(args.eval_step): 316 | try: 317 | inputs_x, targets_x = labeled_iter.next() 318 | except: 319 | if args.world_size > 1: 320 | labeled_epoch += 1 321 | labeled_trainloader.sampler.set_epoch(labeled_epoch) 322 | 323 | labeled_iter = iter(labeled_trainloader) 324 | inputs_x, targets_x = labeled_iter.next() 325 | 326 | try: 327 | (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next() 328 | except: 329 | if args.world_size > 1: 330 | unlabeled_epoch += 1 331 | unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch) 332 | 333 | unlabeled_iter = iter(unlabeled_trainloader) 334 | (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next() 335 | 336 | data_time.update(time.time() - end) 337 | batch_size = inputs_x.shape[0] 338 | inputs = torch.cat((inputs_x, inputs_u_w, inputs_u_s)).to(args.device) 339 | targets_x = targets_x.to(args.device) 340 | logits = model(inputs) 341 | logits_x = logits[:batch_size] 342 | logits_u_w, logits_u_s = logits[batch_size:].chunk(2) 343 | del logits 344 | 345 | Lx = F.cross_entropy(logits_x, targets_x, reduction='mean') 346 | 347 | targets_u = torch.softmax(logits_u_w.detach()/args.T, dim=-1) 348 | max_probs, _ = torch.max(targets_u, dim=-1) 349 | mask = max_probs.ge(args.threshold).float() 350 | 351 | Lu = (-(targets_u * torch.log_softmax(logits_u_s, dim=-1)).sum(dim=-1) * mask).mean() 352 | 353 | loss = Lx + args.lambda_u * Lu 354 | 355 | if args.amp: 356 | with amp.scale_loss(loss, optimizer) as scaled_loss: 357 | scaled_loss.backward() 358 | else: 359 | loss.backward() 360 | 361 | losses.update(loss.item()) 362 | losses_x.update(Lx.item()) 363 | losses_u.update(Lu.item()) 364 | optimizer.step() 365 | scheduler.step() 366 | if args.use_ema: 367 | ema_model.update_parameters(model) 368 | model.zero_grad() 369 | 370 | batch_time.update(time.time() - end) 371 | end = time.time() 372 | mask_probs.update(mask.mean().item()) 373 | p_bar.set_description("Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.4f}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. Loss_x: {loss_x:.4f}. Loss_u: {loss_u:.4f}. Mask: {mask:.2f}. ".format( 374 | epoch=epoch + 1, 375 | epochs=args.epochs, 376 | batch=batch_idx + 1, 377 | iter=args.eval_step, 378 | lr=scheduler.get_last_lr()[0], 379 | data=data_time.avg, 380 | bt=batch_time.avg, 381 | loss=losses.avg, 382 | loss_x=losses_x.avg, 383 | loss_u=losses_u.avg, 384 | mask=mask_probs.avg)) 385 | p_bar.update() 386 | p_bar.close() 387 | 388 | if args.use_ema: 389 | test_model = ema_model 390 | else: 391 | test_model = model 392 | 393 | if args.local_rank in [-1, 0]: 394 | test_loss, test_acc = test(args, test_loader, test_model, epoch) 395 | 396 | writer.add_scalar('train/1.train_loss', losses.avg, epoch) 397 | writer.add_scalar('train/2.train_loss_x', losses_x.avg, epoch) 398 | writer.add_scalar('train/3.train_loss_u', losses_u.avg, epoch) 399 | writer.add_scalar('train/4.mask', mask_probs.avg, epoch) 400 | writer.add_scalar('test/1.test_acc', test_acc, epoch) 401 | writer.add_scalar('test/2.test_loss', test_loss, epoch) 402 | 403 | is_best = test_acc > best_acc 404 | best_acc = max(test_acc, best_acc) 405 | 406 | save_checkpoint({ 407 | 'epoch': epoch + 1, 408 | 'state_dict': model.state_dict(), 409 | 'ema_state_dict': ema_model.state_dict() if args.use_ema else None, 410 | 'acc': test_acc, 411 | 'best_acc': best_acc, 412 | 'optimizer': optimizer.state_dict(), 413 | 'scheduler': scheduler.state_dict(), 414 | }, is_best, args.out) 415 | logger.info('Best top-1 acc: {:.2f}'.format(best_acc)) 416 | 417 | if args.local_rank in [-1, 0]: 418 | writer.close() 419 | 420 | 421 | def test(args, test_loader, model, epoch): 422 | batch_time = AverageMeter() 423 | data_time = AverageMeter() 424 | losses = AverageMeter() 425 | top1 = AverageMeter() 426 | top3 = AverageMeter() 427 | end = time.time() 428 | model.eval() 429 | test_loader = tqdm(test_loader, disable=args.local_rank not in [-1, 0]) 430 | with torch.no_grad(): 431 | for batch_idx, (inputs, targets) in enumerate(test_loader): 432 | data_time.update(time.time() - end) 433 | 434 | inputs = inputs.to(args.device) 435 | targets = targets.to(args.device) 436 | outputs = model(inputs) 437 | loss = F.cross_entropy(outputs, targets) 438 | 439 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 3)) 440 | losses.update(loss.item()) 441 | top1.update(prec1.item()) 442 | top3.update(prec5.item()) 443 | batch_time.update(time.time() - end) 444 | end = time.time() 445 | test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top3: {top3:.2f}. ".format( 446 | batch=batch_idx + 1, 447 | iter=len(test_loader), 448 | data=data_time.avg, 449 | bt=batch_time.avg, 450 | loss=losses.avg, 451 | top1=top1.avg, 452 | top3=top3.avg, 453 | )) 454 | test_loader.close() 455 | 456 | logger.info("top-1 acc: {:.2f}".format(top1.avg)) 457 | logger.info("top-3 acc: {:.2f}".format(top3.avg)) 458 | return losses.avg, top1.avg 459 | 460 | 461 | if __name__ == '__main__': 462 | main() 463 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | ''' 4 | import logging 5 | 6 | import torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | __all__ = ['get_mean_and_std', 'accuracy', 'AverageMeter'] 11 | 12 | 13 | def get_mean_and_std(dataset): 14 | '''Compute the mean and std value of dataset.''' 15 | dataloader = torch.utils.data.DataLoader( 16 | dataset, batch_size=1, shuffle=False, num_workers=4) 17 | 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | logger.info('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:, i, :, :].mean() 24 | std[i] += inputs[:, i, :, :].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | """Computes the precision@k for the specified values of k""" 32 | maxk = max(topk) 33 | batch_size = target.size(0) 34 | 35 | _, pred = output.topk(maxk, 1, True, True) 36 | pred = pred.t() 37 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 38 | 39 | res = [] 40 | for k in topk: 41 | correct_k = correct[:k].reshape(-1).float().sum(0) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | return res 44 | 45 | 46 | class AverageMeter(object): 47 | def __init__(self): 48 | self.reset() 49 | 50 | def reset(self): 51 | self.val = 0 52 | self.avg = 0 53 | self.sum = 0 54 | self.count = 0 55 | 56 | def update(self, val, n=1): 57 | self.val = val 58 | self.sum += val * n 59 | self.count += 1 60 | self.avg = self.sum / self.count --------------------------------------------------------------------------------