├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── cifar.py └── randaugment.py ├── models ├── ema.py ├── resnext.py └── wideresnet.py ├── train.py └── utils ├── __init__.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | data 3 | .vscode 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jungdae Kim, Qing Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FixMatch 2 | This is an unofficial PyTorch implementation of [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence](https://arxiv.org/abs/2001.07685). 3 | The official Tensorflow implementation is [here](https://github.com/google-research/fixmatch). 4 | 5 | This code is only available in FixMatch (RandAugment). 6 | 7 | ## Results 8 | 9 | ### CIFAR10 10 | | #Labels | 40 | 250 | 4000 | 11 | |:---:|:---:|:---:|:---:| 12 | | Paper (RA) | 86.19 ± 3.37 | 94.93 ± 0.65 | 95.74 ± 0.05 | 13 | | This code | 93.60 | 95.31 | 95.77 | 14 | | Acc. curve | [link](https://tensorboard.dev/experiment/YcLQA52kQ1KZIgND8bGijw/) | [link](https://tensorboard.dev/experiment/GN36hbbRTDaBPy7z8alE1A/) | [link](https://tensorboard.dev/experiment/5flaQd1WQyS727hZ70ebbA/) | 15 | 16 | \* November 2020. Retested after fixing EMA issues. 17 | ### CIFAR100 18 | | #Labels | 400 | 2500 | 10000 | 19 | |:---:|:---:|:---:|:---:| 20 | | Paper (RA) | 51.15 ± 1.75 | 71.71 ± 0.11 | 77.40 ± 0.12 | 21 | | This code | 57.50 | 72.93 | 78.12 | 22 | | Acc. curve | [link](https://tensorboard.dev/experiment/y4Mmz3hRTQm6rHDlyeso4Q/) | [link](https://tensorboard.dev/experiment/mY3UExn5RpOanO1Hx1vOxg/) | [link](https://tensorboard.dev/experiment/EDb13xzJTWu5leEyVf2qfQ/) | 23 | 24 | \* Training using the following options `--amp --opt_level O2 --wdecay 0.001` 25 | 26 | ## Usage 27 | 28 | ### Train 29 | Train the model by 4000 labeled data of CIFAR-10 dataset: 30 | 31 | ``` 32 | python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5 33 | ``` 34 | 35 | Train the model by 10000 labeled data of CIFAR-100 dataset by using DistributedDataParallel: 36 | ``` 37 | python -m torch.distributed.launch --nproc_per_node 4 ./train.py --dataset cifar100 --num-labeled 10000 --arch wideresnet --batch-size 16 --lr 0.03 --wdecay 0.001 --expand-labels --seed 5 --out results/cifar100@10000 38 | ``` 39 | 40 | ### Monitoring training progress 41 | ``` 42 | tensorboard --logdir= 43 | ``` 44 | 45 | ## Requirements 46 | - python 3.6+ 47 | - torch 1.4 48 | - torchvision 0.5 49 | - tensorboard 50 | - numpy 51 | - tqdm 52 | - apex (optional) 53 | 54 | ## My other implementations 55 | - [Meta Pseudo Labels](https://github.com/kekmodel/MPL-pytorch) 56 | - [UDA for images](https://github.com/kekmodel/UDA-pytorch) 57 | 58 | 59 | ## References 60 | - [Official TensorFlow implementation of FixMatch](https://github.com/google-research/fixmatch) 61 | - [Unofficial PyTorch implementation of MixMatch](https://github.com/YU1ut/MixMatch-pytorch) 62 | - [Unofficial PyTorch Reimplementation of RandAugment](https://github.com/ildoonet/pytorch-randaugment) 63 | - [PyTorch image models](https://github.com/rwightman/pytorch-image-models) 64 | 65 | ## Citations 66 | ``` 67 | @misc{jd2020fixmatch, 68 | author = {Jungdae Kim}, 69 | title = {PyTorch implementation of FixMatch}, 70 | year = {2020}, 71 | publisher = {GitHub}, 72 | journal = {GitHub repository}, 73 | howpublished = {\url{https://github.com/kekmodel/FixMatch-pytorch}} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /dataset/cifar.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from torchvision import datasets 7 | from torchvision import transforms 8 | 9 | from .randaugment import RandAugmentMC 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | cifar10_mean = (0.4914, 0.4822, 0.4465) 14 | cifar10_std = (0.2471, 0.2435, 0.2616) 15 | cifar100_mean = (0.5071, 0.4867, 0.4408) 16 | cifar100_std = (0.2675, 0.2565, 0.2761) 17 | normal_mean = (0.5, 0.5, 0.5) 18 | normal_std = (0.5, 0.5, 0.5) 19 | 20 | 21 | def get_cifar10(args, root): 22 | transform_labeled = transforms.Compose([ 23 | transforms.RandomHorizontalFlip(), 24 | transforms.RandomCrop(size=32, 25 | padding=int(32*0.125), 26 | padding_mode='reflect'), 27 | transforms.ToTensor(), 28 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 29 | ]) 30 | transform_val = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 33 | ]) 34 | base_dataset = datasets.CIFAR10(root, train=True, download=True) 35 | 36 | train_labeled_idxs, train_unlabeled_idxs = x_u_split( 37 | args, base_dataset.targets) 38 | 39 | train_labeled_dataset = CIFAR10SSL( 40 | root, train_labeled_idxs, train=True, 41 | transform=transform_labeled) 42 | 43 | train_unlabeled_dataset = CIFAR10SSL( 44 | root, train_unlabeled_idxs, train=True, 45 | transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std)) 46 | 47 | test_dataset = datasets.CIFAR10( 48 | root, train=False, transform=transform_val, download=False) 49 | 50 | return train_labeled_dataset, train_unlabeled_dataset, test_dataset 51 | 52 | 53 | def get_cifar100(args, root): 54 | 55 | transform_labeled = transforms.Compose([ 56 | transforms.RandomHorizontalFlip(), 57 | transforms.RandomCrop(size=32, 58 | padding=int(32*0.125), 59 | padding_mode='reflect'), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 62 | 63 | transform_val = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 66 | 67 | base_dataset = datasets.CIFAR100( 68 | root, train=True, download=True) 69 | 70 | train_labeled_idxs, train_unlabeled_idxs = x_u_split( 71 | args, base_dataset.targets) 72 | 73 | train_labeled_dataset = CIFAR100SSL( 74 | root, train_labeled_idxs, train=True, 75 | transform=transform_labeled) 76 | 77 | train_unlabeled_dataset = CIFAR100SSL( 78 | root, train_unlabeled_idxs, train=True, 79 | transform=TransformFixMatch(mean=cifar100_mean, std=cifar100_std)) 80 | 81 | test_dataset = datasets.CIFAR100( 82 | root, train=False, transform=transform_val, download=False) 83 | 84 | return train_labeled_dataset, train_unlabeled_dataset, test_dataset 85 | 86 | 87 | def x_u_split(args, labels): 88 | label_per_class = args.num_labeled // args.num_classes 89 | labels = np.array(labels) 90 | labeled_idx = [] 91 | # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10) 92 | unlabeled_idx = np.array(range(len(labels))) 93 | for i in range(args.num_classes): 94 | idx = np.where(labels == i)[0] 95 | idx = np.random.choice(idx, label_per_class, False) 96 | labeled_idx.extend(idx) 97 | labeled_idx = np.array(labeled_idx) 98 | assert len(labeled_idx) == args.num_labeled 99 | 100 | if args.expand_labels or args.num_labeled < args.batch_size: 101 | num_expand_x = math.ceil( 102 | args.batch_size * args.eval_step / args.num_labeled) 103 | labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)]) 104 | np.random.shuffle(labeled_idx) 105 | return labeled_idx, unlabeled_idx 106 | 107 | 108 | class TransformFixMatch(object): 109 | def __init__(self, mean, std): 110 | self.weak = transforms.Compose([ 111 | transforms.RandomHorizontalFlip(), 112 | transforms.RandomCrop(size=32, 113 | padding=int(32*0.125), 114 | padding_mode='reflect')]) 115 | self.strong = transforms.Compose([ 116 | transforms.RandomHorizontalFlip(), 117 | transforms.RandomCrop(size=32, 118 | padding=int(32*0.125), 119 | padding_mode='reflect'), 120 | RandAugmentMC(n=2, m=10)]) 121 | self.normalize = transforms.Compose([ 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean=mean, std=std)]) 124 | 125 | def __call__(self, x): 126 | weak = self.weak(x) 127 | strong = self.strong(x) 128 | return self.normalize(weak), self.normalize(strong) 129 | 130 | 131 | class CIFAR10SSL(datasets.CIFAR10): 132 | def __init__(self, root, indexs, train=True, 133 | transform=None, target_transform=None, 134 | download=False): 135 | super().__init__(root, train=train, 136 | transform=transform, 137 | target_transform=target_transform, 138 | download=download) 139 | if indexs is not None: 140 | self.data = self.data[indexs] 141 | self.targets = np.array(self.targets)[indexs] 142 | 143 | def __getitem__(self, index): 144 | img, target = self.data[index], self.targets[index] 145 | img = Image.fromarray(img) 146 | 147 | if self.transform is not None: 148 | img = self.transform(img) 149 | 150 | if self.target_transform is not None: 151 | target = self.target_transform(target) 152 | 153 | return img, target 154 | 155 | 156 | class CIFAR100SSL(datasets.CIFAR100): 157 | def __init__(self, root, indexs, train=True, 158 | transform=None, target_transform=None, 159 | download=False): 160 | super().__init__(root, train=train, 161 | transform=transform, 162 | target_transform=target_transform, 163 | download=download) 164 | if indexs is not None: 165 | self.data = self.data[indexs] 166 | self.targets = np.array(self.targets)[indexs] 167 | 168 | def __getitem__(self, index): 169 | img, target = self.data[index], self.targets[index] 170 | img = Image.fromarray(img) 171 | 172 | if self.transform is not None: 173 | img = self.transform(img) 174 | 175 | if self.target_transform is not None: 176 | target = self.target_transform(target) 177 | 178 | return img, target 179 | 180 | 181 | DATASET_GETTERS = {'cifar10': get_cifar10, 182 | 'cifar100': get_cifar100} 183 | -------------------------------------------------------------------------------- /dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | # gray 57 | color = (127, 127, 127) 58 | img = img.copy() 59 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 60 | return img 61 | 62 | 63 | def Equalize(img, **kwarg): 64 | return PIL.ImageOps.equalize(img) 65 | 66 | 67 | def Identity(img, **kwarg): 68 | return img 69 | 70 | 71 | def Invert(img, **kwarg): 72 | return PIL.ImageOps.invert(img) 73 | 74 | 75 | def Posterize(img, v, max_v, bias=0): 76 | v = _int_parameter(v, max_v) + bias 77 | return PIL.ImageOps.posterize(img, v) 78 | 79 | 80 | def Rotate(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | if random.random() < 0.5: 83 | v = -v 84 | return img.rotate(v) 85 | 86 | 87 | def Sharpness(img, v, max_v, bias=0): 88 | v = _float_parameter(v, max_v) + bias 89 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 90 | 91 | 92 | def ShearX(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | if random.random() < 0.5: 95 | v = -v 96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 97 | 98 | 99 | def ShearY(img, v, max_v, bias=0): 100 | v = _float_parameter(v, max_v) + bias 101 | if random.random() < 0.5: 102 | v = -v 103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 104 | 105 | 106 | def Solarize(img, v, max_v, bias=0): 107 | v = _int_parameter(v, max_v) + bias 108 | return PIL.ImageOps.solarize(img, 256 - v) 109 | 110 | 111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 112 | v = _int_parameter(v, max_v) + bias 113 | if random.random() < 0.5: 114 | v = -v 115 | img_np = np.array(img).astype(np.int) 116 | img_np = img_np + v 117 | img_np = np.clip(img_np, 0, 255) 118 | img_np = img_np.astype(np.uint8) 119 | img = Image.fromarray(img_np) 120 | return PIL.ImageOps.solarize(img, threshold) 121 | 122 | 123 | def TranslateX(img, v, max_v, bias=0): 124 | v = _float_parameter(v, max_v) + bias 125 | if random.random() < 0.5: 126 | v = -v 127 | v = int(v * img.size[0]) 128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 129 | 130 | 131 | def TranslateY(img, v, max_v, bias=0): 132 | v = _float_parameter(v, max_v) + bias 133 | if random.random() < 0.5: 134 | v = -v 135 | v = int(v * img.size[1]) 136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 137 | 138 | 139 | def _float_parameter(v, max_v): 140 | return float(v) * max_v / PARAMETER_MAX 141 | 142 | 143 | def _int_parameter(v, max_v): 144 | return int(v * max_v / PARAMETER_MAX) 145 | 146 | 147 | def fixmatch_augment_pool(): 148 | # FixMatch paper 149 | augs = [(AutoContrast, None, None), 150 | (Brightness, 0.9, 0.05), 151 | (Color, 0.9, 0.05), 152 | (Contrast, 0.9, 0.05), 153 | (Equalize, None, None), 154 | (Identity, None, None), 155 | (Posterize, 4, 4), 156 | (Rotate, 30, 0), 157 | (Sharpness, 0.9, 0.05), 158 | (ShearX, 0.3, 0), 159 | (ShearY, 0.3, 0), 160 | (Solarize, 256, 0), 161 | (TranslateX, 0.3, 0), 162 | (TranslateY, 0.3, 0)] 163 | return augs 164 | 165 | 166 | def my_augment_pool(): 167 | # Test 168 | augs = [(AutoContrast, None, None), 169 | (Brightness, 1.8, 0.1), 170 | (Color, 1.8, 0.1), 171 | (Contrast, 1.8, 0.1), 172 | (Cutout, 0.2, 0), 173 | (Equalize, None, None), 174 | (Invert, None, None), 175 | (Posterize, 4, 4), 176 | (Rotate, 30, 0), 177 | (Sharpness, 1.8, 0.1), 178 | (ShearX, 0.3, 0), 179 | (ShearY, 0.3, 0), 180 | (Solarize, 256, 0), 181 | (SolarizeAdd, 110, 0), 182 | (TranslateX, 0.45, 0), 183 | (TranslateY, 0.45, 0)] 184 | return augs 185 | 186 | 187 | class RandAugmentPC(object): 188 | def __init__(self, n, m): 189 | assert n >= 1 190 | assert 1 <= m <= 10 191 | self.n = n 192 | self.m = m 193 | self.augment_pool = my_augment_pool() 194 | 195 | def __call__(self, img): 196 | ops = random.choices(self.augment_pool, k=self.n) 197 | for op, max_v, bias in ops: 198 | prob = np.random.uniform(0.2, 0.8) 199 | if random.random() + prob >= 1: 200 | img = op(img, v=self.m, max_v=max_v, bias=bias) 201 | img = CutoutAbs(img, int(32*0.5)) 202 | return img 203 | 204 | 205 | class RandAugmentMC(object): 206 | def __init__(self, n, m): 207 | assert n >= 1 208 | assert 1 <= m <= 10 209 | self.n = n 210 | self.m = m 211 | self.augment_pool = fixmatch_augment_pool() 212 | 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | img = CutoutAbs(img, int(32*0.5)) 220 | return img 221 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | 5 | 6 | class ModelEMA(object): 7 | def __init__(self, args, model, decay): 8 | self.ema = deepcopy(model) 9 | self.ema.to(args.device) 10 | self.ema.eval() 11 | self.decay = decay 12 | self.ema_has_module = hasattr(self.ema, 'module') 13 | # Fix EMA. https://github.com/valencebond/FixMatch_pytorch thank you! 14 | self.param_keys = [k for k, _ in self.ema.named_parameters()] 15 | self.buffer_keys = [k for k, _ in self.ema.named_buffers()] 16 | for p in self.ema.parameters(): 17 | p.requires_grad_(False) 18 | 19 | def update(self, model): 20 | needs_module = hasattr(model, 'module') and not self.ema_has_module 21 | with torch.no_grad(): 22 | msd = model.state_dict() 23 | esd = self.ema.state_dict() 24 | for k in self.param_keys: 25 | if needs_module: 26 | j = 'module.' + k 27 | else: 28 | j = k 29 | model_v = msd[j].detach() 30 | ema_v = esd[k] 31 | esd[k].copy_(ema_v * self.decay + (1. - self.decay) * model_v) 32 | 33 | for k in self.buffer_keys: 34 | if needs_module: 35 | j = 'module.' + k 36 | else: 37 | j = k 38 | esd[k].copy_(msd[j]) 39 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def mish(x): 11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 12 | return x * torch.tanh(F.softplus(x)) 13 | 14 | 15 | class nn.BatchNorm2d(nn.BatchNorm2d): 16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 17 | 18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): 19 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 20 | self.alpha = alpha 21 | 22 | def forward(self, x): 23 | return super().forward(x) + self.alpha 24 | 25 | 26 | class ResNeXtBottleneck(nn.Module): 27 | """ 28 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 29 | """ 30 | 31 | def __init__(self, in_channels, out_channels, stride, 32 | cardinality, base_width, widen_factor): 33 | """ Constructor 34 | Args: 35 | in_channels: input channel dimensionality 36 | out_channels: output channel dimensionality 37 | stride: conv stride. Replaces pooling layer. 38 | cardinality: num of convolution groups. 39 | base_width: base number of channels in each group. 40 | widen_factor: factor to reduce the input dimensionality before convolution. 41 | """ 42 | super().__init__() 43 | width_ratio = out_channels / (widen_factor * 64.) 44 | D = cardinality * int(base_width * width_ratio) 45 | self.conv_reduce = nn.Conv2d( 46 | in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 47 | self.bn_reduce = nn.BatchNorm2d(D, momentum=0.001) 48 | self.conv_conv = nn.Conv2d(D, D, 49 | kernel_size=3, stride=stride, padding=1, 50 | groups=cardinality, bias=False) 51 | self.bn = nn.BatchNorm2d(D, momentum=0.001) 52 | self.act = mish 53 | self.conv_expand = nn.Conv2d( 54 | D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 55 | self.bn_expand = nn.BatchNorm2d(out_channels, momentum=0.001) 56 | 57 | self.shortcut = nn.Sequential() 58 | if in_channels != out_channels: 59 | self.shortcut.add_module('shortcut_conv', 60 | nn.Conv2d(in_channels, out_channels, 61 | kernel_size=1, 62 | stride=stride, 63 | padding=0, 64 | bias=False)) 65 | self.shortcut.add_module( 66 | 'shortcut_bn', nn.BatchNorm2d(out_channels, momentum=0.001)) 67 | 68 | def forward(self, x): 69 | bottleneck = self.conv_reduce.forward(x) 70 | bottleneck = self.act(self.bn_reduce.forward(bottleneck)) 71 | bottleneck = self.conv_conv.forward(bottleneck) 72 | bottleneck = self.act(self.bn.forward(bottleneck)) 73 | bottleneck = self.conv_expand.forward(bottleneck) 74 | bottleneck = self.bn_expand.forward(bottleneck) 75 | residual = self.shortcut.forward(x) 76 | return self.act(residual + bottleneck) 77 | 78 | 79 | class CifarResNeXt(nn.Module): 80 | """ 81 | ResNext optimized for the Cifar dataset, as specified in 82 | https://arxiv.org/pdf/1611.05431.pdf 83 | """ 84 | 85 | def __init__(self, cardinality, depth, num_classes, 86 | base_width, widen_factor=4): 87 | """ Constructor 88 | Args: 89 | cardinality: number of convolution groups. 90 | depth: number of layers. 91 | nlabels: number of classes 92 | base_width: base number of channels in each group. 93 | widen_factor: factor to adjust the channel dimensionality 94 | """ 95 | super().__init__() 96 | self.cardinality = cardinality 97 | self.depth = depth 98 | self.block_depth = (self.depth - 2) // 9 99 | self.base_width = base_width 100 | self.widen_factor = widen_factor 101 | self.nlabels = num_classes 102 | self.output_size = 64 103 | self.stages = [64, 64 * self.widen_factor, 128 * 104 | self.widen_factor, 256 * self.widen_factor] 105 | 106 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 107 | self.bn_1 = nn.BatchNorm2d(64, momentum=0.001) 108 | self.act = mish 109 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 110 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 111 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 112 | self.classifier = nn.Linear(self.stages[3], num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, 117 | mode='fan_out', 118 | nonlinearity='leaky_relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1.0) 121 | nn.init.constant_(m.bias, 0.0) 122 | elif isinstance(m, nn.Linear): 123 | nn.init.xavier_normal_(m.weight) 124 | nn.init.constant_(m.bias, 0.0) 125 | 126 | def block(self, name, in_channels, out_channels, pool_stride=2): 127 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 128 | Args: 129 | name: string name of the current block. 130 | in_channels: number of input channels 131 | out_channels: number of output channels 132 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 133 | Returns: a Module consisting of n sequential bottlenecks. 134 | """ 135 | block = nn.Sequential() 136 | for bottleneck in range(self.block_depth): 137 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 138 | if bottleneck == 0: 139 | block.add_module(name_, ResNeXtBottleneck(in_channels, 140 | out_channels, 141 | pool_stride, 142 | self.cardinality, 143 | self.base_width, 144 | self.widen_factor)) 145 | else: 146 | block.add_module(name_, 147 | ResNeXtBottleneck(out_channels, 148 | out_channels, 149 | 1, 150 | self.cardinality, 151 | self.base_width, 152 | self.widen_factor)) 153 | return block 154 | 155 | def forward(self, x): 156 | x = self.conv_1_3x3.forward(x) 157 | x = self.act(self.bn_1.forward(x)) 158 | x = self.stage_1.forward(x) 159 | x = self.stage_2.forward(x) 160 | x = self.stage_3.forward(x) 161 | x = F.adaptive_avg_pool2d(x, 1) 162 | x = x.view(-1, self.stages[3]) 163 | return self.classifier(x) 164 | 165 | 166 | def build_resnext(cardinality, depth, width, num_classes): 167 | logger.info(f"Model: ResNeXt {depth+1}x{width}") 168 | return CifarResNeXt(cardinality=cardinality, 169 | depth=depth, 170 | base_width=width, 171 | num_classes=num_classes) 172 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def mish(x): 11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 12 | return x * torch.tanh(F.softplus(x)) 13 | 14 | 15 | class PSBatchNorm2d(nn.BatchNorm2d): 16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 17 | 18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True): 19 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 20 | self.alpha = alpha 21 | 22 | def forward(self, x): 23 | return super().forward(x) + self.alpha 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False): 28 | super(BasicBlock, self).__init__() 29 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 30 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 31 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 34 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 35 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 36 | padding=1, bias=False) 37 | self.drop_rate = drop_rate 38 | self.equalInOut = (in_planes == out_planes) 39 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 40 | padding=0, bias=False) or None 41 | self.activate_before_residual = activate_before_residual 42 | 43 | def forward(self, x): 44 | if not self.equalInOut and self.activate_before_residual == True: 45 | x = self.relu1(self.bn1(x)) 46 | else: 47 | out = self.relu1(self.bn1(x)) 48 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 49 | if self.drop_rate > 0: 50 | out = F.dropout(out, p=self.drop_rate, training=self.training) 51 | out = self.conv2(out) 52 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 53 | 54 | 55 | class NetworkBlock(nn.Module): 56 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False): 57 | super(NetworkBlock, self).__init__() 58 | self.layer = self._make_layer( 59 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual) 60 | 61 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual): 62 | layers = [] 63 | for i in range(int(nb_layers)): 64 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, 65 | i == 0 and stride or 1, drop_rate, activate_before_residual)) 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | return self.layer(x) 70 | 71 | 72 | class WideResNet(nn.Module): 73 | def __init__(self, num_classes, depth=28, widen_factor=2, drop_rate=0.0): 74 | super(WideResNet, self).__init__() 75 | channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 76 | assert((depth - 4) % 6 == 0) 77 | n = (depth - 4) / 6 78 | block = BasicBlock 79 | # 1st conv before any network block 80 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = NetworkBlock( 84 | n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual=True) 85 | # 2nd block 86 | self.block2 = NetworkBlock( 87 | n, channels[1], channels[2], block, 2, drop_rate) 88 | # 3rd block 89 | self.block3 = NetworkBlock( 90 | n, channels[2], channels[3], block, 2, drop_rate) 91 | # global average pooling and classifier 92 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001) 93 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 94 | self.fc = nn.Linear(channels[3], num_classes) 95 | self.channels = channels[3] 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, 100 | mode='fan_out', 101 | nonlinearity='leaky_relu') 102 | elif isinstance(m, nn.BatchNorm2d): 103 | nn.init.constant_(m.weight, 1.0) 104 | nn.init.constant_(m.bias, 0.0) 105 | elif isinstance(m, nn.Linear): 106 | nn.init.xavier_normal_(m.weight) 107 | nn.init.constant_(m.bias, 0.0) 108 | 109 | def forward(self, x): 110 | out = self.conv1(x) 111 | out = self.block1(out) 112 | out = self.block2(out) 113 | out = self.block3(out) 114 | out = self.relu(self.bn1(out)) 115 | out = F.adaptive_avg_pool2d(out, 1) 116 | out = out.view(-1, self.channels) 117 | return self.fc(out) 118 | 119 | 120 | def build_wideresnet(depth, widen_factor, dropout, num_classes): 121 | logger.info(f"Model: WideResNet {depth}x{widen_factor}") 122 | return WideResNet(depth=depth, 123 | widen_factor=widen_factor, 124 | drop_rate=dropout, 125 | num_classes=num_classes) 126 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | from collections import OrderedDict 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.tensorboard import SummaryWriter 18 | from tqdm import tqdm 19 | 20 | from dataset.cifar import DATASET_GETTERS 21 | from utils import AverageMeter, accuracy 22 | 23 | logger = logging.getLogger(__name__) 24 | best_acc = 0 25 | 26 | 27 | def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'): 28 | filepath = os.path.join(checkpoint, filename) 29 | torch.save(state, filepath) 30 | if is_best: 31 | shutil.copyfile(filepath, os.path.join(checkpoint, 32 | 'model_best.pth.tar')) 33 | 34 | 35 | def set_seed(args): 36 | random.seed(args.seed) 37 | np.random.seed(args.seed) 38 | torch.manual_seed(args.seed) 39 | if args.n_gpu > 0: 40 | torch.cuda.manual_seed_all(args.seed) 41 | 42 | 43 | def get_cosine_schedule_with_warmup(optimizer, 44 | num_warmup_steps, 45 | num_training_steps, 46 | num_cycles=7./16., 47 | last_epoch=-1): 48 | def _lr_lambda(current_step): 49 | if current_step < num_warmup_steps: 50 | return float(current_step) / float(max(1, num_warmup_steps)) 51 | no_progress = float(current_step - num_warmup_steps) / \ 52 | float(max(1, num_training_steps - num_warmup_steps)) 53 | return max(0., math.cos(math.pi * num_cycles * no_progress)) 54 | 55 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 56 | 57 | 58 | def interleave(x, size): 59 | s = list(x.shape) 60 | return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:]) 61 | 62 | 63 | def de_interleave(x, size): 64 | s = list(x.shape) 65 | return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:]) 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser(description='PyTorch FixMatch Training') 70 | parser.add_argument('--gpu-id', default='0', type=int, 71 | help='id(s) for CUDA_VISIBLE_DEVICES') 72 | parser.add_argument('--num-workers', type=int, default=4, 73 | help='number of workers') 74 | parser.add_argument('--dataset', default='cifar10', type=str, 75 | choices=['cifar10', 'cifar100'], 76 | help='dataset name') 77 | parser.add_argument('--num-labeled', type=int, default=4000, 78 | help='number of labeled data') 79 | parser.add_argument("--expand-labels", action="store_true", 80 | help="expand labels to fit eval steps") 81 | parser.add_argument('--arch', default='wideresnet', type=str, 82 | choices=['wideresnet', 'resnext'], 83 | help='dataset name') 84 | parser.add_argument('--total-steps', default=2**20, type=int, 85 | help='number of total steps to run') 86 | parser.add_argument('--eval-step', default=1024, type=int, 87 | help='number of eval steps to run') 88 | parser.add_argument('--start-epoch', default=0, type=int, 89 | help='manual epoch number (useful on restarts)') 90 | parser.add_argument('--batch-size', default=64, type=int, 91 | help='train batchsize') 92 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 93 | help='initial learning rate') 94 | parser.add_argument('--warmup', default=0, type=float, 95 | help='warmup epochs (unlabeled data based)') 96 | parser.add_argument('--wdecay', default=5e-4, type=float, 97 | help='weight decay') 98 | parser.add_argument('--nesterov', action='store_true', default=True, 99 | help='use nesterov momentum') 100 | parser.add_argument('--use-ema', action='store_true', default=True, 101 | help='use EMA model') 102 | parser.add_argument('--ema-decay', default=0.999, type=float, 103 | help='EMA decay rate') 104 | parser.add_argument('--mu', default=7, type=int, 105 | help='coefficient of unlabeled batch size') 106 | parser.add_argument('--lambda-u', default=1, type=float, 107 | help='coefficient of unlabeled loss') 108 | parser.add_argument('--T', default=1, type=float, 109 | help='pseudo label temperature') 110 | parser.add_argument('--threshold', default=0.95, type=float, 111 | help='pseudo label threshold') 112 | parser.add_argument('--out', default='result', 113 | help='directory to output the result') 114 | parser.add_argument('--resume', default='', type=str, 115 | help='path to latest checkpoint (default: none)') 116 | parser.add_argument('--seed', default=None, type=int, 117 | help="random seed") 118 | parser.add_argument("--amp", action="store_true", 119 | help="use 16-bit (mixed) precision through NVIDIA apex AMP") 120 | parser.add_argument("--opt_level", type=str, default="O1", 121 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 122 | "See details at https://nvidia.github.io/apex/amp.html") 123 | parser.add_argument("--local_rank", type=int, default=-1, 124 | help="For distributed training: local_rank") 125 | parser.add_argument('--no-progress', action='store_true', 126 | help="don't use progress bar") 127 | 128 | args = parser.parse_args() 129 | global best_acc 130 | 131 | def create_model(args): 132 | if args.arch == 'wideresnet': 133 | import models.wideresnet as models 134 | model = models.build_wideresnet(depth=args.model_depth, 135 | widen_factor=args.model_width, 136 | dropout=0, 137 | num_classes=args.num_classes) 138 | elif args.arch == 'resnext': 139 | import models.resnext as models 140 | model = models.build_resnext(cardinality=args.model_cardinality, 141 | depth=args.model_depth, 142 | width=args.model_width, 143 | num_classes=args.num_classes) 144 | logger.info("Total params: {:.2f}M".format( 145 | sum(p.numel() for p in model.parameters())/1e6)) 146 | return model 147 | 148 | if args.local_rank == -1: 149 | device = torch.device('cuda', args.gpu_id) 150 | args.world_size = 1 151 | args.n_gpu = torch.cuda.device_count() 152 | else: 153 | torch.cuda.set_device(args.local_rank) 154 | device = torch.device('cuda', args.local_rank) 155 | torch.distributed.init_process_group(backend='nccl') 156 | args.world_size = torch.distributed.get_world_size() 157 | args.n_gpu = 1 158 | 159 | args.device = device 160 | 161 | logging.basicConfig( 162 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 163 | datefmt="%m/%d/%Y %H:%M:%S", 164 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 165 | 166 | logger.warning( 167 | f"Process rank: {args.local_rank}, " 168 | f"device: {args.device}, " 169 | f"n_gpu: {args.n_gpu}, " 170 | f"distributed training: {bool(args.local_rank != -1)}, " 171 | f"16-bits training: {args.amp}",) 172 | 173 | logger.info(dict(args._get_kwargs())) 174 | 175 | if args.seed is not None: 176 | set_seed(args) 177 | 178 | if args.local_rank in [-1, 0]: 179 | os.makedirs(args.out, exist_ok=True) 180 | args.writer = SummaryWriter(args.out) 181 | 182 | if args.dataset == 'cifar10': 183 | args.num_classes = 10 184 | if args.arch == 'wideresnet': 185 | args.model_depth = 28 186 | args.model_width = 2 187 | elif args.arch == 'resnext': 188 | args.model_cardinality = 4 189 | args.model_depth = 28 190 | args.model_width = 4 191 | 192 | elif args.dataset == 'cifar100': 193 | args.num_classes = 100 194 | if args.arch == 'wideresnet': 195 | args.model_depth = 28 196 | args.model_width = 8 197 | elif args.arch == 'resnext': 198 | args.model_cardinality = 8 199 | args.model_depth = 29 200 | args.model_width = 64 201 | 202 | if args.local_rank not in [-1, 0]: 203 | torch.distributed.barrier() 204 | 205 | labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset]( 206 | args, './data') 207 | 208 | if args.local_rank == 0: 209 | torch.distributed.barrier() 210 | 211 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler 212 | 213 | labeled_trainloader = DataLoader( 214 | labeled_dataset, 215 | sampler=train_sampler(labeled_dataset), 216 | batch_size=args.batch_size, 217 | num_workers=args.num_workers, 218 | drop_last=True) 219 | 220 | unlabeled_trainloader = DataLoader( 221 | unlabeled_dataset, 222 | sampler=train_sampler(unlabeled_dataset), 223 | batch_size=args.batch_size*args.mu, 224 | num_workers=args.num_workers, 225 | drop_last=True) 226 | 227 | test_loader = DataLoader( 228 | test_dataset, 229 | sampler=SequentialSampler(test_dataset), 230 | batch_size=args.batch_size, 231 | num_workers=args.num_workers) 232 | 233 | if args.local_rank not in [-1, 0]: 234 | torch.distributed.barrier() 235 | 236 | model = create_model(args) 237 | 238 | if args.local_rank == 0: 239 | torch.distributed.barrier() 240 | 241 | model.to(args.device) 242 | 243 | no_decay = ['bias', 'bn'] 244 | grouped_parameters = [ 245 | {'params': [p for n, p in model.named_parameters() if not any( 246 | nd in n for nd in no_decay)], 'weight_decay': args.wdecay}, 247 | {'params': [p for n, p in model.named_parameters() if any( 248 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 249 | ] 250 | optimizer = optim.SGD(grouped_parameters, lr=args.lr, 251 | momentum=0.9, nesterov=args.nesterov) 252 | 253 | args.epochs = math.ceil(args.total_steps / args.eval_step) 254 | scheduler = get_cosine_schedule_with_warmup( 255 | optimizer, args.warmup, args.total_steps) 256 | 257 | if args.use_ema: 258 | from models.ema import ModelEMA 259 | ema_model = ModelEMA(args, model, args.ema_decay) 260 | 261 | args.start_epoch = 0 262 | 263 | if args.resume: 264 | logger.info("==> Resuming from checkpoint..") 265 | assert os.path.isfile( 266 | args.resume), "Error: no checkpoint directory found!" 267 | args.out = os.path.dirname(args.resume) 268 | checkpoint = torch.load(args.resume) 269 | best_acc = checkpoint['best_acc'] 270 | args.start_epoch = checkpoint['epoch'] 271 | model.load_state_dict(checkpoint['state_dict']) 272 | if args.use_ema: 273 | ema_model.ema.load_state_dict(checkpoint['ema_state_dict']) 274 | optimizer.load_state_dict(checkpoint['optimizer']) 275 | scheduler.load_state_dict(checkpoint['scheduler']) 276 | 277 | if args.amp: 278 | from apex import amp 279 | model, optimizer = amp.initialize( 280 | model, optimizer, opt_level=args.opt_level) 281 | 282 | if args.local_rank != -1: 283 | model = torch.nn.parallel.DistributedDataParallel( 284 | model, device_ids=[args.local_rank], 285 | output_device=args.local_rank, find_unused_parameters=True) 286 | 287 | logger.info("***** Running training *****") 288 | logger.info(f" Task = {args.dataset}@{args.num_labeled}") 289 | logger.info(f" Num Epochs = {args.epochs}") 290 | logger.info(f" Batch size per GPU = {args.batch_size}") 291 | logger.info( 292 | f" Total train batch size = {args.batch_size*args.world_size}") 293 | logger.info(f" Total optimization steps = {args.total_steps}") 294 | 295 | model.zero_grad() 296 | train(args, labeled_trainloader, unlabeled_trainloader, test_loader, 297 | model, optimizer, ema_model, scheduler) 298 | 299 | 300 | def train(args, labeled_trainloader, unlabeled_trainloader, test_loader, 301 | model, optimizer, ema_model, scheduler): 302 | if args.amp: 303 | from apex import amp 304 | global best_acc 305 | test_accs = [] 306 | end = time.time() 307 | 308 | if args.world_size > 1: 309 | labeled_epoch = 0 310 | unlabeled_epoch = 0 311 | labeled_trainloader.sampler.set_epoch(labeled_epoch) 312 | unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch) 313 | 314 | labeled_iter = iter(labeled_trainloader) 315 | unlabeled_iter = iter(unlabeled_trainloader) 316 | 317 | model.train() 318 | for epoch in range(args.start_epoch, args.epochs): 319 | batch_time = AverageMeter() 320 | data_time = AverageMeter() 321 | losses = AverageMeter() 322 | losses_x = AverageMeter() 323 | losses_u = AverageMeter() 324 | mask_probs = AverageMeter() 325 | if not args.no_progress: 326 | p_bar = tqdm(range(args.eval_step), 327 | disable=args.local_rank not in [-1, 0]) 328 | for batch_idx in range(args.eval_step): 329 | try: 330 | inputs_x, targets_x = labeled_iter.next() 331 | # error occurs ↓ 332 | # inputs_x, targets_x = next(labeled_iter) 333 | except: 334 | if args.world_size > 1: 335 | labeled_epoch += 1 336 | labeled_trainloader.sampler.set_epoch(labeled_epoch) 337 | labeled_iter = iter(labeled_trainloader) 338 | inputs_x, targets_x = labeled_iter.next() 339 | # error occurs ↓ 340 | # inputs_x, targets_x = next(labeled_iter) 341 | 342 | try: 343 | (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next() 344 | # error occurs ↓ 345 | # (inputs_u_w, inputs_u_s), _ = next(unlabeled_iter) 346 | except: 347 | if args.world_size > 1: 348 | unlabeled_epoch += 1 349 | unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch) 350 | unlabeled_iter = iter(unlabeled_trainloader) 351 | (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next() 352 | # error occurs ↓ 353 | # (inputs_u_w, inputs_u_s), _ = next(unlabeled_iter) 354 | 355 | data_time.update(time.time() - end) 356 | batch_size = inputs_x.shape[0] 357 | inputs = interleave( 358 | torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*args.mu+1).to(args.device) 359 | targets_x = targets_x.to(args.device) 360 | logits = model(inputs) 361 | logits = de_interleave(logits, 2*args.mu+1) 362 | logits_x = logits[:batch_size] 363 | logits_u_w, logits_u_s = logits[batch_size:].chunk(2) 364 | del logits 365 | 366 | Lx = F.cross_entropy(logits_x, targets_x, reduction='mean') 367 | 368 | pseudo_label = torch.softmax(logits_u_w.detach()/args.T, dim=-1) 369 | max_probs, targets_u = torch.max(pseudo_label, dim=-1) 370 | mask = max_probs.ge(args.threshold).float() 371 | 372 | Lu = (F.cross_entropy(logits_u_s, targets_u, 373 | reduction='none') * mask).mean() 374 | 375 | loss = Lx + args.lambda_u * Lu 376 | 377 | if args.amp: 378 | with amp.scale_loss(loss, optimizer) as scaled_loss: 379 | scaled_loss.backward() 380 | else: 381 | loss.backward() 382 | 383 | losses.update(loss.item()) 384 | losses_x.update(Lx.item()) 385 | losses_u.update(Lu.item()) 386 | optimizer.step() 387 | scheduler.step() 388 | if args.use_ema: 389 | ema_model.update(model) 390 | model.zero_grad() 391 | 392 | batch_time.update(time.time() - end) 393 | end = time.time() 394 | mask_probs.update(mask.mean().item()) 395 | if not args.no_progress: 396 | p_bar.set_description("Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.4f}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. Loss_x: {loss_x:.4f}. Loss_u: {loss_u:.4f}. Mask: {mask:.2f}. ".format( 397 | epoch=epoch + 1, 398 | epochs=args.epochs, 399 | batch=batch_idx + 1, 400 | iter=args.eval_step, 401 | lr=scheduler.get_last_lr()[0], 402 | data=data_time.avg, 403 | bt=batch_time.avg, 404 | loss=losses.avg, 405 | loss_x=losses_x.avg, 406 | loss_u=losses_u.avg, 407 | mask=mask_probs.avg)) 408 | p_bar.update() 409 | 410 | if not args.no_progress: 411 | p_bar.close() 412 | 413 | if args.use_ema: 414 | test_model = ema_model.ema 415 | else: 416 | test_model = model 417 | 418 | if args.local_rank in [-1, 0]: 419 | test_loss, test_acc = test(args, test_loader, test_model, epoch) 420 | 421 | args.writer.add_scalar('train/1.train_loss', losses.avg, epoch) 422 | args.writer.add_scalar('train/2.train_loss_x', losses_x.avg, epoch) 423 | args.writer.add_scalar('train/3.train_loss_u', losses_u.avg, epoch) 424 | args.writer.add_scalar('train/4.mask', mask_probs.avg, epoch) 425 | args.writer.add_scalar('test/1.test_acc', test_acc, epoch) 426 | args.writer.add_scalar('test/2.test_loss', test_loss, epoch) 427 | 428 | is_best = test_acc > best_acc 429 | best_acc = max(test_acc, best_acc) 430 | 431 | model_to_save = model.module if hasattr(model, "module") else model 432 | if args.use_ema: 433 | ema_to_save = ema_model.ema.module if hasattr( 434 | ema_model.ema, "module") else ema_model.ema 435 | save_checkpoint({ 436 | 'epoch': epoch + 1, 437 | 'state_dict': model_to_save.state_dict(), 438 | 'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None, 439 | 'acc': test_acc, 440 | 'best_acc': best_acc, 441 | 'optimizer': optimizer.state_dict(), 442 | 'scheduler': scheduler.state_dict(), 443 | }, is_best, args.out) 444 | 445 | test_accs.append(test_acc) 446 | logger.info('Best top-1 acc: {:.2f}'.format(best_acc)) 447 | logger.info('Mean top-1 acc: {:.2f}\n'.format( 448 | np.mean(test_accs[-20:]))) 449 | 450 | if args.local_rank in [-1, 0]: 451 | args.writer.close() 452 | 453 | 454 | def test(args, test_loader, model, epoch): 455 | batch_time = AverageMeter() 456 | data_time = AverageMeter() 457 | losses = AverageMeter() 458 | top1 = AverageMeter() 459 | top5 = AverageMeter() 460 | end = time.time() 461 | 462 | if not args.no_progress: 463 | test_loader = tqdm(test_loader, 464 | disable=args.local_rank not in [-1, 0]) 465 | 466 | with torch.no_grad(): 467 | for batch_idx, (inputs, targets) in enumerate(test_loader): 468 | data_time.update(time.time() - end) 469 | model.eval() 470 | 471 | inputs = inputs.to(args.device) 472 | targets = targets.to(args.device) 473 | outputs = model(inputs) 474 | loss = F.cross_entropy(outputs, targets) 475 | 476 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 477 | losses.update(loss.item(), inputs.shape[0]) 478 | top1.update(prec1.item(), inputs.shape[0]) 479 | top5.update(prec5.item(), inputs.shape[0]) 480 | batch_time.update(time.time() - end) 481 | end = time.time() 482 | if not args.no_progress: 483 | test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format( 484 | batch=batch_idx + 1, 485 | iter=len(test_loader), 486 | data=data_time.avg, 487 | bt=batch_time.avg, 488 | loss=losses.avg, 489 | top1=top1.avg, 490 | top5=top5.avg, 491 | )) 492 | if not args.no_progress: 493 | test_loader.close() 494 | 495 | logger.info("top-1 acc: {:.2f}".format(top1.avg)) 496 | logger.info("top-5 acc: {:.2f}".format(top5.avg)) 497 | return losses.avg, top1.avg 498 | 499 | 500 | if __name__ == '__main__': 501 | main() 502 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | ''' 4 | import logging 5 | 6 | import torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | __all__ = ['get_mean_and_std', 'accuracy', 'AverageMeter'] 11 | 12 | 13 | def get_mean_and_std(dataset): 14 | '''Compute the mean and std value of dataset.''' 15 | dataloader = torch.utils.data.DataLoader( 16 | dataset, batch_size=1, shuffle=False, num_workers=4) 17 | 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | logger.info('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:, i, :, :].mean() 24 | std[i] += inputs[:, i, :, :].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | """Computes the precision@k for the specified values of k""" 32 | maxk = max(topk) 33 | batch_size = target.size(0) 34 | 35 | _, pred = output.topk(maxk, 1, True, True) 36 | pred = pred.t() 37 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 38 | 39 | res = [] 40 | for k in topk: 41 | correct_k = correct[:k].reshape(-1).float().sum(0) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | return res 44 | 45 | 46 | class AverageMeter(object): 47 | """Computes and stores the average and current value 48 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 49 | """ 50 | 51 | def __init__(self): 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | 60 | def update(self, val, n=1): 61 | self.val = val 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | --------------------------------------------------------------------------------