├── .gitignore ├── .idea ├── .gitignore ├── deployment.xml ├── fixmatch_jian.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── cifar.py ├── randaugment.py ├── sampler.py └── transform.py ├── label_guessor.py ├── lr_scheduler.py ├── models ├── __init__.py ├── ema.py └── model.py ├── requirements.txt ├── run.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | .hypothesis/ 46 | .pytest_cache/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | db.sqlite3 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # Environments 83 | .env 84 | .venv 85 | env/ 86 | venv/ 87 | ENV/ 88 | env.bak/ 89 | venv.bak/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | 105 | ## Coin: 106 | data 107 | dataset/ 108 | data/ 109 | play.py 110 | preprocess_data.py 111 | res/ 112 | labels.py 113 | adj.md 114 | bayesian_search.py 115 | evaluate.py.old 116 | *png 117 | infer.py 118 | 119 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/fixmatch_jian.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 CoinCheung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## FixMatch 3 | 4 | The unofficial reimplementation of [fixmatch](https://arxiv.org/abs/2001.07685) with RandomAugment. 5 | 6 | ## Overview 7 | 8 | 9 | |repo|using EMA model to evaluate|using EMA model to train|update parameters|update buffer| 10 | |:---|:---:|:---:|:---:|:---:| 11 | |ours| ✓|-| ✓|-| 12 | |mdiephuis| ✓| ✓| ✓|-| 13 | |kekmodel| ✓|-|-| ✓| 14 | 15 | 16 | 2020-03-30_18:07:08.log : annotation decay and add classifier.bias 17 | 18 | 2020-03-31_09:51:38.log : add interleave and run model once 19 | 20 | ## Dependencies 21 | 22 | - python 3.6 23 | - pytorch 1.3.1 24 | - torchvision 0.2.1 25 | 26 | The other packages and versions are listed in ```requirements.txt```. 27 | You can install them by ```pip install -r requirements.txt```. 28 | 29 | 30 | ## Dataset 31 | download cifar-10 dataset: 32 | ``` 33 | $ mkdir -p dataset && cd data 34 | $ wget -c http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 35 | $ tar -xzvf cifar-10-python.tar.gz 36 | ``` 37 | 38 | download cifar-100 dataset: 39 | ``` 40 | $ mkdir -p dataset && cd data 41 | $ wget -c http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz 42 | $ tar -xzvf cifar-100-python.tar.gz 43 | ``` 44 | 45 | ## Train the model 46 | 47 | To train the model on CIFAR10 with 40 labeled samples, you can run the script: 48 | ``` 49 | $ CUDA_VISIBLE_DEVICES='0' python train.py --dataset CIFAR10 --n-labeled 40 50 | ``` 51 | To train the model on CIFAR100 with 400 labeled samples, you can run the script: 52 | ``` 53 | $ CUDA_VISIBLE_DEVICES='0' python train.py --dataset CIFAR100 --n-labeled 400 54 | ``` 55 | 56 | 57 | ## Results 58 | 59 | 60 | ### CIFAR10 61 | | #Labels | 40 | 250 | 4000 | 62 | |:---|:---:|:---:|:---:| 63 | |Paper (RA) | 86.19 ± 3.37 | 94.93 ± 0.65 | 95.74 ± 0.05 | 64 | |ours| 89.63(85.65) | 93.0832 |94.7154| 65 | 66 | ### CIFAR100 67 | 68 | | #Labels | 400 | 2500 | 10000 | 69 | |:---|:---:|:---:|:---:| 70 | |Paper (RA) | 51.15 ± 1.75 | 71.71 ± 0.11 | 77.40 ± 0.12 | 71 | |ours | 53.74 | 67.3169 | 73.26 | 72 | 73 | 74 | ### References 75 | - https://github.com/CoinCheung/fixmatch 76 | - https://github.com/kekmodel/FixMatch-pytorch 77 | - official implement https://github.com/google-research/fixmatch 78 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valencebond/FixMatch_pytorch/3b7fc36ca0c71754a59c9c78465d681f259ee174/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | from datasets import transform as T 8 | 9 | from datasets.randaugment import RandomAugment 10 | from datasets.sampler import RandomSampler, BatchSampler 11 | 12 | 13 | def load_data_train(L=250, dataset='CIFAR10', dspth='./data'): 14 | if dataset == 'CIFAR10': 15 | datalist = [ 16 | osp.join(dspth, 'cifar-10-batches-py', 'data_batch_{}'.format(i + 1)) 17 | for i in range(5) 18 | ] 19 | n_class = 10 20 | assert L in [40, 250, 4000] 21 | elif dataset == 'CIFAR100': 22 | datalist = [ 23 | osp.join(dspth, 'cifar-100-python', 'train')] 24 | n_class = 100 25 | assert L in [400, 2500, 10000] 26 | 27 | data, labels = [], [] 28 | for data_batch in datalist: 29 | with open(data_batch, 'rb') as fr: 30 | entry = pickle.load(fr, encoding='latin1') 31 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels'] 32 | data.append(entry['data']) 33 | labels.append(lbs) 34 | data = np.concatenate(data, axis=0) 35 | labels = np.concatenate(labels, axis=0) 36 | n_labels = L // n_class 37 | data_x, label_x, data_u, label_u = [], [], [], [] 38 | for i in range(n_class): 39 | indices = np.where(labels == i)[0] 40 | np.random.shuffle(indices) 41 | inds_x, inds_u = indices[:n_labels], indices[n_labels:] 42 | data_x += [ 43 | data[i].reshape(3, 32, 32).transpose(1, 2, 0) 44 | for i in inds_x 45 | ] 46 | label_x += [labels[i] for i in inds_x] 47 | data_u += [ 48 | data[i].reshape(3, 32, 32).transpose(1, 2, 0) 49 | for i in inds_u 50 | ] 51 | label_u += [labels[i] for i in inds_u] 52 | return data_x, label_x, data_u, label_u 53 | 54 | 55 | # def load_data_train(L=250, dspth='./data'): 56 | # datalist = [ 57 | # osp.join(dspth, 'cifar-10-batches-py', 'data_batch_{}'.format(i + 1)) 58 | # for i in range(5) 59 | # ] 60 | # data, labels = [], [] 61 | # for data_batch in datalist: 62 | # with open(data_batch, 'rb') as fr: 63 | # entry = pickle.load(fr, encoding='latin1') 64 | # lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels'] 65 | # data.append(entry['data']) 66 | # labels.append(lbs) 67 | # data = np.concatenate(data, axis=0) 68 | # labels = np.concatenate(labels, axis=0) 69 | # n_labels = L // 10 70 | # data_x, label_x, data_u, label_u = [], [], [], [] 71 | # for i in range(10): 72 | # indices = np.where(labels == i)[0] 73 | # np.random.shuffle(indices) 74 | # inds_x, inds_u = indices[:n_labels], indices[n_labels:] 75 | # data_x += [ 76 | # data[i].reshape(3, 32, 32).transpose(1, 2, 0) 77 | # for i in inds_x 78 | # ] 79 | # label_x += [labels[i] for i in inds_x] 80 | # data_u += [ 81 | # data[i].reshape(3, 32, 32).transpose(1, 2, 0) 82 | # for i in inds_u 83 | # ] 84 | # label_u += [labels[i] for i in inds_u] 85 | # return data_x, label_x, data_u, label_u 86 | 87 | 88 | def load_data_val(dataset, dspth='./data'): 89 | if dataset == 'CIFAR10': 90 | datalist = [ 91 | osp.join(dspth, 'cifar-10-batches-py', 'test_batch') 92 | ] 93 | elif dataset == 'CIFAR100': 94 | datalist = [ 95 | osp.join(dspth, 'cifar-100-python', 'test') 96 | ] 97 | 98 | data, labels = [], [] 99 | for data_batch in datalist: 100 | with open(data_batch, 'rb') as fr: 101 | entry = pickle.load(fr, encoding='latin1') 102 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels'] 103 | data.append(entry['data']) 104 | labels.append(lbs) 105 | data = np.concatenate(data, axis=0) 106 | labels = np.concatenate(labels, axis=0) 107 | data = [ 108 | el.reshape(3, 32, 32).transpose(1, 2, 0) 109 | for el in data 110 | ] 111 | return data, labels 112 | 113 | 114 | def compute_mean_var(): 115 | data_x, label_x, data_u, label_u = load_data_train() 116 | data = data_x + data_u 117 | data = np.concatenate([el[None, ...] for el in data], axis=0) 118 | 119 | mean, var = [], [] 120 | for i in range(3): 121 | channel = (data[:, :, :, i].ravel() / 127.5) - 1 122 | # channel = (data[:, :, :, i].ravel() / 255) 123 | mean.append(np.mean(channel)) 124 | var.append(np.std(channel)) 125 | 126 | print('mean: ', mean) 127 | print('var: ', var) 128 | 129 | 130 | 131 | class Cifar(Dataset): 132 | def __init__(self, dataset, data, labels, is_train=True): 133 | super(Cifar, self).__init__() 134 | self.data, self.labels = data, labels 135 | self.is_train = is_train 136 | assert len(self.data) == len(self.labels) 137 | if dataset == 'CIFAR10': 138 | mean, std = (0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616) 139 | elif dataset == 'CIFAR100': 140 | mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) 141 | 142 | if is_train: 143 | self.trans_weak = T.Compose([ 144 | T.Resize((32, 32)), 145 | T.PadandRandomCrop(border=4, cropsize=(32, 32)), 146 | T.RandomHorizontalFlip(p=0.5), 147 | T.Normalize(mean, std), 148 | T.ToTensor(), 149 | ]) 150 | self.trans_strong = T.Compose([ 151 | T.Resize((32, 32)), 152 | T.PadandRandomCrop(border=4, cropsize=(32, 32)), 153 | T.RandomHorizontalFlip(p=0.5), 154 | RandomAugment(2, 10), 155 | T.Normalize(mean, std), 156 | T.ToTensor(), 157 | ]) 158 | else: 159 | self.trans = T.Compose([ 160 | T.Resize((32, 32)), 161 | T.Normalize(mean, std), 162 | T.ToTensor(), 163 | ]) 164 | 165 | def __getitem__(self, idx): 166 | im, lb = self.data[idx], self.labels[idx] 167 | if self.is_train: 168 | return self.trans_weak(im), self.trans_strong(im), lb 169 | else: 170 | return self.trans(im), lb 171 | 172 | def __len__(self): 173 | leng = len(self.data) 174 | return leng 175 | 176 | 177 | def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, root='data'): 178 | data_x, label_x, data_u, label_u = load_data_train(L=L, dataset=dataset, dspth=root) 179 | 180 | ds_x = Cifar( 181 | dataset=dataset, 182 | data=data_x, 183 | labels=label_x, 184 | is_train=True 185 | ) # return an iter of num_samples length (all indices of samples) 186 | sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size) 187 | batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True) # yield a batch of samples one time 188 | dl_x = torch.utils.data.DataLoader( 189 | ds_x, 190 | batch_sampler=batch_sampler_x, 191 | num_workers=2, 192 | pin_memory=True 193 | ) 194 | ds_u = Cifar( 195 | dataset=dataset, 196 | data=data_u, 197 | labels=label_u, 198 | is_train=True 199 | ) 200 | sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size) 201 | batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True) 202 | dl_u = torch.utils.data.DataLoader( 203 | ds_u, 204 | batch_sampler=batch_sampler_u, 205 | num_workers=2, 206 | pin_memory=True 207 | ) 208 | return dl_x, dl_u 209 | 210 | 211 | def get_val_loader(dataset, batch_size, num_workers, pin_memory=True): 212 | data, labels = load_data_val(dataset) 213 | ds = Cifar( 214 | dataset=dataset, 215 | data=data, 216 | labels=labels, 217 | is_train=False 218 | ) 219 | dl = torch.utils.data.DataLoader( 220 | ds, 221 | shuffle=False, 222 | batch_size=batch_size, 223 | drop_last=False, 224 | num_workers=num_workers, 225 | pin_memory=pin_memory 226 | ) 227 | return dl 228 | 229 | 230 | class OneHot(object): 231 | def __init__( 232 | self, 233 | n_labels, 234 | lb_ignore=255, 235 | ): 236 | super(OneHot, self).__init__() 237 | self.n_labels = n_labels 238 | self.lb_ignore = lb_ignore 239 | 240 | def __call__(self, label): 241 | N, *S = label.size() 242 | size = [N, self.n_labels] + S 243 | lb_one_hot = torch.zeros(size) 244 | if label.is_cuda: 245 | lb_one_hot = lb_one_hot.cuda() 246 | ignore = label.data.cpu() == self.lb_ignore 247 | label[ignore] = 0 248 | lb_one_hot.scatter_(1, label.unsqueeze(1), 1) 249 | ignore = ignore.nonzero() 250 | _, M = ignore.size() 251 | a, *b = ignore.chunk(M, dim=1) 252 | lb_one_hot[[a, torch.arange(self.n_labels), *b]] = 0 253 | 254 | return lb_one_hot 255 | 256 | 257 | if __name__ == "__main__": 258 | compute_mean_var() 259 | # dl_x, dl_u = get_train_loader(64, 250, 2, 2) 260 | # dl_x2 = iter(dl_x) 261 | # dl_u2 = iter(dl_u) 262 | # ims, lb = next(dl_u2) 263 | # print(type(ims)) 264 | # print(len(ims)) 265 | # print(ims[0].size()) 266 | # print(len(dl_u2)) 267 | # for i in range(1024): 268 | # try: 269 | # ims_x, lbs_x = next(dl_x2) 270 | # # ims_u, lbs_u = next(dl_u2) 271 | # print(i, ": ", ims_x[0].size()) 272 | # except StopIteration: 273 | # dl_x2 = iter(dl_x) 274 | # dl_u2 = iter(dl_u) 275 | # ims_x, lbs_x = next(dl_x2) 276 | # # ims_u, lbs_u = next(dl_u2) 277 | # print('recreate iterator') 278 | # print(i, ": ", ims_x[0].size()) 279 | -------------------------------------------------------------------------------- /datasets/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10): 313 | self.N = N 314 | self.M = M 315 | 316 | def get_random_ops(self): 317 | sampled_ops = np.random.choice(list(func_dict.keys()), self.N) 318 | return [(op, 0.5, self.M) for op in sampled_ops] 319 | 320 | def __call__(self, img): 321 | ops = self.get_random_ops() 322 | for name, prob, level in ops: 323 | if np.random.random() > prob: 324 | continue 325 | args = arg_dict[name](level) 326 | img = func_dict[name](img, *args) 327 | img = cutout_func(img, 16, replace_value) 328 | return img 329 | 330 | 331 | if __name__ == '__main__': 332 | a = RandomAugment() 333 | img = np.random.randn(32, 32, 3) 334 | a(img) -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import int_classes as _int_classes 3 | 4 | 5 | class Sampler(object): 6 | r"""Base class for all Samplers. 7 | 8 | Every Sampler subclass has to provide an :meth:`__iter__` method, providing a 9 | way to iterate over indices of dataset elements, and a :meth:`__len__` method 10 | that returns the length of the returned iterators. 11 | 12 | .. note:: The :meth:`__len__` method isn't strictly required by 13 | :class:`~torch.utils.data.DataLoader`, but is expected in any 14 | calculation involving the length of a :class:`~torch.utils.data.DataLoader`. 15 | """ 16 | 17 | def __init__(self, data_source): 18 | pass 19 | 20 | def __iter__(self): 21 | raise NotImplementedError 22 | 23 | # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 24 | # 25 | # Many times we have an abstract class representing a collection/iterable of 26 | # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally 27 | # implementing a `__len__` method. In such cases, we must make sure to not 28 | # provide a default implementation, because both straightforward default 29 | # implementations have their issues: 30 | # 31 | # + `return NotImplemented`: 32 | # Calling `len(subclass_instance)` raises: 33 | # TypeError: 'NotImplementedType' object cannot be interpreted as an integer 34 | # 35 | # + `raise NotImplementedError()`: 36 | # This prevents triggering some fallback behavior. E.g., the built-in 37 | # `list(X)` tries to call `len(X)` first, and executes a different code 38 | # path if the method is not found or `NotImplemented` is returned, while 39 | # raising an `NotImplementedError` will propagate and and make the call 40 | # fail where it could have use `__iter__` to complete the call. 41 | # 42 | # Thus, the only two sensible things to do are 43 | # 44 | # + **not** provide a default `__len__`. 45 | # 46 | # + raise a `TypeError` instead, which is what Python uses when users call 47 | # a method that is not defined on an object. 48 | # (@ssnl verifies that this works on at least Python 3.7.) 49 | 50 | 51 | class SequentialSampler(Sampler): 52 | r"""Samples elements sequentially, always in the same order. 53 | 54 | Arguments: 55 | data_source (Dataset): dataset to sample from 56 | """ 57 | 58 | def __init__(self, data_source): 59 | self.data_source = data_source 60 | 61 | def __iter__(self): 62 | return iter(range(len(self.data_source))) 63 | 64 | def __len__(self): 65 | return len(self.data_source) 66 | 67 | 68 | class RandomSampler(Sampler): 69 | r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. 70 | If with replacement, then user can specify :attr:`num_samples` to draw. 71 | 72 | Arguments: 73 | data_source (Dataset): dataset to sample from 74 | replacement (bool): samples are drawn with replacement if ``True``, default=``False`` 75 | num_samples (int): number of samples to draw, default=`len(dataset)`. This argument 76 | is supposed to be specified only when `replacement` is ``True``. 77 | """ 78 | 79 | def __init__(self, data_source, replacement=False, num_samples=None): 80 | self.data_source = data_source 81 | self.replacement = replacement 82 | self._num_samples = num_samples 83 | 84 | if not isinstance(self.replacement, bool): 85 | raise ValueError("replacement should be a boolean value, but got " 86 | "replacement={}".format(self.replacement)) 87 | 88 | if self._num_samples is not None and not replacement: 89 | raise ValueError("With replacement=False, num_samples should not be specified, " 90 | "since a random permute will be performed.") 91 | 92 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 93 | raise ValueError("num_samples should be a positive integer " 94 | "value, but got num_samples={}".format(self.num_samples)) 95 | 96 | @property 97 | def num_samples(self): 98 | # dataset size might change at runtime 99 | if self._num_samples is None: 100 | return len(self.data_source) 101 | return self._num_samples 102 | 103 | def __iter__(self): 104 | n = len(self.data_source) 105 | if self.replacement: 106 | n_repeats = self.num_samples // n 107 | n_remain = self.num_samples % n 108 | indices = [torch.randperm(n) for _ in range(n_repeats)] 109 | indices.append(torch.randperm(n)[:n_remain]) 110 | return iter(torch.cat(indices, dim=0).tolist()) 111 | return iter(torch.randperm(n).tolist()) 112 | 113 | def __len__(self): 114 | return self.num_samples 115 | 116 | 117 | class SubsetRandomSampler(Sampler): 118 | r"""Samples elements randomly from a given list of indices, without replacement. 119 | 120 | Arguments: 121 | indices (sequence): a sequence of indices 122 | """ 123 | 124 | def __init__(self, indices): 125 | self.indices = indices 126 | 127 | def __iter__(self): 128 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 129 | 130 | def __len__(self): 131 | return len(self.indices) 132 | 133 | 134 | class WeightedRandomSampler(Sampler): 135 | r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). 136 | 137 | Args: 138 | weights (sequence) : a sequence of weights, not necessary summing up to one 139 | num_samples (int): number of samples to draw 140 | replacement (bool): if ``True``, samples are drawn with replacement. 141 | If not, they are drawn without replacement, which means that when a 142 | sample index is drawn for a row, it cannot be drawn again for that row. 143 | 144 | Example: 145 | >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) 146 | [0, 0, 0, 1, 0] 147 | >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) 148 | [0, 1, 4, 3, 2] 149 | """ 150 | 151 | def __init__(self, weights, num_samples, replacement=True): 152 | if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ 153 | num_samples <= 0: 154 | raise ValueError("num_samples should be a positive integer " 155 | "value, but got num_samples={}".format(num_samples)) 156 | if not isinstance(replacement, bool): 157 | raise ValueError("replacement should be a boolean value, but got " 158 | "replacement={}".format(replacement)) 159 | self.weights = torch.as_tensor(weights, dtype=torch.double) 160 | self.num_samples = num_samples 161 | self.replacement = replacement 162 | 163 | def __iter__(self): 164 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist()) 165 | 166 | def __len__(self): 167 | return self.num_samples 168 | 169 | 170 | class BatchSampler(Sampler): 171 | r"""Wraps another sampler to yield a mini-batch of indices. 172 | 173 | Args: 174 | sampler (Sampler): Base sampler. 175 | batch_size (int): Size of mini-batch. 176 | drop_last (bool): If ``True``, the sampler will drop the last batch if 177 | its size would be less than ``batch_size`` 178 | 179 | Example: 180 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 181 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 182 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 183 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 184 | """ 185 | 186 | def __init__(self, sampler, batch_size, drop_last): 187 | if not isinstance(sampler, Sampler): 188 | raise ValueError("sampler should be an instance of " 189 | "torch.utils.data.Sampler, but got sampler={}" 190 | .format(sampler)) 191 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 192 | batch_size <= 0: 193 | raise ValueError("batch_size should be a positive integer value, " 194 | "but got batch_size={}".format(batch_size)) 195 | if not isinstance(drop_last, bool): 196 | raise ValueError("drop_last should be a boolean value, but got " 197 | "drop_last={}".format(drop_last)) 198 | self.sampler = sampler 199 | self.batch_size = batch_size 200 | self.drop_last = drop_last 201 | 202 | def __iter__(self): 203 | batch = [] 204 | for idx in self.sampler: 205 | batch.append(idx) 206 | if len(batch) == self.batch_size: 207 | yield batch 208 | batch = [] 209 | if len(batch) > 0 and not self.drop_last: 210 | yield batch 211 | 212 | def __len__(self): 213 | if self.drop_last: 214 | return len(self.sampler) // self.batch_size 215 | else: 216 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 217 | -------------------------------------------------------------------------------- /datasets/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | 6 | class PadandRandomCrop(object): 7 | ''' 8 | Input tensor is expected to have shape of (H, W, 3) 9 | ''' 10 | def __init__(self, border=4, cropsize=(32, 32)): 11 | self.border = border 12 | self.cropsize = cropsize 13 | 14 | def __call__(self, im): 15 | borders = [(self.border, self.border), (self.border, self.border), (0, 0)] # input is (h, w, c) 16 | convas = np.pad(im, borders, mode='reflect') 17 | H, W, C = convas.shape 18 | h, w = self.cropsize 19 | dh, dw = max(0, H-h), max(0, W-w) 20 | sh, sw = np.random.randint(0, dh), np.random.randint(0, dw) 21 | out = convas[sh:sh+h, sw:sw+w, :] 22 | return out 23 | 24 | 25 | class RandomHorizontalFlip(object): 26 | def __init__(self, p=0.5): 27 | self.p = p 28 | 29 | def __call__(self, im): 30 | if np.random.rand() < self.p: 31 | im = im[:, ::-1, :] 32 | return im 33 | 34 | 35 | class Resize(object): 36 | def __init__(self, size): 37 | self.size = size 38 | 39 | def __call__(self, im): 40 | im = cv2.resize(im, self.size) 41 | return im 42 | 43 | 44 | class Normalize(object): 45 | ''' 46 | Inputs are pixel values in range of [0, 255], channel order is 'rgb' 47 | ''' 48 | def __init__(self, mean, std): 49 | self.mean = np.array(mean, np.float32).reshape(1, 1, -1) 50 | self.std = np.array(std, np.float32).reshape(1, 1, -1) 51 | 52 | def __call__(self, im): 53 | if len(im.shape) == 4: 54 | mean, std = self.mean[None, ...], self.std[None, ...] 55 | elif len(im.shape) == 3: 56 | mean, std = self.mean, self.std 57 | im = im.astype(np.float32) / 255. 58 | # im = (im.astype(np.float32) / 127.5) - 1 59 | im -= mean 60 | im /= std 61 | return im 62 | 63 | 64 | class ToTensor(object): 65 | def __init__(self): 66 | pass 67 | 68 | def __call__(self, im): 69 | if len(im.shape) == 4: 70 | return torch.from_numpy(im.transpose(0, 3, 1, 2)) 71 | elif len(im.shape) == 3: 72 | return torch.from_numpy(im.transpose(2, 0, 1)) 73 | 74 | 75 | class Compose(object): 76 | def __init__(self, ops): 77 | self.ops = ops 78 | 79 | def __call__(self, im): 80 | for op in self.ops: 81 | im = op(im) 82 | return im 83 | -------------------------------------------------------------------------------- /label_guessor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LabelGuessor(object): 5 | 6 | def __init__(self, thresh): 7 | self.thresh = thresh 8 | 9 | def __call__(self, model, ims): 10 | org_state = { 11 | k: v.clone().detach() 12 | for k, v in model.state_dict().items() 13 | } 14 | is_train = model.training 15 | with torch.no_grad(): 16 | model.train() 17 | logits = model(ims) 18 | probs = torch.softmax(logits, dim=1) 19 | scores, lbs = torch.max(probs, dim=1) 20 | mask = scores.ge(self.thresh).float() 21 | 22 | # note it is necessary to keep org_state! especially for bn layer 23 | # for k, v in org_state.items(): 24 | # if not all((model.state_dict()[k] == v).reshape(-1)): 25 | # print(f'{k} diff') 26 | 27 | model.load_state_dict(org_state) 28 | if is_train: 29 | model.train() 30 | else: 31 | model.eval() 32 | return mask, lbs 33 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import math 4 | 5 | import torch 6 | from torch.optim.lr_scheduler import _LRScheduler, LambdaLR 7 | import numpy as np 8 | 9 | 10 | class WarmupExpLrScheduler(_LRScheduler): 11 | def __init__( 12 | self, 13 | optimizer, 14 | power, 15 | step_interval=1, 16 | warmup_iter=500, 17 | warmup_ratio=5e-4, 18 | warmup='exp', 19 | last_epoch=-1, 20 | ): 21 | self.power = power 22 | self.step_interval = step_interval 23 | self.warmup_iter = warmup_iter 24 | self.warmup_ratio = warmup_ratio 25 | self.warmup = warmup 26 | super(WarmupExpLrScheduler, self).__init__(optimizer, last_epoch) 27 | 28 | def get_lr(self): 29 | ratio = self.get_lr_ratio() 30 | lrs = [ratio * lr for lr in self.base_lrs] 31 | return lrs 32 | 33 | def get_lr_ratio(self): 34 | if self.last_epoch < self.warmup_iter: 35 | ratio = self.get_warmup_ratio() 36 | else: 37 | real_iter = self.last_epoch - self.warmup_iter 38 | ratio = self.power ** (real_iter // self.step_interval) 39 | return ratio 40 | 41 | def get_warmup_ratio(self): 42 | assert self.warmup in ('linear', 'exp') 43 | alpha = self.last_epoch / self.warmup_iter 44 | if self.warmup == 'linear': 45 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 46 | elif self.warmup == 'exp': 47 | ratio = self.warmup_ratio ** (1. - alpha) 48 | return ratio 49 | 50 | 51 | class WarmupPolyLrScheduler(_LRScheduler): 52 | def __init__( 53 | self, 54 | optimizer, 55 | power, 56 | max_iter, 57 | warmup_iter, 58 | warmup_ratio=5e-4, 59 | warmup='exp', 60 | last_epoch=-1, 61 | ): 62 | self.power = power 63 | self.max_iter = max_iter 64 | self.warmup_iter = warmup_iter 65 | self.warmup_ratio = warmup_ratio 66 | self.warmup = warmup 67 | super(WarmupPolyLrScheduler, self).__init__(optimizer, last_epoch) 68 | 69 | def get_lr(self): 70 | ratio = self.get_lr_ratio() 71 | lrs = [ratio * lr for lr in self.base_lrs] 72 | return lrs 73 | 74 | def get_lr_ratio(self): 75 | if self.last_epoch < self.warmup_iter: 76 | ratio = self.get_warmup_ratio() 77 | else: 78 | real_iter = self.last_epoch - self.warmup_iter 79 | real_max_iter = self.max_iter - self.warmup_iter 80 | alpha = real_iter / real_max_iter 81 | ratio = (1 - alpha) ** self.power 82 | return ratio 83 | 84 | def get_warmup_ratio(self): 85 | assert self.warmup in ('linear', 'exp') 86 | alpha = self.last_epoch / self.warmup_iter 87 | if self.warmup == 'linear': 88 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 89 | elif self.warmup == 'exp': 90 | ratio = self.warmup_ratio ** (1. - alpha) 91 | return ratio 92 | 93 | 94 | class WarmupCosineLrScheduler(_LRScheduler): 95 | ''' 96 | This is different from official definition, this is implemented according to 97 | the paper of fix-match 98 | ''' 99 | def __init__( 100 | self, 101 | optimizer, 102 | max_iter, 103 | warmup_iter, 104 | warmup_ratio=5e-4, 105 | warmup='exp', 106 | last_epoch=-1, 107 | ): 108 | self.max_iter = max_iter 109 | self.warmup_iter = warmup_iter 110 | self.warmup_ratio = warmup_ratio 111 | self.warmup = warmup 112 | super(WarmupCosineLrScheduler, self).__init__(optimizer, last_epoch) 113 | 114 | def get_lr(self): 115 | ratio = self.get_lr_ratio() 116 | lrs = [ratio * lr for lr in self.base_lrs] 117 | return lrs 118 | 119 | def get_lr_ratio(self): 120 | if self.last_epoch < self.warmup_iter: 121 | ratio = self.get_warmup_ratio() 122 | else: 123 | real_iter = self.last_epoch - self.warmup_iter 124 | real_max_iter = self.max_iter - self.warmup_iter 125 | ratio = np.cos((7 * np.pi * real_iter) / (16 * real_max_iter)) 126 | return ratio 127 | 128 | def get_warmup_ratio(self): 129 | assert self.warmup in ('linear', 'exp') 130 | alpha = self.last_epoch / self.warmup_iter 131 | if self.warmup == 'linear': 132 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 133 | elif self.warmup == 'exp': 134 | ratio = self.warmup_ratio ** (1. - alpha) 135 | return ratio 136 | 137 | 138 | # from Fixmatch-pytorch 139 | def get_cosine_schedule_with_warmup(optimizer, 140 | num_warmup_steps, 141 | num_training_steps, 142 | num_cycles=7./16., 143 | last_epoch=-1): 144 | def _lr_lambda(current_step): 145 | if current_step < num_warmup_steps: 146 | return float(current_step) / float(max(1, num_warmup_steps)) 147 | no_progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 148 | # return max(0., math.cos(math.pi * num_cycles * no_progress)) 149 | 150 | return max(0., (math.cos(math.pi * num_cycles * no_progress) + 1) * 0.5) 151 | 152 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 153 | 154 | if __name__ == "__main__": 155 | model = torch.nn.Conv2d(3, 16, 3, 1, 1) 156 | optim = torch.optim.SGD(model.parameters(), lr=1e-3) 157 | 158 | max_iter = 20000 159 | # lr_scheduler = WarmupCosineLrScheduler(optim, max_iter, 0) 160 | lr_scheduler = get_cosine_schedule_with_warmup( 161 | optim, 0, max_iter) 162 | 163 | lrs = [] 164 | for _ in range(max_iter): 165 | lr = lr_scheduler.get_lr()[0] 166 | print(lr) 167 | lrs.append(lr) 168 | lr_scheduler.step() 169 | import matplotlib 170 | import matplotlib.pyplot as plt 171 | import numpy as np 172 | lrs = np.array(lrs) 173 | n_lrs = len(lrs) 174 | plt.plot(np.arange(n_lrs), lrs) 175 | plt.title('3') 176 | plt.grid() 177 | plt.show() 178 | 179 | 180 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valencebond/FixMatch_pytorch/3b7fc36ca0c71754a59c9c78465d681f259ee174/models/__init__.py -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | """ 5 | 为什么 param 和 buffer 要采用不同的的更新策略 6 | param 是 指数移动平均数,buffer 不是 7 | """ 8 | 9 | 10 | class EMA(object): 11 | def __init__(self, model, alpha=0.999): 12 | self.step = 0 13 | self.model = model 14 | self.alpha = alpha 15 | self.shadow = self.get_model_state() 16 | self.backup = {} 17 | self.param_keys = [k for k, _ in self.model.named_parameters()] 18 | # num_batches_tracked, running_mean, running_var in bn 19 | self.buffer_keys = [k for k, _ in self.model.named_buffers()] 20 | 21 | def update_params(self): 22 | # decay = min(self.alpha, (self.step + 1) / (self.step + 10)) # ???? 23 | decay = self.alpha 24 | state = self.model.state_dict() # current params 25 | for name in self.param_keys: 26 | self.shadow[name].copy_( 27 | decay * self.shadow[name] + (1 - decay) * state[name] 28 | ) 29 | # for name in self.buffer_keys: 30 | # self.shadow[name].copy_( 31 | # decay * self.shadow[name] 32 | # + (1 - decay) * state[name] 33 | # ) 34 | 35 | self.step += 1 36 | 37 | def update_buffer(self): 38 | # without EMA 39 | state = self.model.state_dict() 40 | for name in self.buffer_keys: 41 | self.shadow[name].copy_(state[name]) 42 | 43 | def apply_shadow(self): 44 | self.backup = self.get_model_state() 45 | self.model.load_state_dict(self.shadow) 46 | 47 | def restore(self): 48 | self.model.load_state_dict(self.backup) 49 | 50 | def get_model_state(self): 51 | return { 52 | k: v.clone().detach() 53 | for k, v in self.model.state_dict().items() 54 | } 55 | 56 | 57 | if __name__ == '__main__': 58 | print('=====') 59 | model = torch.nn.BatchNorm1d(5) 60 | ema = EMA(model, 0.9, 0.02, 0.002) 61 | inten = torch.randn(10, 5) 62 | out = model(inten) 63 | ema.update_params() 64 | print(model.state_dict()) 65 | ema.update_buffer() 66 | print(model.state_dict()) 67 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from torch.nn import BatchNorm2d 9 | 10 | ''' 11 | As in the paper, the wide resnet only considers the resnet of the pre-activated version, 12 | and it only considers the basic blocks rather than the bottleneck blocks. 13 | ''' 14 | 15 | 16 | class BasicBlockPreAct(nn.Module): 17 | def __init__(self, in_chan, out_chan, drop_rate=0, stride=1, pre_res_act=False): 18 | super(BasicBlockPreAct, self).__init__() 19 | self.bn1 = BatchNorm2d(in_chan, momentum=0.001) 20 | self.relu1 = nn.LeakyReLU(inplace=True, negative_slope=0.1) 21 | self.conv1 = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn2 = BatchNorm2d(out_chan, momentum=0.001) 23 | self.relu2 = nn.LeakyReLU(inplace=True, negative_slope=0.1) 24 | self.dropout = nn.Dropout(drop_rate) if not drop_rate == 0 else None 25 | self.conv2 = nn.Conv2d(out_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.downsample = None 27 | if in_chan != out_chan or stride != 1: 28 | self.downsample = nn.Conv2d( 29 | in_chan, out_chan, kernel_size=1, stride=stride, bias=False 30 | ) 31 | self.pre_res_act = pre_res_act 32 | # self.init_weight() 33 | 34 | def forward(self, x): 35 | bn1 = self.bn1(x) 36 | act1 = self.relu1(bn1) 37 | residual = self.conv1(act1) 38 | residual = self.bn2(residual) 39 | residual = self.relu2(residual) 40 | if self.dropout is not None: 41 | residual = self.dropout(residual) 42 | residual = self.conv2(residual) 43 | 44 | shortcut = act1 if self.pre_res_act else x 45 | if self.downsample is not None: 46 | shortcut = self.downsample(shortcut) 47 | 48 | out = shortcut + residual 49 | return out 50 | 51 | def init_weight(self): 52 | # for _, md in self.named_modules(): 53 | # if isinstance(md, nn.Conv2d): 54 | # nn.init.kaiming_normal_( 55 | # md.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 56 | # if md.bias is not None: 57 | # nn.init.constant_(md.bias, 0) 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 61 | if m.bias is not None: 62 | nn.init.constant_(m.bias, 0) 63 | 64 | 65 | class WideResnetBackbone(nn.Module): 66 | def __init__(self, k=1, n=28, drop_rate=0): 67 | super(WideResnetBackbone, self).__init__() 68 | self.k, self.n = k, n 69 | assert (self.n - 4) % 6 == 0 70 | n_blocks = (self.n - 4) // 6 71 | n_layers = [16, ] + [self.k * 16 * (2 ** i) for i in range(3)] 72 | 73 | self.conv1 = nn.Conv2d( 74 | 3, 75 | n_layers[0], 76 | kernel_size=3, 77 | stride=1, 78 | padding=1, 79 | bias=False 80 | ) 81 | self.layer1 = self.create_layer( 82 | n_layers[0], 83 | n_layers[1], 84 | bnum=n_blocks, 85 | stride=1, 86 | drop_rate=drop_rate, 87 | pre_res_act=True, 88 | ) 89 | self.layer2 = self.create_layer( 90 | n_layers[1], 91 | n_layers[2], 92 | bnum=n_blocks, 93 | stride=2, 94 | drop_rate=drop_rate, 95 | pre_res_act=False, 96 | ) 97 | self.layer3 = self.create_layer( 98 | n_layers[2], 99 | n_layers[3], 100 | bnum=n_blocks, 101 | stride=2, 102 | drop_rate=drop_rate, 103 | pre_res_act=False, 104 | ) 105 | self.bn_last = BatchNorm2d(n_layers[3], momentum=0.001) 106 | self.relu_last = nn.LeakyReLU(inplace=True, negative_slope=0.1) 107 | self.init_weight() 108 | 109 | def create_layer( 110 | self, 111 | in_chan, 112 | out_chan, 113 | bnum, 114 | stride=1, 115 | drop_rate=0, 116 | pre_res_act=False, 117 | ): 118 | layers = [ 119 | BasicBlockPreAct( 120 | in_chan, 121 | out_chan, 122 | drop_rate=drop_rate, 123 | stride=stride, 124 | pre_res_act=pre_res_act), ] 125 | for _ in range(bnum - 1): 126 | layers.append( 127 | BasicBlockPreAct( 128 | out_chan, 129 | out_chan, 130 | drop_rate=drop_rate, 131 | stride=1, 132 | pre_res_act=False, )) 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | feat = self.conv1(x) 137 | 138 | feat = self.layer1(feat) 139 | feat2 = self.layer2(feat) # 1/2 140 | feat4 = self.layer3(feat2) # 1/4 141 | 142 | feat4 = self.bn_last(feat4) 143 | feat4 = self.relu_last(feat4) 144 | return feat2, feat4 145 | 146 | def init_weight(self): 147 | # for _, child in self.named_children(): 148 | # if isinstance(child, nn.Conv2d): 149 | # n = child.kernel_size[0] * child.kernel_size[0] * child.out_channels 150 | # nn.init.normal_(child.weight, 0, 1. / ((0.5 * n) ** 0.5)) 151 | # # nn.init.kaiming_normal_( 152 | # # child.weight, a=0.1, mode='fan_out', 153 | # # nonlinearity='leaky_relu' 154 | # # ) 155 | # 156 | # if child.bias is not None: 157 | # nn.init.constant_(child.bias, 0) 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 161 | m.weight.data.normal_(0, math.sqrt(2. / n)) 162 | 163 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 164 | if m.bias is not None: 165 | nn.init.constant_(m.bias, 0) 166 | elif isinstance(m, nn.BatchNorm2d): 167 | m.weight.data.fill_(1) 168 | m.bias.data.zero_() 169 | elif isinstance(m, nn.Linear): 170 | m.bias.data.zero_() 171 | 172 | 173 | class WideResnet(nn.Module): 174 | ''' 175 | for wide-resnet-28-10, the definition should be WideResnet(n_classes, 10, 28) 176 | ''' 177 | 178 | def __init__(self, n_classes, k=1, n=28): 179 | super(WideResnet, self).__init__() 180 | self.n_layers, self.k = n, k 181 | self.backbone = WideResnetBackbone(k=k, n=n) 182 | self.classifier = nn.Linear(64 * self.k, n_classes, bias=True) 183 | 184 | def forward(self, x): 185 | feat = self.backbone(x)[-1] 186 | feat = torch.mean(feat, dim=(2, 3)) 187 | feat = self.classifier(feat) 188 | return feat 189 | 190 | def init_weight(self): 191 | nn.init.xavier_normal_(self.classifier.weight) 192 | if not self.classifier.bias is None: 193 | nn.init.constant_(self.classifier.bias, 0) 194 | 195 | 196 | if __name__ == "__main__": 197 | x = torch.randn(2, 3, 224, 224) 198 | lb = torch.randint(0, 10, (2,)).long() 199 | 200 | net = WideResnetBackbone() 201 | out = net(x) 202 | print(out[0].size()) 203 | del net, out 204 | 205 | net = WideResnet(n_classes=10) 206 | criteria = nn.CrossEntropyLoss() 207 | out = net(x) 208 | loss = criteria(out, lb) 209 | loss.backward() 210 | print(out.size()) 211 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python==4.1.0.25 2 | torch==1.3.1+cu100 3 | matplotlib==3.1.1 4 | numpy==1.17.2 5 | tensorboardX==2.0 6 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=1 2 | export MKL_NUM_THREADS=1 3 | 4 | python train.py 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import random 3 | 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from models.model import WideResnet 16 | from datasets.cifar import get_train_loader, get_val_loader 17 | from label_guessor import LabelGuessor 18 | from lr_scheduler import WarmupCosineLrScheduler 19 | from models.ema import EMA 20 | 21 | from utils import accuracy, setup_default_logging, interleave, de_interleave 22 | 23 | from utils import AverageMeter 24 | 25 | 26 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 27 | 28 | 29 | def set_model(args): 30 | model = WideResnet(n_classes=10 if args.dataset == 'CIFAR10' else 100, 31 | k=args.wresnet_k, n=args.wresnet_n) # wresnet-28-2 32 | 33 | model.train() 34 | model.cuda() 35 | criteria_x = nn.CrossEntropyLoss().cuda() 36 | criteria_u = nn.CrossEntropyLoss(reduction='none').cuda() 37 | return model, criteria_x, criteria_u 38 | 39 | 40 | def train_one_epoch(epoch, 41 | model, 42 | criteria_x, 43 | criteria_u, 44 | optim, 45 | lr_schdlr, 46 | ema, 47 | dltrain_x, 48 | dltrain_u, 49 | lb_guessor, 50 | lambda_u, 51 | n_iters, 52 | logger, 53 | ): 54 | model.train() 55 | # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], [] 56 | loss_meter = AverageMeter() 57 | loss_x_meter = AverageMeter() 58 | loss_u_meter = AverageMeter() 59 | loss_u_real_meter = AverageMeter() 60 | # the number of correctly-predicted and gradient-considered unlabeled data 61 | n_correct_u_lbs_meter = AverageMeter() 62 | # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples 63 | n_strong_aug_meter = AverageMeter() 64 | mask_meter = AverageMeter() 65 | 66 | epoch_start = time.time() # start time 67 | dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) 68 | for it in range(n_iters): 69 | ims_x_weak, ims_x_strong, lbs_x = next(dl_x) 70 | ims_u_weak, ims_u_strong, lbs_u_real = next(dl_u) 71 | 72 | lbs_x = lbs_x.cuda() 73 | lbs_u_real = lbs_u_real.cuda() 74 | 75 | # -------------------------------------- 76 | 77 | bt = ims_x_weak.size(0) 78 | mu = int(ims_u_weak.size(0) // bt) 79 | imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda() 80 | imgs = interleave(imgs, 2 * mu + 1) 81 | logits = model(imgs) 82 | logits = de_interleave(logits, 2 * mu + 1) 83 | 84 | logits_x = logits[:bt] 85 | logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu) 86 | 87 | loss_x = criteria_x(logits_x, lbs_x) 88 | 89 | with torch.no_grad(): 90 | probs = torch.softmax(logits_u_w, dim=1) 91 | scores, lbs_u_guess = torch.max(probs, dim=1) 92 | mask = scores.ge(0.95).float() 93 | 94 | loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() 95 | loss = loss_x + lambda_u * loss_u 96 | loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean() 97 | 98 | # -------------------------------------- 99 | 100 | 101 | # mask, lbs_u_guess = lb_guessor(model, ims_u_weak.cuda()) 102 | # n_x = ims_x_weak.size(0) 103 | # ims_x_u = torch.cat([ims_x_weak, ims_u_strong]).cuda() 104 | # logits_x_u = model(ims_x_u) 105 | # logits_x, logits_u = logits_x_u[:n_x], logits_x_u[n_x:] 106 | # loss_x = criteria_x(logits_x, lbs_x) 107 | # loss_u = (criteria_u(logits_u, lbs_u_guess) * mask).mean() 108 | # loss = loss_x + lambda_u * loss_u 109 | # loss_u_real = (F.cross_entropy(logits_u, lbs_u_real) * mask).mean() 110 | 111 | optim.zero_grad() 112 | loss.backward() 113 | optim.step() 114 | ema.update_params() 115 | lr_schdlr.step() 116 | 117 | loss_meter.update(loss.item()) 118 | loss_x_meter.update(loss_x.item()) 119 | loss_u_meter.update(loss_u.item()) 120 | loss_u_real_meter.update(loss_u_real.item()) 121 | mask_meter.update(mask.mean().item()) 122 | 123 | corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask 124 | n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) 125 | n_strong_aug_meter.update(mask.sum().item()) 126 | 127 | if (it + 1) % 512 == 0: 128 | t = time.time() - epoch_start 129 | 130 | lr_log = [pg['lr'] for pg in optim.param_groups] 131 | lr_log = sum(lr_log) / len(lr_log) 132 | 133 | logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. " 134 | "n_correct_u: {:.2f}/{:.2f}. Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format( 135 | epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, loss_u_real_meter.avg, 136 | n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t)) 137 | 138 | epoch_start = time.time() 139 | 140 | ema.update_buffer() 141 | return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg, loss_u_real_meter.avg, mask_meter.avg 142 | 143 | 144 | def evaluate(ema, dataloader, criterion): 145 | # using EMA params to evaluate performance 146 | ema.apply_shadow() 147 | ema.model.eval() 148 | ema.model.cuda() 149 | 150 | loss_meter = AverageMeter() 151 | top1_meter = AverageMeter() 152 | top5_meter = AverageMeter() 153 | 154 | # matches = [] 155 | with torch.no_grad(): 156 | for ims, lbs in dataloader: 157 | ims = ims.cuda() 158 | lbs = lbs.cuda() 159 | logits = ema.model(ims) 160 | loss = criterion(logits, lbs) 161 | scores = torch.softmax(logits, dim=1) 162 | top1, top5 = accuracy(scores, lbs, (1, 5)) 163 | loss_meter.update(loss.item()) 164 | top1_meter.update(top1.item()) 165 | top5_meter.update(top5.item()) 166 | 167 | # note roll back model current params to continue training 168 | ema.restore() 169 | return top1_meter.avg, top5_meter.avg, loss_meter.avg 170 | 171 | 172 | def main(): 173 | parser = argparse.ArgumentParser(description=' FixMatch Training') 174 | parser.add_argument('--wresnet-k', default=2, type=int, 175 | help='width factor of wide resnet') 176 | parser.add_argument('--wresnet-n', default=28, type=int, 177 | help='depth of wide resnet') 178 | parser.add_argument('--dataset', type=str, default='CIFAR10', 179 | help='number of classes in dataset') 180 | # parser.add_argument('--n-classes', type=int, default=100, 181 | # help='number of classes in dataset') 182 | parser.add_argument('--n-labeled', type=int, default=40, 183 | help='number of labeled samples for training') 184 | parser.add_argument('--n-epoches', type=int, default=1024, 185 | help='number of training epoches') 186 | parser.add_argument('--batchsize', type=int, default=64, 187 | help='train batch size of labeled samples') 188 | parser.add_argument('--mu', type=int, default=7, 189 | help='factor of train batch size of unlabeled samples') 190 | parser.add_argument('--thr', type=float, default=0.95, 191 | help='pseudo label threshold') 192 | parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, 193 | help='number of training images for each epoch') 194 | parser.add_argument('--lam-u', type=float, default=1., 195 | help='coefficient of unlabeled loss') 196 | parser.add_argument('--ema-alpha', type=float, default=0.999, 197 | help='decay rate for ema module') 198 | parser.add_argument('--lr', type=float, default=0.03, 199 | help='learning rate for training') 200 | parser.add_argument('--weight-decay', type=float, default=5e-4, 201 | help='weight decay') 202 | parser.add_argument('--momentum', type=float, default=0.9, 203 | help='momentum for optimizer') 204 | parser.add_argument('--seed', type=int, default=-1, 205 | help='seed for random behaviors, no seed if negtive') 206 | 207 | args = parser.parse_args() 208 | 209 | logger, writer = setup_default_logging(args) 210 | logger.info(dict(args._get_kwargs())) 211 | 212 | # global settings 213 | # torch.multiprocessing.set_sharing_strategy('file_system') 214 | if args.seed > 0: 215 | torch.manual_seed(args.seed) 216 | random.seed(args.seed) 217 | np.random.seed(args.seed) 218 | # torch.backends.cudnn.deterministic = True 219 | 220 | n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 221 | n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 1024 222 | 223 | logger.info("***** Running training *****") 224 | logger.info(f" Task = {args.dataset}@{args.n_labeled}") 225 | logger.info(f" Num Epochs = {n_iters_per_epoch}") 226 | logger.info(f" Batch size per GPU = {args.batchsize}") 227 | # logger.info(f" Total train batch size = {args.batch_size * args.world_size}") 228 | logger.info(f" Total optimization steps = {n_iters_all}") 229 | 230 | model, criteria_x, criteria_u = set_model(args) 231 | logger.info("Total params: {:.2f}M".format( 232 | sum(p.numel() for p in model.parameters()) / 1e6)) 233 | 234 | dltrain_x, dltrain_u = get_train_loader( 235 | args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled) 236 | dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2) 237 | 238 | lb_guessor = LabelGuessor(thresh=args.thr) 239 | 240 | ema = EMA(model, args.ema_alpha) 241 | 242 | wd_params, non_wd_params = [], [] 243 | for name, param in model.named_parameters(): 244 | # if len(param.size()) == 1: 245 | if 'bn' in name: 246 | non_wd_params.append(param) # bn.weight, bn.bias and classifier.bias 247 | # print(name) 248 | else: 249 | wd_params.append(param) 250 | param_list = [ 251 | {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] 252 | optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, 253 | momentum=args.momentum, nesterov=True) 254 | lr_schdlr = WarmupCosineLrScheduler( 255 | optim, max_iter=n_iters_all, warmup_iter=0 256 | ) 257 | 258 | train_args = dict( 259 | model=model, 260 | criteria_x=criteria_x, 261 | criteria_u=criteria_u, 262 | optim=optim, 263 | lr_schdlr=lr_schdlr, 264 | ema=ema, 265 | dltrain_x=dltrain_x, 266 | dltrain_u=dltrain_u, 267 | lb_guessor=lb_guessor, 268 | lambda_u=args.lam_u, 269 | n_iters=n_iters_per_epoch, 270 | logger=logger 271 | ) 272 | best_acc = -1 273 | best_epoch = 0 274 | logger.info('-----------start training--------------') 275 | for epoch in range(args.n_epoches): 276 | train_loss, loss_x, loss_u, loss_u_real, mask_mean = train_one_epoch(epoch, **train_args) 277 | # torch.cuda.empty_cache() 278 | 279 | top1, top5, valid_loss = evaluate(ema, dlval, criteria_x) 280 | 281 | writer.add_scalars('train/1.loss', {'train': train_loss, 282 | 'test': valid_loss}, epoch) 283 | writer.add_scalar('train/2.train_loss_x', loss_x, epoch) 284 | writer.add_scalar('train/3.train_loss_u', loss_u, epoch) 285 | writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch) 286 | writer.add_scalar('train/5.mask_mean', mask_mean, epoch) 287 | writer.add_scalars('test/1.test_acc', {'top1': top1, 'top5': top5}, epoch) 288 | # writer.add_scalar('test/2.test_loss', loss, epoch) 289 | 290 | # best_acc = top1 if best_acc < top1 else best_acc 291 | if best_acc < top1: 292 | best_acc = top1 293 | best_epoch = epoch 294 | 295 | logger.info("Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}". 296 | format(epoch, top1, top5, best_acc, best_epoch)) 297 | 298 | writer.close() 299 | 300 | 301 | if __name__ == '__main__': 302 | main() 303 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import logging 3 | import os 4 | import sys 5 | import torch 6 | 7 | # from torch.utils.tensorboard import SummaryWriter 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | def interleave(x, bt): 12 | s = list(x.shape) 13 | return torch.reshape(torch.transpose(x.reshape([-1, bt] + s[1:]), 1, 0), [-1] + s[1:]) 14 | 15 | 16 | def de_interleave(x, bt): 17 | s = list(x.shape) 18 | return torch.reshape(torch.transpose(x.reshape([bt, -1] + s[1:]), 1, 0), [-1] + s[1:]) 19 | 20 | 21 | def setup_default_logging(args, default_level=logging.INFO, 22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"): 23 | output_dir = os.path.join(args.dataset, f'x{args.n_labeled}') 24 | os.makedirs(output_dir, exist_ok=True) 25 | 26 | writer = SummaryWriter(comment=f'{args.dataset}_{args.n_labeled}') 27 | 28 | logger = logging.getLogger('train') 29 | 30 | logging.basicConfig( # unlike the root logger, a custom logger can’t be configured using basicConfig() 31 | filename=os.path.join(output_dir, f'{time_str()}.log'), 32 | format=format, 33 | datefmt="%m/%d/%Y %H:%M:%S", 34 | level=default_level) 35 | 36 | # print 37 | # file_handler = logging.FileHandler() 38 | console_handler = logging.StreamHandler(sys.stdout) 39 | console_handler.setLevel(default_level) 40 | console_handler.setFormatter(logging.Formatter(format)) 41 | logger.addHandler(console_handler) 42 | 43 | return logger, writer 44 | 45 | 46 | def accuracy(output, target, topk=(1,)): 47 | """Computes the precision@k for the specified values of k""" 48 | maxk = max(topk) 49 | batch_size = target.size(0) 50 | 51 | _, pred = output.topk(maxk, 1, largest=True, sorted=True) # return value, indices 52 | pred = pred.t() 53 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 54 | 55 | res = [] 56 | for k in topk: 57 | correct_k = correct[:k].view(-1).float().sum(0) 58 | res.append(correct_k.mul_(100.0 / batch_size)) 59 | return res 60 | 61 | 62 | class AverageMeter(object): 63 | """ 64 | Computes and stores the average and current value 65 | 66 | """ 67 | 68 | def __init__(self): 69 | self.reset() 70 | 71 | def reset(self): 72 | self.val = 0 73 | self.avg = 0 74 | self.sum = 0 75 | self.count = 0 76 | 77 | def update(self, val, n=1): 78 | self.val = val 79 | self.sum += val * n 80 | self.count += n 81 | # self.avg = self.sum / (self.count + 1e-20) 82 | self.avg = self.sum / self.count 83 | 84 | 85 | def time_str(fmt=None): 86 | if fmt is None: 87 | fmt = '%Y-%m-%d_%H:%M:%S' 88 | 89 | # time.strftime(format[, t]) 90 | return datetime.today().strftime(fmt) 91 | 92 | 93 | if __name__ == '__main__': 94 | 95 | a = torch.tensor(range(30)) 96 | a_ = interleave(a, 15) 97 | a__ = de_interleave(a_, 15) 98 | print(a, a_, a__) 99 | --------------------------------------------------------------------------------