├── LICENSE ├── README.md ├── main.py ├── meta.py ├── model.py ├── noisy_long_tail_CIFAR.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ShiYunyi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Weight-Net_Code-Optimization 2 | A new code framework that uses pytorch to implement meta-learning, and takes Meta-Weight-Net as an example. 3 | 4 | --- 5 | 6 | By using a trick, meta-learning and meta-networks have become plug-and-play. We can now apply the meta learning 7 | algorithm directly to the existing pytorch model without rewriting it. 8 | 9 | This code takes Meta-Weight-Net ([Meta-Weight-Net: Learning an Explicit Mapping For Sample Weighting](https://arxiv.org/abs/1902.07379)) 10 | as an example to show how to use this trick. It rewrites an optimizer to assign non leaf node tensors to model parameters. 11 | See `meta.py` and line 90-120 of `main.py` for details. 12 | ## Environment 13 | - python 3.8 14 | - pytorch 1.9.0 15 | - torchvision 0.10.0 16 | 17 | `noisy_long_tail_CIFAR.py` can generate noisy and long-tailed CIFAR datasets by calling `torchvision.datasets`. Because 18 | some class attributes' names have been changed, errors may occur in some earlier versions of torchvision. It can be solved by 19 | changing the corresponding attribute name. 20 | ## Running this example 21 | ResNet32 on CIFAR10-LT with imbalanced factor of 50: 22 | ``` 23 | python main.py --imbalanced_factor 50 24 | ``` 25 | ResNet32 on CIFAR10 with 40% uniform noise: 26 | ``` 27 | python main.py --meta_lr 1e-3 --meta_weight_decay 1e-4 --corruption_type uniform --corruption_ratio 0.4 28 | ``` 29 | ## Resuilt(CIFAR10) 30 | |Data Setting|Test Accuracy| 31 | |:----------:|:-----------:| 32 | |imbalanced factor 50|80.43%| 33 | |imbalanced factor 100|75.92%| 34 | |imbalanced factor 200|68.89%| 35 | |40% uniform noise|87.83%| 36 | ## Acknowledgements 37 | Thanks to the original code of Meta-Weight-Net (https://github.com/xjtushujun/meta-weight-net). 38 | 39 | Contact: Shi Yunyi (2404208668@qq.com) 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.optim 3 | # from torch.utils.tensorboard import SummaryWriter 4 | from meta import * 5 | from model import * 6 | from noisy_long_tail_CIFAR import * 7 | from utils import * 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Meta_Weight_Net') 11 | parser.add_argument('--device', type=str, default='cuda') 12 | parser.add_argument('--seed', type=int, default=1) 13 | parser.add_argument('--meta_net_hidden_size', type=int, default=100) 14 | parser.add_argument('--meta_net_num_layers', type=int, default=1) 15 | 16 | parser.add_argument('--lr', type=float, default=.1) 17 | parser.add_argument('--momentum', type=float, default=.9) 18 | parser.add_argument('--dampening', type=float, default=0.) 19 | parser.add_argument('--nesterov', type=bool, default=False) 20 | parser.add_argument('--weight_decay', type=float, default=5e-4) 21 | parser.add_argument('--meta_lr', type=float, default=1e-5) 22 | parser.add_argument('--meta_weight_decay', type=float, default=0.) 23 | 24 | parser.add_argument('--dataset', type=str, default='cifar10') 25 | parser.add_argument('--num_meta', type=int, default=1000) 26 | parser.add_argument('--imbalanced_factor', type=int, default=None) 27 | parser.add_argument('--corruption_type', type=str, default=None) 28 | parser.add_argument('--corruption_ratio', type=float, default=0.) 29 | parser.add_argument('--batch_size', type=int, default=100) 30 | parser.add_argument('--max_epoch', type=int, default=120) 31 | 32 | parser.add_argument('--meta_interval', type=int, default=1) 33 | parser.add_argument('--paint_interval', type=int, default=20) 34 | 35 | args = parser.parse_args() 36 | print(args) 37 | 38 | 39 | def meta_weight_net(): 40 | set_cudnn(device=args.device) 41 | set_seed(seed=args.seed) 42 | # writer = SummaryWriter(log_dir='.\\mwn') 43 | 44 | meta_net = MLP(hidden_size=args.meta_net_hidden_size, num_layers=args.meta_net_num_layers).to(device=args.device) 45 | net = ResNet32(args.dataset == 'cifar10' and 10 or 100).to(device=args.device) 46 | 47 | criterion = nn.CrossEntropyLoss().to(device=args.device) 48 | 49 | optimizer = torch.optim.SGD( 50 | net.parameters(), 51 | lr=args.lr, 52 | momentum=args.momentum, 53 | dampening=args.dampening, 54 | weight_decay=args.weight_decay, 55 | nesterov=args.nesterov, 56 | ) 57 | meta_optimizer = torch.optim.Adam(meta_net.parameters(), lr=args.meta_lr, weight_decay=args.meta_weight_decay) 58 | lr = args.lr 59 | 60 | train_dataloader, meta_dataloader, test_dataloader, imbalanced_num_list = build_dataloader( 61 | seed=args.seed, 62 | dataset=args.dataset, 63 | num_meta_total=args.num_meta, 64 | imbalanced_factor=args.imbalanced_factor, 65 | corruption_type=args.corruption_type, 66 | corruption_ratio=args.corruption_ratio, 67 | batch_size=args.batch_size, 68 | ) 69 | 70 | meta_dataloader_iter = iter(meta_dataloader) 71 | # with torch.no_grad(): 72 | # for point in range(500): 73 | # x = torch.tensor(point / 10).unsqueeze(0).to(args.device) 74 | # fx = meta_net(x) 75 | # writer.add_scalar('Initial Meta Net', fx, point) 76 | 77 | for epoch in range(args.max_epoch): 78 | 79 | if epoch >= 80 and epoch % 20 == 0: 80 | lr = lr / 10 81 | for group in optimizer.param_groups: 82 | group['lr'] = lr 83 | 84 | print('Training...') 85 | for iteration, (inputs, labels) in enumerate(train_dataloader): 86 | net.train() 87 | inputs, labels = inputs.to(args.device), labels.to(args.device) 88 | 89 | if (iteration + 1) % args.meta_interval == 0: 90 | pseudo_net = ResNet32(args.dataset == 'cifar10' and 10 or 100).to(args.device) 91 | pseudo_net.load_state_dict(net.state_dict()) 92 | pseudo_net.train() 93 | 94 | pseudo_outputs = pseudo_net(inputs) 95 | pseudo_loss_vector = functional.cross_entropy(pseudo_outputs, labels.long(), reduction='none') 96 | pseudo_loss_vector_reshape = torch.reshape(pseudo_loss_vector, (-1, 1)) 97 | pseudo_weight = meta_net(pseudo_loss_vector_reshape.data) 98 | pseudo_loss = torch.mean(pseudo_weight * pseudo_loss_vector_reshape) 99 | 100 | pseudo_grads = torch.autograd.grad(pseudo_loss, pseudo_net.parameters(), create_graph=True) 101 | 102 | pseudo_optimizer = MetaSGD(pseudo_net, pseudo_net.parameters(), lr=lr) 103 | pseudo_optimizer.load_state_dict(optimizer.state_dict()) 104 | pseudo_optimizer.meta_step(pseudo_grads) 105 | 106 | del pseudo_grads 107 | 108 | try: 109 | meta_inputs, meta_labels = next(meta_dataloader_iter) 110 | except StopIteration: 111 | meta_dataloader_iter = iter(meta_dataloader) 112 | meta_inputs, meta_labels = next(meta_dataloader_iter) 113 | 114 | meta_inputs, meta_labels = meta_inputs.to(args.device), meta_labels.to(args.device) 115 | meta_outputs = pseudo_net(meta_inputs) 116 | meta_loss = criterion(meta_outputs, meta_labels.long()) 117 | 118 | meta_optimizer.zero_grad() 119 | meta_loss.backward() 120 | meta_optimizer.step() 121 | 122 | outputs = net(inputs) 123 | loss_vector = functional.cross_entropy(outputs, labels.long(), reduction='none') 124 | loss_vector_reshape = torch.reshape(loss_vector, (-1, 1)) 125 | 126 | with torch.no_grad(): 127 | weight = meta_net(loss_vector_reshape) 128 | 129 | loss = torch.mean(weight * loss_vector_reshape) 130 | 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | 135 | print('Computing Test Result...') 136 | test_loss, test_accuracy = compute_loss_accuracy( 137 | net=net, 138 | data_loader=test_dataloader, 139 | criterion=criterion, 140 | device=args.device, 141 | ) 142 | # writer.add_scalar('Loss', test_loss, epoch) 143 | # writer.add_scalar('Accuracy', test_accuracy, epoch) 144 | 145 | print('Epoch: {}, (Loss, Accuracy) Test: ({:.4f}, {:.2%}) LR: {}'.format( 146 | epoch, 147 | test_loss, 148 | test_accuracy, 149 | lr, 150 | )) 151 | 152 | # if (epoch + 1) % args.paint_interval == 0: 153 | # with torch.no_grad(): 154 | # for point in range(500): 155 | # x = torch.tensor(point / 10).unsqueeze(0).to(args.device) 156 | # fx = meta_net(x) 157 | # writer.add_scalar('Meta Net of Epoch {}'.format(epoch), fx, point) 158 | 159 | # writer.close() 160 | 161 | 162 | if __name__ == '__main__': 163 | meta_weight_net() 164 | -------------------------------------------------------------------------------- /meta.py: -------------------------------------------------------------------------------- 1 | from torch.optim.sgd import SGD 2 | 3 | 4 | class MetaSGD(SGD): 5 | def __init__(self, net, *args, **kwargs): 6 | super(MetaSGD, self).__init__(*args, **kwargs) 7 | self.net = net 8 | 9 | def set_parameter(self, current_module, name, parameters): 10 | if '.' in name: 11 | name_split = name.split('.') 12 | module_name = name_split[0] 13 | rest_name = '.'.join(name_split[1:]) 14 | for children_name, children in current_module.named_children(): 15 | if module_name == children_name: 16 | self.set_parameter(children, rest_name, parameters) 17 | break 18 | else: 19 | current_module._parameters[name] = parameters 20 | 21 | def meta_step(self, grads): 22 | group = self.param_groups[0] 23 | weight_decay = group['weight_decay'] 24 | momentum = group['momentum'] 25 | dampening = group['dampening'] 26 | nesterov = group['nesterov'] 27 | lr = group['lr'] 28 | 29 | for (name, parameter), grad in zip(self.net.named_parameters(), grads): 30 | parameter.detach_() 31 | if weight_decay != 0: 32 | grad_wd = grad.add(parameter, alpha=weight_decay) 33 | else: 34 | grad_wd = grad 35 | if momentum != 0 and 'momentum_buffer' in self.state[parameter]: 36 | buffer = self.state[parameter]['momentum_buffer'] 37 | grad_b = buffer.mul(momentum).add(grad_wd, alpha=1-dampening) 38 | else: 39 | grad_b = grad_wd 40 | if nesterov: 41 | grad_n = grad_wd.add(grad_b, alpha=momentum) 42 | else: 43 | grad_n = grad_b 44 | self.set_parameter(self.net, name, parameter.add(grad_n, alpha=-lr)) 45 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | import torch.nn.init as init 5 | 6 | 7 | def _weights_init(m): 8 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 9 | init.kaiming_normal_(m.weight) 10 | 11 | 12 | class LambdaLayer(nn.Module): 13 | def __init__(self, lambd): 14 | super(LambdaLayer, self).__init__() 15 | self.lambd = lambd 16 | 17 | def forward(self, x): 18 | return self.lambd(x) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, in_planes, planes, stride=1, option='A'): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | if option == 'A': 34 | """ 35 | For CIFAR10 ResNet paper uses option A. 36 | """ 37 | self.shortcut = LambdaLayer(lambda x: 38 | functional.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 39 | elif option == 'B': 40 | self.shortcut = nn.Sequential( 41 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 42 | nn.BatchNorm2d(self.expansion * planes) 43 | ) 44 | 45 | def forward(self, x): 46 | out = functional.relu(self.bn1(self.conv1(x))) 47 | out = self.bn2(self.conv2(out)) 48 | out += self.shortcut(x) 49 | out = functional.relu(out) 50 | return out 51 | 52 | 53 | class ResNet32(nn.Module): 54 | def __init__(self, num_classes=10, block=BasicBlock, num_blocks=[5, 5, 5]): 55 | super(ResNet32, self).__init__() 56 | self.in_planes = 16 57 | 58 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(16) 60 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 61 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 62 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 63 | self.linear = nn.Linear(64, num_classes) 64 | 65 | self.apply(_weights_init) 66 | 67 | def _make_layer(self, block, planes, num_blocks, stride): 68 | strides = [stride] + [1]*(num_blocks-1) 69 | layers = [] 70 | for stride in strides: 71 | layers.append(block(self.in_planes, planes, stride)) 72 | self.in_planes = planes * block.expansion 73 | 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | out = functional.relu(self.bn1(self.conv1(x))) 78 | out = self.layer1(out) 79 | out = self.layer2(out) 80 | out = self.layer3(out) 81 | out = functional.avg_pool2d(out, out.size()[3]) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | 87 | class HiddenLayer(nn.Module): 88 | def __init__(self, input_size, output_size): 89 | super(HiddenLayer, self).__init__() 90 | self.fc = nn.Linear(input_size, output_size) 91 | self.relu = nn.ReLU() 92 | 93 | def forward(self, x): 94 | return self.relu(self.fc(x)) 95 | 96 | 97 | class MLP(nn.Module): 98 | def __init__(self, hidden_size=100, num_layers=1): 99 | super(MLP, self).__init__() 100 | self.first_hidden_layer = HiddenLayer(1, hidden_size) 101 | self.rest_hidden_layers = nn.Sequential(*[HiddenLayer(hidden_size, hidden_size) for _ in range(num_layers - 1)]) 102 | self.output_layer = nn.Linear(hidden_size, 1) 103 | 104 | def forward(self, x): 105 | x = self.first_hidden_layer(x) 106 | x = self.rest_hidden_layers(x) 107 | x = self.output_layer(x) 108 | return torch.sigmoid(x) 109 | -------------------------------------------------------------------------------- /noisy_long_tail_CIFAR.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torchvision.datasets 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | def uniform_corruption(corruption_ratio, num_classes): 9 | eye = np.eye(num_classes) 10 | noise = np.full((num_classes, num_classes), 1/num_classes) 11 | corruption_matrix = eye * (1 - corruption_ratio) + noise * corruption_ratio 12 | return corruption_matrix 13 | 14 | 15 | def flip1_corruption(corruption_ratio, num_classes): 16 | corruption_matrix = np.eye(num_classes) * (1 - corruption_ratio) 17 | row_indices = np.arange(num_classes) 18 | for i in range(num_classes): 19 | corruption_matrix[i][np.random.choice(row_indices[row_indices != i])] = corruption_ratio 20 | return corruption_matrix 21 | 22 | 23 | def flip2_corruption(corruption_ratio, num_classes): 24 | corruption_matrix = np.eye(num_classes) * (1 - corruption_ratio) 25 | row_indices = np.arange(num_classes) 26 | for i in range(num_classes): 27 | corruption_matrix[i][np.random.choice(row_indices[row_indices != i], 2, replace=False)] = corruption_ratio / 2 28 | return corruption_matrix 29 | 30 | 31 | def build_dataloader( 32 | seed=1, 33 | dataset='cifar10', 34 | num_meta_total=1000, 35 | imbalanced_factor=None, 36 | corruption_type=None, 37 | corruption_ratio=0., 38 | batch_size=100, 39 | ): 40 | 41 | np.random.seed(seed) 42 | normalize = transforms.Normalize( 43 | mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 44 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]], 45 | ) 46 | 47 | train_transforms = transforms.Compose([ 48 | transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | normalize, 52 | ]) 53 | 54 | test_transforms = transforms.Compose([ 55 | transforms.ToTensor(), 56 | normalize, 57 | ]) 58 | 59 | dataset_list = { 60 | 'cifar10': torchvision.datasets.CIFAR10, 61 | 'cifar100': torchvision.datasets.CIFAR100, 62 | } 63 | 64 | corruption_list = { 65 | 'uniform': uniform_corruption, 66 | 'flip1': flip1_corruption, 67 | 'flip2': flip2_corruption, 68 | } 69 | 70 | train_dataset = dataset_list[dataset](root='../data', train=True, download=True, transform=train_transforms) 71 | test_dataset = dataset_list[dataset](root='../data', train=False, transform=test_transforms) 72 | 73 | num_classes = len(train_dataset.classes) 74 | num_meta = int(num_meta_total / num_classes) 75 | 76 | index_to_meta = [] 77 | index_to_train = [] 78 | 79 | if imbalanced_factor is not None: 80 | imbalanced_num_list = [] 81 | sample_num = int((len(train_dataset.targets) - num_meta_total) / num_classes) 82 | for class_index in range(num_classes): 83 | imbalanced_num = sample_num / (imbalanced_factor ** (class_index / (num_classes - 1))) 84 | imbalanced_num_list.append(int(imbalanced_num)) 85 | np.random.shuffle(imbalanced_num_list) 86 | print(imbalanced_num_list) 87 | else: 88 | imbalanced_num_list = None 89 | 90 | for class_index in range(num_classes): 91 | index_to_class = [index for index, label in enumerate(train_dataset.targets) if label == class_index] 92 | np.random.shuffle(index_to_class) 93 | index_to_meta.extend(index_to_class[:num_meta]) 94 | index_to_class_for_train = index_to_class[num_meta:] 95 | 96 | if imbalanced_num_list is not None: 97 | index_to_class_for_train = index_to_class_for_train[:imbalanced_num_list[class_index]] 98 | 99 | index_to_train.extend(index_to_class_for_train) 100 | 101 | meta_dataset = copy.deepcopy(train_dataset) 102 | train_dataset.data = train_dataset.data[index_to_train] 103 | train_dataset.targets = list(np.array(train_dataset.targets)[index_to_train]) 104 | meta_dataset.data = meta_dataset.data[index_to_meta] 105 | meta_dataset.targets = list(np.array(meta_dataset.targets)[index_to_meta]) 106 | 107 | if corruption_type is not None: 108 | corruption_matrix = corruption_list[corruption_type](corruption_ratio, num_classes) 109 | print(corruption_matrix) 110 | for index in range(len(train_dataset.targets)): 111 | p = corruption_matrix[train_dataset.targets[index]] 112 | train_dataset.targets[index] = np.random.choice(num_classes, p=p) 113 | 114 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True) 115 | meta_dataloader = DataLoader(meta_dataset, batch_size=batch_size, shuffle=True, pin_memory=True) 116 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True) 117 | 118 | return train_dataloader, meta_dataloader, test_dataloader, imbalanced_num_list 119 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from time import sleep 5 | 6 | 7 | def set_cudnn(device='cuda'): 8 | torch.backends.cudnn.enabled = (device == 'cuda') 9 | torch.backends.cudnn.benchmark = (device == 'cuda') 10 | 11 | 12 | def set_seed(seed=1): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | 18 | 19 | def stop_epoch(time=3): 20 | try: 21 | print('can break now') 22 | for i in range(time): 23 | sleep(1) 24 | print('wait for next epoch') 25 | return False 26 | except KeyboardInterrupt: 27 | return True 28 | 29 | 30 | def compute_loss_accuracy(net, data_loader, criterion, device): 31 | net.eval() 32 | correct = 0 33 | total_loss = 0. 34 | 35 | with torch.no_grad(): 36 | for batch_idx, (inputs, labels) in enumerate(data_loader): 37 | inputs, labels = inputs.to(device), labels.to(device) 38 | outputs = net(inputs) 39 | total_loss += criterion(outputs, labels).item() 40 | _, pred = outputs.max(1) 41 | correct += pred.eq(labels).sum().item() 42 | 43 | return total_loss / (batch_idx + 1), correct / len(data_loader.dataset) 44 | --------------------------------------------------------------------------------