├── .gitignore ├── LICENSE ├── MW-Net.py ├── NeurIPS2019.pdf ├── README.md ├── load_corrupted_data.py ├── resnet.py ├── train_WRN-28-10_Meta_PGC.py └── wideresnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 xjtushujun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MW-Net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | from torch.autograd import Variable 18 | from torch.utils.data.sampler import SubsetRandomSampler 19 | import matplotlib.pyplot as plt 20 | # import sklearn.metrics as sm 21 | # import pandas as pd 22 | # import sklearn.metrics as sm 23 | import random 24 | import numpy as np 25 | 26 | # from wideresnet import WideResNet, VNet 27 | from resnet import ResNet32,VNet 28 | from load_corrupted_data import CIFAR10, CIFAR100 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch WideResNet Training') 31 | parser.add_argument('--dataset', default='cifar10', type=str, 32 | help='dataset (cifar10 [default] or cifar100)') 33 | parser.add_argument('--corruption_prob', type=float, default=0.4, 34 | help='label noise') 35 | parser.add_argument('--corruption_type', '-ctype', type=str, default='unif', 36 | help='Type of corruption ("unif" or "flip" or "flip2").') 37 | parser.add_argument('--num_meta', type=int, default=1000) 38 | parser.add_argument('--epochs', default=120, type=int, 39 | help='number of total epochs to run') 40 | parser.add_argument('--iters', default=60000, type=int, 41 | help='number of total iters to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('--batch_size', '--batch-size', default=100, type=int, 45 | help='mini-batch size (default: 100)') 46 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 47 | help='initial learning rate') 48 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 49 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 50 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 51 | help='weight decay (default: 5e-4)') 52 | parser.add_argument('--print-freq', '-p', default=10, type=int, 53 | help='print frequency (default: 10)') 54 | parser.add_argument('--layers', default=28, type=int, 55 | help='total number of layers (default: 28)') 56 | parser.add_argument('--widen-factor', default=10, type=int, 57 | help='widen factor (default: 10)') 58 | parser.add_argument('--droprate', default=0, type=float, 59 | help='dropout probability (default: 0.0)') 60 | parser.add_argument('--no-augment', dest='augment', action='store_false', 61 | help='whether to use standard augmentation (default: True)') 62 | parser.add_argument('--resume', default='', type=str, 63 | help='path to latest checkpoint (default: none)') 64 | parser.add_argument('--name', default='WideResNet-28-10', type=str, 65 | help='name of experiment') 66 | parser.add_argument('--seed', type=int, default=1) 67 | parser.add_argument('--prefetch', type=int, default=0, help='Pre-fetching threads.') 68 | parser.set_defaults(augment=True) 69 | 70 | args = parser.parse_args() 71 | use_cuda = True 72 | torch.manual_seed(args.seed) 73 | device = torch.device("cuda" if use_cuda else "cpu") 74 | 75 | 76 | print() 77 | print(args) 78 | 79 | def build_dataset(): 80 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 81 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 82 | if args.augment: 83 | train_transform = transforms.Compose([ 84 | transforms.ToTensor(), 85 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 86 | (4, 4, 4, 4), mode='reflect').squeeze()), 87 | transforms.ToPILImage(), 88 | transforms.RandomCrop(32), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | normalize, 92 | ]) 93 | else: 94 | train_transform = transforms.Compose([ 95 | transforms.ToTensor(), 96 | normalize, 97 | ]) 98 | test_transform = transforms.Compose([ 99 | transforms.ToTensor(), 100 | normalize 101 | ]) 102 | 103 | if args.dataset == 'cifar10': 104 | train_data_meta = CIFAR10( 105 | root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 106 | corruption_type=args.corruption_type, transform=train_transform, download=True) 107 | train_data = CIFAR10( 108 | root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 109 | corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) 110 | test_data = CIFAR10(root='../data', train=False, transform=test_transform, download=True) 111 | 112 | 113 | elif args.dataset == 'cifar100': 114 | train_data_meta = CIFAR100( 115 | root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 116 | corruption_type=args.corruption_type, transform=train_transform, download=True) 117 | train_data = CIFAR100( 118 | root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 119 | corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) 120 | test_data = CIFAR100(root='../data', train=False, transform=test_transform, download=True) 121 | 122 | 123 | train_loader = torch.utils.data.DataLoader( 124 | train_data, batch_size=args.batch_size, shuffle=True, 125 | num_workers=args.prefetch, pin_memory=True) 126 | train_meta_loader = torch.utils.data.DataLoader( 127 | train_data_meta, batch_size=args.batch_size, shuffle=True, 128 | num_workers=args.prefetch, pin_memory=True) 129 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, 130 | num_workers=args.prefetch, pin_memory=True) 131 | 132 | return train_loader, train_meta_loader, test_loader 133 | 134 | 135 | def build_model(): 136 | model = ResNet32(args.dataset == 'cifar10' and 10 or 100) 137 | 138 | if torch.cuda.is_available(): 139 | model.cuda() 140 | torch.backends.cudnn.benchmark = True 141 | 142 | return model 143 | 144 | def accuracy(output, target, topk=(1,)): 145 | """Computes the precision@k for the specified values of k""" 146 | maxk = max(topk) 147 | batch_size = target.size(0) 148 | 149 | _, pred = output.topk(maxk, 1, True, True) 150 | pred = pred.t() 151 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 152 | 153 | res = [] 154 | for k in topk: 155 | correct_k = correct[:k].view(-1).float().sum(0) 156 | res.append(correct_k.mul_(100.0 / batch_size)) 157 | return res 158 | 159 | 160 | def adjust_learning_rate(optimizer, epochs): 161 | lr = args.lr * ((0.1 ** int(epochs >= 80)) * (0.1 ** int(epochs >= 100))) # For WRN-28-10 162 | for param_group in optimizer.param_groups: 163 | param_group['lr'] = lr 164 | 165 | 166 | 167 | def test(model, test_loader): 168 | model.eval() 169 | correct = 0 170 | test_loss = 0 171 | 172 | with torch.no_grad(): 173 | for batch_idx, (inputs, targets) in enumerate(test_loader): 174 | inputs, targets = inputs.to(device), targets.to(device) 175 | outputs = model(inputs) 176 | test_loss +=F.cross_entropy(outputs, targets).item() 177 | _, predicted = outputs.max(1) 178 | correct += predicted.eq(targets).sum().item() 179 | 180 | test_loss /= len(test_loader.dataset) 181 | accuracy = 100. * correct / len(test_loader.dataset) 182 | 183 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format( 184 | test_loss, correct, len(test_loader.dataset), 185 | accuracy)) 186 | 187 | return accuracy 188 | 189 | 190 | def train(train_loader,train_meta_loader,model, vnet,optimizer_model,optimizer_vnet,epoch): 191 | print('\nEpoch: %d' % epoch) 192 | 193 | train_loss = 0 194 | meta_loss = 0 195 | 196 | train_meta_loader_iter = iter(train_meta_loader) 197 | for batch_idx, (inputs, targets) in enumerate(train_loader): 198 | model.train() 199 | inputs, targets = inputs.to(device), targets.to(device) 200 | meta_model = build_model().cuda() 201 | meta_model.load_state_dict(model.state_dict()) 202 | outputs = meta_model(inputs) 203 | 204 | cost = F.cross_entropy(outputs, targets, reduce=False) 205 | cost_v = torch.reshape(cost, (len(cost), 1)) 206 | v_lambda = vnet(cost_v.data) 207 | l_f_meta = torch.sum(cost_v * v_lambda)/len(cost_v) 208 | meta_model.zero_grad() 209 | grads = torch.autograd.grad(l_f_meta, (meta_model.params()), create_graph=True) 210 | meta_lr = args.lr * ((0.1 ** int(epoch >= 80)) * (0.1 ** int(epoch >= 100))) # For ResNet32 211 | meta_model.update_params(lr_inner=meta_lr, source_params=grads) 212 | del grads 213 | 214 | try: 215 | inputs_val, targets_val = next(train_meta_loader_iter) 216 | except StopIteration: 217 | train_meta_loader_iter = iter(train_meta_loader) 218 | inputs_val, targets_val = next(train_meta_loader_iter) 219 | inputs_val, targets_val = inputs_val.to(device), targets_val.to(device) 220 | y_g_hat = meta_model(inputs_val) 221 | l_g_meta = F.cross_entropy(y_g_hat, targets_val) 222 | prec_meta = accuracy(y_g_hat.data, targets_val.data, topk=(1,))[0] 223 | 224 | 225 | optimizer_vnet.zero_grad() 226 | l_g_meta.backward() 227 | optimizer_vnet.step() 228 | 229 | outputs = model(inputs) 230 | cost_w = F.cross_entropy(outputs, targets, reduce=False) 231 | cost_v = torch.reshape(cost_w, (len(cost_w), 1)) 232 | prec_train = accuracy(outputs.data, targets.data, topk=(1,))[0] 233 | 234 | with torch.no_grad(): 235 | w_new = vnet(cost_v) 236 | 237 | loss = torch.sum(cost_v * w_new)/len(cost_v) 238 | 239 | optimizer_model.zero_grad() 240 | loss.backward() 241 | optimizer_model.step() 242 | 243 | 244 | train_loss += loss.item() 245 | meta_loss += l_g_meta.item() 246 | 247 | 248 | if (batch_idx + 1) % 50 == 0: 249 | print('Epoch: [%d/%d]\t' 250 | 'Iters: [%d/%d]\t' 251 | 'Loss: %.4f\t' 252 | 'MetaLoss:%.4f\t' 253 | 'Prec@1 %.2f\t' 254 | 'Prec_meta@1 %.2f' % ( 255 | (epoch + 1), args.epochs, batch_idx + 1, len(train_loader.dataset)/args.batch_size, (train_loss / (batch_idx + 1)), 256 | (meta_loss / (batch_idx + 1)), prec_train, prec_meta)) 257 | 258 | 259 | 260 | 261 | train_loader, train_meta_loader, test_loader = build_dataset() 262 | # create model 263 | model = build_model() 264 | vnet = VNet(1, 100, 1).cuda() 265 | 266 | if args.dataset == 'cifar10': 267 | num_classes = 10 268 | if args.dataset == 'cifar100': 269 | num_classes = 100 270 | 271 | 272 | optimizer_model = torch.optim.SGD(model.params(), args.lr, 273 | momentum=args.momentum, weight_decay=args.weight_decay) 274 | optimizer_vnet = torch.optim.Adam(vnet.params(), 1e-3, 275 | weight_decay=1e-4) 276 | 277 | def main(): 278 | best_acc = 0 279 | for epoch in range(args.epochs): 280 | adjust_learning_rate(optimizer_model, epoch) 281 | train(train_loader,train_meta_loader,model, vnet,optimizer_model,optimizer_vnet,epoch) 282 | test_acc = test(model=model, test_loader=test_loader) 283 | if test_acc >= best_acc: 284 | best_acc = test_acc 285 | 286 | print('best accuracy:', best_acc) 287 | 288 | 289 | if __name__ == '__main__': 290 | main() 291 | -------------------------------------------------------------------------------- /NeurIPS2019.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjtushujun/meta-weight-net/3f1800ff26bc66ceda6075f3fa25f3a7e0665a4d/NeurIPS2019.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Weight-Net 2 | NeurIPS'19: Meta-Weight-Net: Learning an Explicit Mapping For Sample Weighting (Official Pytorch implementation for noisy labels). 3 | The implementation of class imbalance is available at https://github.com/xjtushujun/Meta-weight-net_class-imbalance. 4 | 5 | 6 | ================================================================================================================================================================ 7 | 8 | 9 | This is the code for the paper: 10 | [Meta-Weight-Net: Learning an Explicit Mapping For Sample Weighting](https://arxiv.org/abs/1902.07379) 11 | Jun Shu, Qi Xie, Lixuan Yi, Qian Zhao, Sanping Zhou, Zongben Xu, Deyu Meng* 12 | To be presented at [NeurIPS 2019](https://nips.cc/Conferences/2019/). 13 | 14 | If you find this code useful in your research then please cite 15 | ```bash 16 | @inproceedings{han2018coteaching, 17 | title={Meta-Weight-Net: Learning an Explicit Mapping For Sample Weighting}, 18 | author={Shu, Jun and Xie, Qi and Yi, Lixuan and Zhao, Qian and Zhou, Sanping and Xu, Zongben and Meng, Deyu}, 19 | booktitle={NeurIPS}, 20 | year={2019} 21 | } 22 | ``` 23 | 24 | 25 | ## Setups 26 | The requiring environment is as bellow: 27 | 28 | - Linux 29 | - Python 3+ 30 | - PyTorch 0.4.0 31 | - Torchvision 0.2.0 32 | 33 | 34 | ## Running Meta-Weight-Net on benchmark datasets (CIFAR-10 and CIFAR-100). 35 | Here is an example: 36 | ```bash 37 | python train_WRN-28-10_Meta_PGC.py --dataset cifar10 --corruption_type unif(flip2) --corruption_prob 0.6 38 | ``` 39 | 40 | The default network structure is WRN-28-10, if you want to train with ResNet32 model, please reset the learning rate delay policy. 41 | 42 | A stable version is relased. 43 | ```bash 44 | python MW-Net.py --dataset cifar10 --corruption_type unif(flip2) --corruption_prob 0.6 45 | ``` 46 | ## Important Updating Version 47 | 48 | The new code on github (https://github.com/ShiYunyi/Meta-Weight-Net_Code-Optimization) has implemented the MW-Net based on the newest pytorch and torchvision version. It rewrites an optimizer to assign non leaf node tensors to model parameters. Thus it does not need to rewrite the nn.Module as this version does. Very thanks for Shi Yunyi (2404208668@qq.com)! 49 | 50 | 51 | ## Acknowledgements 52 | We thank the Pytorch implementation on glc(https://github.com/mmazeika/glc) and learning-to-reweight-examples(https://github.com/danieltan07/learning-to-reweight-examples). 53 | 54 | 55 | Contact: Jun Shu (xjtushujun@gmail.com); Deyu Meng(dymeng@mail.xjtu.edu.cn). 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /load_corrupted_data.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import os.path 4 | import errno 5 | import numpy as np 6 | import sys 7 | import pickle 8 | 9 | 10 | import torch.utils.data as data 11 | from torchvision.datasets.utils import download_url, check_integrity 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable as V 16 | import wideresnet as wrn 17 | import torchvision.transforms as transforms 18 | 19 | 20 | def uniform_mix_C(mixing_ratio, num_classes): 21 | ''' 22 | returns a linear interpolation of a uniform matrix and an identity matrix 23 | ''' 24 | return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \ 25 | (1 - mixing_ratio) * np.eye(num_classes) 26 | 27 | def flip_labels_C(corruption_prob, num_classes, seed=1): 28 | ''' 29 | returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob 30 | concentrated in only one other entry for each row 31 | ''' 32 | np.random.seed(seed) 33 | C = np.eye(num_classes) * (1 - corruption_prob) 34 | row_indices = np.arange(num_classes) 35 | for i in range(num_classes): 36 | C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob 37 | return C 38 | 39 | def flip_labels_C_two(corruption_prob, num_classes, seed=1): 40 | ''' 41 | returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob 42 | concentrated in only one other entry for each row 43 | ''' 44 | np.random.seed(seed) 45 | C = np.eye(num_classes) * (1 - corruption_prob) 46 | row_indices = np.arange(num_classes) 47 | for i in range(num_classes): 48 | C[i][np.random.choice(row_indices[row_indices != i], 2, replace=False)] = corruption_prob / 2 49 | return C 50 | 51 | 52 | class CIFAR10(data.Dataset): 53 | base_folder = 'cifar-10-batches-py' 54 | url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 55 | filename = "cifar-10-python.tar.gz" 56 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 57 | train_list = [ 58 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 59 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 60 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 61 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 62 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 63 | ] 64 | 65 | test_list = [ 66 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 67 | ] 68 | 69 | def __init__(self, root='', train=True, meta=True, num_meta=1000, 70 | corruption_prob=0, corruption_type='unif', transform=None, target_transform=None, 71 | download=False, seed=1): 72 | self.root = root 73 | self.transform = transform 74 | self.target_transform = target_transform 75 | self.train = train # training set or test set 76 | self.meta = meta 77 | self.corruption_prob = corruption_prob 78 | self.num_meta = num_meta 79 | np.random.seed(seed) 80 | if download: 81 | self.download() 82 | 83 | if not self._check_integrity(): 84 | raise RuntimeError('Dataset not found or corrupted.' + 85 | ' You can use download=True to download it') 86 | 87 | # now load the picked numpy arrays 88 | if self.train: 89 | self.train_data = [] 90 | self.train_labels = [] 91 | self.train_coarse_labels = [] 92 | for fentry in self.train_list: 93 | f = fentry[0] 94 | file = os.path.join(root, self.base_folder, f) 95 | fo = open(file, 'rb') 96 | if sys.version_info[0] == 2: 97 | entry = pickle.load(fo) 98 | else: 99 | entry = pickle.load(fo, encoding='latin1') 100 | self.train_data.append(entry['data']) 101 | if 'labels' in entry: 102 | self.train_labels += entry['labels'] 103 | img_num_list = [int(self.num_meta/10)] * 10 104 | num_classes = 10 105 | else: 106 | self.train_labels += entry['fine_labels'] 107 | self.train_coarse_labels += entry['coarse_labels'] 108 | img_num_list = [int(self.num_meta/100)] * 100 109 | num_classes = 100 110 | fo.close() 111 | 112 | self.train_data = np.concatenate(self.train_data) 113 | self.train_data = self.train_data.reshape((50000, 3, 32, 32)) 114 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 115 | 116 | data_list_val = {} 117 | for j in range(num_classes): 118 | data_list_val[j] = [i for i, label in enumerate(self.train_labels) if label == j] 119 | 120 | 121 | idx_to_meta = [] 122 | idx_to_train = [] 123 | print(img_num_list) 124 | 125 | for cls_idx, img_id_list in data_list_val.items(): 126 | np.random.shuffle(img_id_list) 127 | img_num = img_num_list[int(cls_idx)] 128 | idx_to_meta.extend(img_id_list[:img_num]) 129 | idx_to_train.extend(img_id_list[img_num:]) 130 | 131 | 132 | if meta is True: 133 | self.train_data = self.train_data[idx_to_meta] 134 | self.train_labels = list(np.array(self.train_labels)[idx_to_meta]) 135 | else: 136 | self.train_data = self.train_data[idx_to_train] 137 | self.train_labels = list(np.array(self.train_labels)[idx_to_train]) 138 | if corruption_type == 'hierarchical': 139 | self.train_coarse_labels = list(np.array(self.train_coarse_labels)[idx_to_train]) 140 | 141 | if corruption_type == 'unif': 142 | C = uniform_mix_C(self.corruption_prob, num_classes) 143 | print(C) 144 | self.C = C 145 | elif corruption_type == 'flip': 146 | C = flip_labels_C(self.corruption_prob, num_classes) 147 | print(C) 148 | self.C = C 149 | elif corruption_type == 'flip2': 150 | C = flip_labels_C_two(self.corruption_prob, num_classes) 151 | print(C) 152 | self.C = C 153 | elif corruption_type == 'hierarchical': 154 | assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.' 155 | coarse_fine = [] 156 | for i in range(20): 157 | coarse_fine.append(set()) 158 | for i in range(len(self.train_labels)): 159 | coarse_fine[self.train_coarse_labels[i]].add(self.train_labels[i]) 160 | for i in range(20): 161 | coarse_fine[i] = list(coarse_fine[i]) 162 | 163 | C = np.eye(num_classes) * (1 - corruption_prob) 164 | 165 | for i in range(20): 166 | tmp = np.copy(coarse_fine[i]) 167 | for j in range(len(tmp)): 168 | tmp2 = np.delete(np.copy(tmp), j) 169 | C[tmp[j], tmp2] += corruption_prob * 1/len(tmp2) 170 | self.C = C 171 | print(C) 172 | elif corruption_type == 'clabels': 173 | net = wrn.WideResNet(40, num_classes, 2, dropRate=0.3).cuda() 174 | model_name = './cifar{}_labeler'.format(num_classes) 175 | net.load_state_dict(torch.load(model_name)) 176 | net.eval() 177 | else: 178 | assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical'}".format(corruption_type) 179 | 180 | 181 | if corruption_type == 'clabels': 182 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 183 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 184 | 185 | test_transform = transforms.Compose( 186 | [transforms.ToTensor(), transforms.Normalize(mean, std)]) 187 | 188 | # obtain sampling probabilities 189 | sampling_probs = [] 190 | print('Starting labeling') 191 | 192 | for i in range((len(self.train_labels) // 64) + 1): 193 | current = self.train_data[i*64:(i+1)*64] 194 | current = [Image.fromarray(current[i]) for i in range(len(current))] 195 | current = torch.cat([test_transform(current[i]).unsqueeze(0) for i in range(len(current))], dim=0) 196 | 197 | data = V(current).cuda() 198 | logits = net(data) 199 | smax = F.softmax(logits / 5) # temperature of 1 200 | sampling_probs.append(smax.data.cpu().numpy()) 201 | 202 | 203 | sampling_probs = np.concatenate(sampling_probs, 0) 204 | print('Finished labeling 1') 205 | 206 | new_labeling_correct = 0 207 | argmax_labeling_correct = 0 208 | for i in range(len(self.train_labels)): 209 | old_label = self.train_labels[i] 210 | new_label = np.random.choice(num_classes, p=sampling_probs[i]) 211 | self.train_labels[i] = new_label 212 | if old_label == new_label: 213 | new_labeling_correct += 1 214 | if old_label == np.argmax(sampling_probs[i]): 215 | argmax_labeling_correct += 1 216 | print('Finished labeling 2') 217 | print('New labeling accuracy:', new_labeling_correct / len(self.train_labels)) 218 | print('Argmax labeling accuracy:', argmax_labeling_correct / len(self.train_labels)) 219 | else: 220 | for i in range(len(self.train_labels)): 221 | self.train_labels[i] = np.random.choice(num_classes, p=C[self.train_labels[i]]) 222 | self.corruption_matrix = C 223 | 224 | else: 225 | f = self.test_list[0][0] 226 | file = os.path.join(root, self.base_folder, f) 227 | fo = open(file, 'rb') 228 | if sys.version_info[0] == 2: 229 | entry = pickle.load(fo) 230 | else: 231 | entry = pickle.load(fo, encoding='latin1') 232 | self.test_data = entry['data'] 233 | if 'labels' in entry: 234 | self.test_labels = entry['labels'] 235 | else: 236 | self.test_labels = entry['fine_labels'] 237 | fo.close() 238 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 239 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 240 | 241 | def __getitem__(self, index): 242 | if self.train: 243 | img, target = self.train_data[index], self.train_labels[index] 244 | else: 245 | img, target = self.test_data[index], self.test_labels[index] 246 | 247 | # doing this so that it is consistent with all other datasets 248 | # to return a PIL Image 249 | img = Image.fromarray(img) 250 | 251 | if self.transform is not None: 252 | img = self.transform(img) 253 | 254 | if self.target_transform is not None: 255 | target = self.target_transform(target) 256 | 257 | return img, target 258 | 259 | def __len__(self): 260 | if self.train: 261 | if self.meta is True: 262 | return self.num_meta 263 | else: 264 | return 50000 - self.num_meta 265 | else: 266 | return 10000 267 | 268 | def _check_integrity(self): 269 | root = self.root 270 | for fentry in (self.train_list + self.test_list): 271 | filename, md5 = fentry[0], fentry[1] 272 | fpath = os.path.join(root, self.base_folder, filename) 273 | if not check_integrity(fpath, md5): 274 | return False 275 | return True 276 | 277 | def download(self): 278 | import tarfile 279 | 280 | if self._check_integrity(): 281 | print('Files already downloaded and verified') 282 | return 283 | 284 | root = self.root 285 | download_url(self.url, root, self.filename, self.tgz_md5) 286 | 287 | # extract file 288 | cwd = os.getcwd() 289 | tar = tarfile.open(os.path.join(root, self.filename), "r:gz") 290 | os.chdir(root) 291 | tar.extractall() 292 | tar.close() 293 | os.chdir(cwd) 294 | 295 | 296 | class CIFAR100(CIFAR10): 297 | base_folder = 'cifar-100-python' 298 | url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 299 | filename = "cifar-100-python.tar.gz" 300 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 301 | train_list = [ 302 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 303 | ] 304 | 305 | test_list = [ 306 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 307 | ] 308 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.init as init 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class MetaModule(nn.Module): 16 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | # name_s, param_s = src 52 | # grad = param_s.grad 53 | # name_s, param_s = src 54 | grad = src 55 | if first_order: 56 | grad = to_var(grad.detach().data) 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | else: 60 | 61 | for name, param in self.named_params(self): 62 | if not detach: 63 | grad = param.grad 64 | if first_order: 65 | grad = to_var(grad.detach().data) 66 | tmp = param - lr_inner * grad 67 | self.set_param(self, name, tmp) 68 | else: 69 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 70 | self.set_param(self, name, param) 71 | 72 | def set_param(self, curr_mod, name, param): 73 | if '.' in name: 74 | n = name.split('.') 75 | module_name = n[0] 76 | rest = '.'.join(n[1:]) 77 | for name, mod in curr_mod.named_children(): 78 | if module_name == name: 79 | self.set_param(mod, rest, param) 80 | break 81 | else: 82 | setattr(curr_mod, name, param) 83 | 84 | def detach_params(self): 85 | for name, param in self.named_params(self): 86 | self.set_param(self, name, param.detach()) 87 | 88 | def copy(self, other, same_var=False): 89 | for name, param in other.named_params(): 90 | if not same_var: 91 | param = to_var(param.data.clone(), requires_grad=True) 92 | self.set_param(name, param) 93 | 94 | 95 | class MetaLinear(MetaModule): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__() 98 | ignore = nn.Linear(*args, **kwargs) 99 | 100 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 101 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 102 | 103 | def forward(self, x): 104 | return F.linear(x, self.weight, self.bias) 105 | 106 | def named_leaves(self): 107 | return [('weight', self.weight), ('bias', self.bias)] 108 | 109 | 110 | class MetaConv2d(MetaModule): 111 | def __init__(self, *args, **kwargs): 112 | super().__init__() 113 | ignore = nn.Conv2d(*args, **kwargs) 114 | 115 | self.in_channels = ignore.in_channels 116 | self.out_channels = ignore.out_channels 117 | self.stride = ignore.stride 118 | self.padding = ignore.padding 119 | self.dilation = ignore.dilation 120 | self.groups = ignore.groups 121 | self.kernel_size = ignore.kernel_size 122 | 123 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 124 | 125 | if ignore.bias is not None: 126 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 127 | else: 128 | self.register_buffer('bias', None) 129 | 130 | def forward(self, x): 131 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 132 | 133 | def named_leaves(self): 134 | return [('weight', self.weight), ('bias', self.bias)] 135 | 136 | 137 | class MetaConvTranspose2d(MetaModule): 138 | def __init__(self, *args, **kwargs): 139 | super().__init__() 140 | ignore = nn.ConvTranspose2d(*args, **kwargs) 141 | 142 | self.stride = ignore.stride 143 | self.padding = ignore.padding 144 | self.dilation = ignore.dilation 145 | self.groups = ignore.groups 146 | 147 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 148 | 149 | if ignore.bias is not None: 150 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 151 | else: 152 | self.register_buffer('bias', None) 153 | 154 | def forward(self, x, output_size=None): 155 | output_padding = self._output_padding(x, output_size) 156 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 157 | output_padding, self.groups, self.dilation) 158 | 159 | def named_leaves(self): 160 | return [('weight', self.weight), ('bias', self.bias)] 161 | 162 | 163 | class MetaBatchNorm2d(MetaModule): 164 | def __init__(self, *args, **kwargs): 165 | super().__init__() 166 | ignore = nn.BatchNorm2d(*args, **kwargs) 167 | 168 | self.num_features = ignore.num_features 169 | self.eps = ignore.eps 170 | self.momentum = ignore.momentum 171 | self.affine = ignore.affine 172 | self.track_running_stats = ignore.track_running_stats 173 | 174 | if self.affine: 175 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 176 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 177 | 178 | if self.track_running_stats: 179 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 180 | self.register_buffer('running_var', torch.ones(self.num_features)) 181 | else: 182 | self.register_parameter('running_mean', None) 183 | self.register_parameter('running_var', None) 184 | 185 | def forward(self, x): 186 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 187 | self.training or not self.track_running_stats, self.momentum, self.eps) 188 | 189 | def named_leaves(self): 190 | return [('weight', self.weight), ('bias', self.bias)] 191 | 192 | 193 | def _weights_init(m): 194 | classname = m.__class__.__name__ 195 | # print(classname) 196 | if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d): 197 | init.kaiming_normal(m.weight) 198 | 199 | class LambdaLayer(MetaModule): 200 | def __init__(self, lambd): 201 | super(LambdaLayer, self).__init__() 202 | self.lambd = lambd 203 | 204 | def forward(self, x): 205 | return self.lambd(x) 206 | 207 | 208 | class BasicBlock(MetaModule): 209 | expansion = 1 210 | 211 | def __init__(self, in_planes, planes, stride=1, option='A'): 212 | super(BasicBlock, self).__init__() 213 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 214 | self.bn1 = MetaBatchNorm2d(planes) 215 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 216 | self.bn2 = MetaBatchNorm2d(planes) 217 | 218 | self.shortcut = nn.Sequential() 219 | if stride != 1 or in_planes != planes: 220 | if option == 'A': 221 | """ 222 | For CIFAR10 ResNet paper uses option A. 223 | """ 224 | self.shortcut = LambdaLayer(lambda x: 225 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 226 | elif option == 'B': 227 | self.shortcut = nn.Sequential( 228 | MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 229 | MetaBatchNorm2d(self.expansion * planes) 230 | ) 231 | 232 | def forward(self, x): 233 | out = F.relu(self.bn1(self.conv1(x))) 234 | out = self.bn2(self.conv2(out)) 235 | out += self.shortcut(x) 236 | out = F.relu(out) 237 | return out 238 | 239 | 240 | class ResNet32(MetaModule): 241 | def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]): 242 | super(ResNet32, self).__init__() 243 | self.in_planes = 16 244 | 245 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 246 | self.bn1 = MetaBatchNorm2d(16) 247 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 248 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 249 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 250 | self.linear = MetaLinear(64, num_classes) 251 | 252 | self.apply(_weights_init) 253 | 254 | def _make_layer(self, block, planes, num_blocks, stride): 255 | strides = [stride] + [1]*(num_blocks-1) 256 | layers = [] 257 | for stride in strides: 258 | layers.append(block(self.in_planes, planes, stride)) 259 | self.in_planes = planes * block.expansion 260 | 261 | return nn.Sequential(*layers) 262 | 263 | def forward(self, x): 264 | out = F.relu(self.bn1(self.conv1(x))) 265 | out = self.layer1(out) 266 | out = self.layer2(out) 267 | out = self.layer3(out) 268 | out = F.avg_pool2d(out, out.size()[3]) 269 | out = out.view(out.size(0), -1) 270 | out = self.linear(out) 271 | return out 272 | 273 | 274 | 275 | class VNet(MetaModule): 276 | def __init__(self, input, hidden1, output): 277 | super(VNet, self).__init__() 278 | self.linear1 = MetaLinear(input, hidden1) 279 | self.relu1 = nn.ReLU(inplace=True) 280 | self.linear2 = MetaLinear(hidden1, output) 281 | # self.linear3 = MetaLinear(hidden2, output) 282 | 283 | def forward(self, x): 284 | x = self.linear1(x) 285 | x = self.relu1(x) 286 | # x = self.linear2(x) 287 | # x = self.relu1(x) 288 | out = self.linear2(x) 289 | return F.sigmoid(out) 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | -------------------------------------------------------------------------------- /train_WRN-28-10_Meta_PGC.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | from torch.autograd import Variable 18 | from torch.utils.data.sampler import SubsetRandomSampler 19 | import matplotlib.pyplot as plt 20 | import sklearn.metrics as sm 21 | import pandas as pd 22 | import sklearn.metrics as sm 23 | import random 24 | import numpy as np 25 | 26 | from wideresnet import WideResNet, VNet 27 | from resnet import ResNet32,VNet 28 | from load_corrupted_data import CIFAR10, CIFAR100 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch WideResNet Training') 31 | parser.add_argument('--dataset', default='cifar10', type=str, 32 | help='dataset (cifar10 [default] or cifar100)') 33 | parser.add_argument('--corruption_prob', type=float, default=0.4, 34 | help='label noise') 35 | parser.add_argument('--corruption_type', '-ctype', type=str, default='unif', 36 | help='Type of corruption ("unif" or "flip" or "flip2").') 37 | parser.add_argument('--num_meta', type=int, default=1000) 38 | parser.add_argument('--epochs', default=60, type=int, 39 | help='number of total epochs to run') 40 | parser.add_argument('--iters', default=20000, type=int, 41 | help='number of total iters to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=100, type=int, 45 | help='mini-batch size (default: 100)') 46 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 47 | help='initial learning rate') 48 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 49 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 50 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 51 | help='weight decay (default: 5e-4)') 52 | parser.add_argument('--print-freq', '-p', default=10, type=int, 53 | help='print frequency (default: 10)') 54 | parser.add_argument('--layers', default=28, type=int, 55 | help='total number of layers (default: 28)') 56 | parser.add_argument('--widen-factor', default=10, type=int, 57 | help='widen factor (default: 10)') 58 | parser.add_argument('--droprate', default=0, type=float, 59 | help='dropout probability (default: 0.0)') 60 | parser.add_argument('--no-augment', dest='augment', action='store_false', 61 | help='whether to use standard augmentation (default: True)') 62 | parser.add_argument('--resume', default='', type=str, 63 | help='path to latest checkpoint (default: none)') 64 | parser.add_argument('--name', default='WideResNet-28-10', type=str, 65 | help='name of experiment') 66 | parser.add_argument('--seed', type=int, default=1) 67 | parser.add_argument('--prefetch', type=int, default=0, help='Pre-fetching threads.') 68 | parser.set_defaults(augment=True) 69 | 70 | #os.environ['CUD_DEVICE_ORDER'] = "1" 71 | #ids = [1] 72 | 73 | 74 | best_prec1 = 0 75 | 76 | #use_cuda = True 77 | #device = torch.device("cuda" if use_cuda else "cpu") 78 | 79 | 80 | def main(): 81 | global args, best_prec1 82 | args = parser.parse_args() 83 | print() 84 | print(args) 85 | 86 | train_loader, train_meta_loader, test_loader = build_dataset() 87 | # create model 88 | model = build_model() 89 | optimizer_a = torch.optim.SGD(model.params(), args.lr, 90 | momentum=args.momentum, nesterov=args.nesterov, 91 | weight_decay=args.weight_decay) 92 | 93 | 94 | vnet = VNet(1, 100, 1).cuda() 95 | 96 | optimizer_c = torch.optim.SGD(vnet.params(), 1e-3, 97 | momentum=args.momentum, nesterov=args.nesterov, 98 | weight_decay=args.weight_decay) 99 | 100 | cudnn.benchmark = True 101 | 102 | # define loss function (criterion) and optimizer 103 | criterion = nn.CrossEntropyLoss().cuda() 104 | 105 | model_loss = [] 106 | meta_model_loss = [] 107 | smoothing_alpha = 0.9 108 | 109 | meta_l = 0 110 | net_l = 0 111 | accuracy_log = [] 112 | train_acc = [] 113 | 114 | for iters in range(args.iters): 115 | adjust_learning_rate(optimizer_a, iters + 1) 116 | # adjust_learning_rate(optimizer_c, iters + 1) 117 | model.train() 118 | 119 | input, target = next(iter(train_loader)) 120 | input_var = to_var(input, requires_grad=False) 121 | target_var = to_var(target, requires_grad=False) 122 | 123 | meta_model = build_model() 124 | 125 | meta_model.load_state_dict(model.state_dict()) 126 | y_f_hat = meta_model(input_var) 127 | cost = F.cross_entropy(y_f_hat, target_var, reduce=False) 128 | cost_v = torch.reshape(cost, (len(cost), 1)) 129 | 130 | 131 | v_lambda = vnet(cost_v.data) 132 | 133 | norm_c = torch.sum(v_lambda) 134 | 135 | if norm_c != 0: 136 | v_lambda_norm = v_lambda / norm_c 137 | else: 138 | v_lambda_norm = v_lambda 139 | 140 | l_f_meta = torch.sum(cost_v * v_lambda_norm) 141 | meta_model.zero_grad() 142 | grads = torch.autograd.grad(l_f_meta,(meta_model.params()),create_graph=True) 143 | meta_lr = args.lr * ((0.1 ** int(iters >= 18000)) * (0.1 ** int(iters >= 19000))) # For WRN-28-10 144 | #meta_lr = args.lr * ((0.1 ** int(iters >= 20000)) * (0.1 ** int(iters >= 25000))) # For ResNet32 145 | meta_model.update_params(lr_inner=meta_lr,source_params=grads) 146 | del grads 147 | 148 | 149 | 150 | input_validation, target_validation = next(iter(train_meta_loader)) 151 | input_validation_var = to_var(input_validation, requires_grad=False) 152 | target_validation_var = to_var(target_validation.type(torch.LongTensor), requires_grad=False) 153 | 154 | y_g_hat = meta_model(input_validation_var) 155 | l_g_meta = F.cross_entropy(y_g_hat, target_validation_var) 156 | prec_meta = accuracy(y_g_hat.data, target_validation_var.data, topk=(1,))[0] 157 | 158 | 159 | optimizer_c.zero_grad() 160 | l_g_meta.backward() 161 | optimizer_c.step() 162 | 163 | 164 | y_f = model(input_var) 165 | cost_w = F.cross_entropy(y_f, target_var, reduce=False) 166 | cost_v = torch.reshape(cost_w, (len(cost_w), 1)) 167 | prec_train = accuracy(y_f.data, target_var.data, topk=(1,))[0] 168 | 169 | 170 | with torch.no_grad(): 171 | w_new = vnet(cost_v) 172 | norm_v = torch.sum(w_new) 173 | 174 | if norm_v != 0: 175 | w_v = w_new / norm_v 176 | else: 177 | w_v = w_new 178 | 179 | l_f = torch.sum(cost_v * w_v) 180 | 181 | 182 | optimizer_a.zero_grad() 183 | l_f.backward() 184 | optimizer_a.step() 185 | 186 | meta_l = smoothing_alpha * meta_l + (1 - smoothing_alpha) * l_g_meta.item() 187 | meta_model_loss.append(meta_l / (1 - smoothing_alpha ** (iters + 1))) 188 | 189 | net_l = smoothing_alpha * net_l + (1 - smoothing_alpha) * l_f.item() 190 | model_loss.append(net_l / (1 - smoothing_alpha ** (iters + 1))) 191 | 192 | 193 | if (iters + 1) % 100 == 0: 194 | print('Epoch: [%d/%d]\t' 195 | 'Iters: [%d/%d]\t' 196 | 'Loss: %.4f\t' 197 | 'MetaLoss:%.4f\t' 198 | 'Prec@1 %.2f\t' 199 | 'Prec_meta@1 %.2f' % ( 200 | (iters + 1) // 500 + 1, args.epochs, iters + 1, args.iters, model_loss[iters], 201 | meta_model_loss[iters], prec_train, prec_meta)) 202 | 203 | losses_test = AverageMeter() 204 | top1_test = AverageMeter() 205 | model.eval() 206 | 207 | 208 | for i, (input_test, target_test) in enumerate(test_loader): 209 | input_test_var = to_var(input_test, requires_grad=False) 210 | target_test_var = to_var(target_test, requires_grad=False) 211 | 212 | # compute output 213 | with torch.no_grad(): 214 | output_test = model(input_test_var) 215 | loss_test = criterion(output_test, target_test_var) 216 | prec_test = accuracy(output_test.data, target_test_var.data, topk=(1,))[0] 217 | 218 | losses_test.update(loss_test.data.item(), input_test_var.size(0)) 219 | top1_test.update(prec_test.item(), input_test_var.size(0)) 220 | 221 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1_test)) 222 | 223 | accuracy_log.append(np.array([iters, top1_test.avg])[None]) 224 | train_acc.append(np.array([iters, prec_train])[None]) 225 | 226 | best_prec1 = max(top1_test.avg, best_prec1) 227 | 228 | #np.save('meta_model_loss_%s_%s.npy' % (args.dataset, args.label_corrupt_prob), meta_model_loss) 229 | #np.save('model_loss_%s_%s.npy' % (args.dataset, args.label_corrupt_prob), model_loss) 230 | fig, axes = plt.subplots(1, 3, figsize=(13, 5)) 231 | ax1, ax2, ax3 = axes.ravel() 232 | 233 | ax1.plot(meta_model_loss, label='meta_model_loss') 234 | ax1.plot(model_loss, label='model_loss') 235 | ax1.set_ylabel("Losses") 236 | ax1.set_xlabel("Iteration") 237 | ax1.legend() 238 | 239 | acc_log = np.concatenate(accuracy_log, axis=0) 240 | train_acc_log = np.concatenate(train_acc, axis=0) 241 | #np.save('L2SPL_train_acc.npy', train_acc_log) 242 | #np.save('L2SPL_val_acc.npy', acc_log) 243 | # lr_log = np.concatenate(lr_log, axis=0) 244 | 245 | ax2.plot(acc_log[:, 0], acc_log[:, 1]) 246 | ax2.set_ylabel('Accuracy') 247 | ax2.set_xlabel('Iteration') 248 | 249 | ax3.plot(train_acc_log[:, 0], train_acc_log[:, 1]) 250 | ax3.set_ylabel('Accuracy') 251 | ax3.set_xlabel('Iteration') 252 | 253 | plt.show() 254 | 255 | 256 | 257 | def build_dataset(): 258 | kwargs = {'num_workers': 0, 'pin_memory': True} 259 | # assert (args.dataset == 'cifar10' or args.dataset == 'cifar100') 260 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 261 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 262 | if args.augment: 263 | train_transform = transforms.Compose([ 264 | transforms.ToTensor(), 265 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 266 | (4, 4, 4, 4), mode='reflect').squeeze()), 267 | transforms.ToPILImage(), 268 | transforms.RandomCrop(32), 269 | transforms.RandomHorizontalFlip(), 270 | transforms.ToTensor(), 271 | normalize, 272 | ]) 273 | else: 274 | train_transform = transforms.Compose([ 275 | transforms.ToTensor(), 276 | normalize, 277 | ]) 278 | test_transform = transforms.Compose([ 279 | transforms.ToTensor(), 280 | normalize 281 | ]) 282 | 283 | if args.dataset == 'cifar10': 284 | train_data_meta = CIFAR10( 285 | root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 286 | corruption_type=args.corruption_type, transform=train_transform, download=True) 287 | train_data = CIFAR10( 288 | root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 289 | corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) 290 | test_data = CIFAR10(root='../data', train=False, transform=test_transform, download=True) 291 | 292 | 293 | elif args.dataset == 'cifar100': 294 | train_data_meta = CIFAR100( 295 | root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 296 | corruption_type=args.corruption_type, transform=train_transform, download=True) 297 | train_data = CIFAR100( 298 | root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob, 299 | corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed) 300 | test_data = CIFAR100(root='../data', train=False, transform=test_transform, download=True) 301 | 302 | 303 | train_loader = torch.utils.data.DataLoader( 304 | train_data, batch_size=args.batch_size, shuffle=True, 305 | num_workers=args.prefetch, pin_memory=True) 306 | train_meta_loader = torch.utils.data.DataLoader( 307 | train_data_meta, batch_size=args.batch_size, shuffle=True, 308 | num_workers=args.prefetch, pin_memory=True) 309 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, 310 | num_workers=args.prefetch, pin_memory=True) 311 | 312 | return train_loader, train_meta_loader, test_loader 313 | 314 | 315 | def build_model(): 316 | # model = ResNet32(args.dataset == 'cifar10' and 10 or 100) 317 | model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100, 318 | args.widen_factor, dropRate=args.droprate) 319 | # weights_init(model) 320 | 321 | # print('Number of model parameters: {}'.format( 322 | # sum([p.data.nelement() for p in model.params()]))) 323 | 324 | if torch.cuda.is_available(): 325 | model.cuda() 326 | torch.backends.cudnn.benchmark = True 327 | 328 | return model 329 | 330 | 331 | 332 | 333 | def to_var(x, requires_grad=True): 334 | if torch.cuda.is_available(): 335 | x = x.cuda() 336 | return Variable(x, requires_grad=requires_grad) 337 | 338 | 339 | def adjust_learning_rate(optimizer, iters): 340 | 341 | lr = args.lr * ((0.1 ** int(iters >= 18000)) * (0.1 ** int(iters >= 19000))) # For WRN-28-10 342 | #lr = args.lr * ((0.1 ** int(iters >= 20000)) * (0.1 ** int(iters >= 25000))) # For ResNet32 343 | # log to TensorBoard 344 | for param_group in optimizer.param_groups: 345 | param_group['lr'] = lr 346 | 347 | 348 | class AverageMeter(object): 349 | """Computes and stores the average and current value""" 350 | 351 | def __init__(self): 352 | self.reset() 353 | 354 | def reset(self): 355 | self.val = 0 356 | self.avg = 0 357 | self.sum = 0 358 | self.count = 0 359 | 360 | def update(self, val, n=1): 361 | self.val = val 362 | self.sum += val * n 363 | self.count += n 364 | self.avg = self.sum / self.count 365 | 366 | 367 | def accuracy(output, target, topk=(1,)): 368 | """Computes the precision@k for the specified values of k""" 369 | maxk = max(topk) 370 | batch_size = target.size(0) 371 | 372 | _, pred = output.topk(maxk, 1, True, True) 373 | pred = pred.t() 374 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 375 | 376 | res = [] 377 | for k in topk: 378 | correct_k = correct[:k].view(-1).float().sum(0) 379 | res.append(correct_k.mul_(100.0 / batch_size)) 380 | return res 381 | 382 | 383 | 384 | if __name__ == '__main__': 385 | main() 386 | -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.init as init 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class MetaModule(nn.Module): 16 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | # name_s, param_s = src 52 | # grad = param_s.grad 53 | # name_s, param_s = src 54 | grad = src 55 | if first_order: 56 | grad = to_var(grad.detach().data) 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | else: 60 | 61 | for name, param in self.named_params(self): 62 | if not detach: 63 | grad = param.grad 64 | if first_order: 65 | grad = to_var(grad.detach().data) 66 | tmp = param - lr_inner * grad 67 | self.set_param(self, name, tmp) 68 | else: 69 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 70 | self.set_param(self, name, param) 71 | 72 | def set_param(self, curr_mod, name, param): 73 | if '.' in name: 74 | n = name.split('.') 75 | module_name = n[0] 76 | rest = '.'.join(n[1:]) 77 | for name, mod in curr_mod.named_children(): 78 | if module_name == name: 79 | self.set_param(mod, rest, param) 80 | break 81 | else: 82 | setattr(curr_mod, name, param) 83 | 84 | def detach_params(self): 85 | for name, param in self.named_params(self): 86 | self.set_param(self, name, param.detach()) 87 | 88 | def copy(self, other, same_var=False): 89 | for name, param in other.named_params(): 90 | if not same_var: 91 | param = to_var(param.data.clone(), requires_grad=True) 92 | self.set_param(name, param) 93 | 94 | 95 | class MetaLinear(MetaModule): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__() 98 | ignore = nn.Linear(*args, **kwargs) 99 | 100 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 101 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 102 | 103 | def forward(self, x): 104 | return F.linear(x, self.weight, self.bias) 105 | 106 | def named_leaves(self): 107 | return [('weight', self.weight), ('bias', self.bias)] 108 | 109 | 110 | class MetaConv2d(MetaModule): 111 | def __init__(self, *args, **kwargs): 112 | super().__init__() 113 | ignore = nn.Conv2d(*args, **kwargs) 114 | 115 | self.in_channels = ignore.in_channels 116 | self.out_channels = ignore.out_channels 117 | self.stride = ignore.stride 118 | self.padding = ignore.padding 119 | self.dilation = ignore.dilation 120 | self.groups = ignore.groups 121 | self.kernel_size = ignore.kernel_size 122 | 123 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 124 | 125 | if ignore.bias is not None: 126 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 127 | else: 128 | self.register_buffer('bias', None) 129 | 130 | def forward(self, x): 131 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 132 | 133 | def named_leaves(self): 134 | return [('weight', self.weight), ('bias', self.bias)] 135 | 136 | 137 | class MetaConvTranspose2d(MetaModule): 138 | def __init__(self, *args, **kwargs): 139 | super().__init__() 140 | ignore = nn.ConvTranspose2d(*args, **kwargs) 141 | 142 | self.stride = ignore.stride 143 | self.padding = ignore.padding 144 | self.dilation = ignore.dilation 145 | self.groups = ignore.groups 146 | 147 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 148 | 149 | if ignore.bias is not None: 150 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 151 | else: 152 | self.register_buffer('bias', None) 153 | 154 | def forward(self, x, output_size=None): 155 | output_padding = self._output_padding(x, output_size) 156 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 157 | output_padding, self.groups, self.dilation) 158 | 159 | def named_leaves(self): 160 | return [('weight', self.weight), ('bias', self.bias)] 161 | 162 | 163 | class MetaBatchNorm2d(MetaModule): 164 | def __init__(self, *args, **kwargs): 165 | super().__init__() 166 | ignore = nn.BatchNorm2d(*args, **kwargs) 167 | 168 | self.num_features = ignore.num_features 169 | self.eps = ignore.eps 170 | self.momentum = ignore.momentum 171 | self.affine = ignore.affine 172 | self.track_running_stats = ignore.track_running_stats 173 | 174 | if self.affine: 175 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 176 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 177 | 178 | if self.track_running_stats: 179 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 180 | self.register_buffer('running_var', torch.ones(self.num_features)) 181 | else: 182 | self.register_parameter('running_mean', None) 183 | self.register_parameter('running_var', None) 184 | 185 | def forward(self, x): 186 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 187 | self.training or not self.track_running_stats, self.momentum, self.eps) 188 | 189 | def named_leaves(self): 190 | return [('weight', self.weight), ('bias', self.bias)] 191 | 192 | 193 | class MetaBasicBlock(MetaModule): 194 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 195 | super(MetaBasicBlock, self).__init__() 196 | 197 | self.bn1 = MetaBatchNorm2d(in_planes) 198 | self.relu1 = nn.ReLU(inplace=True) 199 | self.conv1 = MetaConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 200 | padding=1, bias=False) 201 | self.bn2 = MetaBatchNorm2d(out_planes) 202 | self.relu2 = nn.ReLU(inplace=True) 203 | self.conv2 = MetaConv2d(out_planes, out_planes, kernel_size=3, stride=1, 204 | padding=1, bias=False) 205 | self.droprate = dropRate 206 | self.equalInOut = (in_planes == out_planes) 207 | self.convShortcut = (not self.equalInOut) and MetaConv2d(in_planes, out_planes, kernel_size=1, stride=stride, 208 | padding=0, bias=False) or None 209 | def forward(self, x): 210 | if not self.equalInOut: 211 | x = self.relu1(self.bn1(x)) 212 | else: 213 | out = self.relu1(self.bn1(x)) 214 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 215 | if self.droprate > 0: 216 | out = F.dropout(out, p=self.droprate, training=self.training) 217 | out = self.conv2(out) 218 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 219 | 220 | 221 | class MetaNetworkBlock(MetaModule): 222 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 223 | super(MetaNetworkBlock, self).__init__() 224 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 225 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 226 | layers = [] 227 | for i in range(int(nb_layers)): 228 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 229 | return nn.Sequential(*layers) 230 | def forward(self, x): 231 | return self.layer(x) 232 | 233 | class WideResNet(MetaModule): 234 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 235 | super(WideResNet, self).__init__() 236 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 237 | assert((depth - 4) % 6 == 0) 238 | n = (depth - 4) / 6 239 | block = MetaBasicBlock 240 | # 1st conv before any network block 241 | self.conv1 = MetaConv2d(3, nChannels[0], kernel_size=3, stride=1, 242 | padding=1, bias=False) 243 | # 1st block 244 | self.block1 = MetaNetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 245 | # 2nd block 246 | self.block2 = MetaNetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 247 | # 3rd block 248 | self.block3 = MetaNetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 249 | # global average pooling and classifier 250 | self.bn1 = MetaBatchNorm2d(nChannels[3]) 251 | self.relu = nn.ReLU(inplace=True) 252 | self.fc = MetaLinear(nChannels[3], num_classes) 253 | self.nChannels = nChannels[3] 254 | 255 | for m in self.modules(): 256 | if isinstance(m, MetaConv2d): 257 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 258 | m.weight.data.normal_(0, math.sqrt(2. / n)) 259 | elif isinstance(m, MetaBatchNorm2d): 260 | m.weight.data.fill_(1) 261 | m.bias.data.zero_() 262 | elif isinstance(m, MetaLinear): 263 | m.bias.data.zero_() 264 | def forward(self, x): 265 | out = self.conv1(x) 266 | out = self.block1(out) 267 | out = self.block2(out) 268 | out = self.block3(out) 269 | out = self.relu(self.bn1(out)) 270 | out = F.avg_pool2d(out, 8) 271 | out = out.view(-1, self.nChannels) 272 | return self.fc(out) 273 | 274 | 275 | 276 | class VNet(MetaModule): 277 | def __init__(self, input, hidden, output): 278 | super(VNet, self).__init__() 279 | self.linear1 = MetaLinear(input, hidden) 280 | self.relu = nn.ReLU(inplace=True) 281 | self.linear2 = MetaLinear(hidden, output) 282 | 283 | 284 | 285 | def forward(self, x): 286 | x = self.linear1(x) 287 | x = self.relu(x) 288 | out = self.linear2(x) 289 | return F.sigmoid(out) 290 | --------------------------------------------------------------------------------