├── README.md ├── Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data.pdf ├── load_dataset.py ├── train.py ├── transform.py └── wideresnet.py /README.md: -------------------------------------------------------------------------------- 1 | # DS3L 2 | This is the code for paper "Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data" published in ICML 2020. 3 | 4 | # Setups 5 | 6 | The code is implemented with Python and Pytorch. 7 | 8 | # Running D3SL for benchmark datasets 9 | 10 | Here is an example: 11 | 12 | ```bash 13 | python train.py --dataset MNIST --ratio 0.6 --n_labels 60 --iterations 200000 14 | ``` 15 | 16 | # Acknowledgements 17 | We thank the Pytorch implementation on Meta-Net (https://github.com/xjtushujun/meta-weight-ne) and learning-to-reweight-examples(https://github.com/danieltan07/learning-to-reweight-examples). 18 | 19 | 20 | # Contact 21 | If you have any questions, please contact Lan-Zhe Guo (guolz@lamda.nju.edu.cn). 22 | -------------------------------------------------------------------------------- /Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guolz-ml/DS3L/ede547cfa8a62f1e53bbe5a7ba1d4e1410cb4c5e/Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data.pdf -------------------------------------------------------------------------------- /load_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data import Dataset 4 | from torch.utils.data import Sampler 5 | from torchvision import datasets 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as tv_transforms 9 | 10 | COUNTS = { 11 | "svhn": {"train": 73257, "test": 26032, "valid": 7326}, 12 | "cifar10": {"train": 50000, "test": 10000, "valid": 5000}, 13 | "imagenet_32": { 14 | "train": 1281167, 15 | "test": 50000, 16 | "valid": 50050 17 | }, 18 | } 19 | rng = np.random.RandomState(seed=1) 20 | 21 | class SimpleDataset(Dataset): 22 | def __init__(self, dataset, transform=True): 23 | self.dataset=dataset 24 | self.transform=transform 25 | 26 | def __getitem__(self, index): 27 | image = self.dataset['images'][index] 28 | label = self.dataset['labels'][index] 29 | if(self.transform): 30 | image = (image / 255. - 0.5) / 0.5 31 | return image, label, index 32 | 33 | def __len__(self): 34 | return len(self.dataset['images']) 35 | 36 | class RandomSampler(Sampler): 37 | """ sampling without replacement """ 38 | def __init__(self, num_data, num_sample): 39 | iterations = num_sample // num_data + 1 40 | self.indices = torch.cat([torch.randperm(num_data) for _ in range(iterations)]).tolist()[:num_sample] 41 | 42 | def __iter__(self): 43 | return iter(self.indices) 44 | 45 | def __len__(self): 46 | return len(self.indices) 47 | 48 | data_path = "./data" 49 | 50 | def split_l_u(train_set, n_labels, n_unlabels, tot_class=6, ratio = 0.5): 51 | # NOTE: this function assume that train_set is shuffled. 52 | images = train_set["images"] 53 | labels = train_set["labels"] 54 | classes = np.unique(labels) 55 | n_labels_per_cls = n_labels // tot_class 56 | n_unlabels_per_cls = int(n_unlabels*(1.0-ratio)) // tot_class 57 | if(tot_class < len(classes)): 58 | n_unlabels_shift = (n_unlabels - (n_unlabels_per_cls * tot_class)) // (len(classes) - tot_class) 59 | l_images = [] 60 | l_labels = [] 61 | u_images = [] 62 | u_labels = [] 63 | for c in classes[:tot_class]: 64 | cls_mask = (labels == c) 65 | c_images = images[cls_mask] 66 | c_labels = labels[cls_mask] 67 | l_images += [c_images[:n_labels_per_cls]] 68 | l_labels += [c_labels[:n_labels_per_cls]] 69 | u_images += [c_images[n_labels_per_cls:n_labels_per_cls+n_unlabels_per_cls]] 70 | u_labels += [c_labels[n_labels_per_cls:n_labels_per_cls+n_unlabels_per_cls]] 71 | for c in classes[tot_class:]: 72 | cls_mask = (labels == c) 73 | c_images = images[cls_mask] 74 | c_labels = labels[cls_mask] 75 | u_images += [c_images[:n_unlabels_shift]] 76 | u_labels += [c_labels[:n_unlabels_shift]] 77 | 78 | l_train_set = {"images": np.concatenate(l_images, 0), "labels": np.concatenate(l_labels, 0)} 79 | u_train_set = {"images": np.concatenate(u_images, 0), "labels": np.concatenate(u_labels, 0)} 80 | 81 | indices = rng.permutation(len(l_train_set["images"])) 82 | l_train_set["images"] = l_train_set["images"][indices] 83 | l_train_set["labels"] = l_train_set["labels"][indices] 84 | 85 | indices = rng.permutation(len(u_train_set["images"])) 86 | u_train_set["images"] = u_train_set["images"][indices] 87 | u_train_set["labels"] = u_train_set["labels"][indices] 88 | return l_train_set, u_train_set 89 | 90 | def split_test(test_set, tot_class=6): 91 | images = test_set["images"] 92 | labels = test_set['labels'] 93 | classes = np.unique(labels) 94 | l_images = [] 95 | l_labels = [] 96 | for c in classes[:tot_class]: 97 | cls_mask = (labels == c) 98 | c_images = images[cls_mask] 99 | c_labels = labels[cls_mask] 100 | l_images += [c_images[:]] 101 | l_labels += [c_labels[:]] 102 | test_set = {"images": np.concatenate(l_images, 0), "labels":np.concatenate(l_labels,0)} 103 | 104 | indices = rng.permutation(len(test_set["images"])) 105 | test_set["images"] = test_set["images"][indices] 106 | test_set["labels"] = test_set["labels"][indices] 107 | return test_set 108 | 109 | def load_mnist(): 110 | splits = {} 111 | trans = tv_transforms.Compose([tv_transforms.ToPILImage(),tv_transforms.ToTensor(), tv_transforms.Normalize((0.5,), (1.0,))]) 112 | for train in [True, False]: 113 | dataset = datasets.MNIST(data_path, train, transform=trans, download=True) 114 | data = {} 115 | data['images'] = dataset.data 116 | data['labels'] = np.array(dataset.targets) 117 | splits['train' if train else 'test'] = data 118 | return splits.values() 119 | 120 | def load_cifar10(): 121 | splits = {} 122 | for train in [True, False]: 123 | dataset = datasets.CIFAR10(data_path, train, download=True) 124 | data = {} 125 | data['images'] = dataset.data 126 | data['labels'] = np.array(dataset.targets) 127 | splits["train" if train else "test"] = data 128 | return splits.values() 129 | 130 | 131 | def gcn(images, multiplier=55, eps=1e-10): 132 | #global contrast normalization 133 | images = images.astype(np.float) 134 | images -= images.mean(axis=(1,2,3), keepdims=True) 135 | per_image_norm = np.sqrt(np.square(images).sum((1,2,3), keepdims=True)) 136 | per_image_norm[per_image_norm < eps] = 1 137 | images = multiplier * images / per_image_norm 138 | return images 139 | 140 | def get_zca_normalization_param(images, scale=0.1, eps=1e-10): 141 | n_data, height, width, channels = images.shape 142 | images = images.reshape(n_data, height*width*channels) 143 | image_cov = np.cov(images, rowvar=False) 144 | U, S, _ = np.linalg.svd(image_cov + scale * np.eye(image_cov.shape[0])) 145 | zca_decomp = np.dot(U, np.dot(np.diag(1/np.sqrt(S + eps)), U.T)) 146 | mean = images.mean(axis=0) 147 | return mean, zca_decomp 148 | 149 | def zca_normalization(images, mean, decomp): 150 | n_data, height, width, channels = images.shape 151 | images = images.reshape(n_data, -1) 152 | images = np.dot((images - mean), decomp) 153 | return images.reshape(n_data, height, width, channels) 154 | 155 | def get_dataloaders(dataset, n_labels, n_unlabels, n_valid, l_batch_size, ul_batch_size, test_batch_size, iterations, 156 | tot_class, ratio): 157 | 158 | rng = np.random.RandomState(seed=1) 159 | 160 | if dataset == "MNIST": 161 | train_set, test_set = load_mnist() 162 | transform = False 163 | elif dataset == "CIFAR10": 164 | train_set, test_set = load_cifar10() 165 | train_set["images"] = gcn(train_set["images"]) 166 | test_set["images"] = gcn(test_set["images"]) 167 | mean, zca_decomp = get_zca_normalization_param(train_set["images"]) 168 | train_set["images"] = zca_normalization(train_set["images"], mean, zca_decomp) 169 | test_set["images"] = zca_normalization(test_set["images"], mean, zca_decomp) 170 | # N x H x W x C -> N x C x H x W 171 | train_set["images"] = np.transpose(train_set["images"], (0, 3, 1, 2)) 172 | test_set["images"] = np.transpose(test_set["images"], (0, 3, 1, 2)) 173 | 174 | #move class "plane" and "car" to label 8 and 9 175 | train_set['labels'] -= 2 176 | test_set['labels'] -= 2 177 | train_set['labels'][np.where(train_set['labels'] == -2)] = 8 178 | train_set['labels'][np.where(train_set['labels'] == -1)] = 9 179 | test_set['labels'][np.where(test_set['labels'] == -2)] = 8 180 | test_set['labels'][np.where(test_set['labels'] == -1)] = 9 181 | 182 | transform = False 183 | 184 | #permute index of training set 185 | indices = rng.permutation(len(train_set['images'])) 186 | train_set['images'] = train_set['images'][indices] 187 | train_set['labels'] = train_set['labels'][indices] 188 | 189 | #split training set into training and validation 190 | train_images = train_set['images'][n_valid:] 191 | train_labels = train_set['labels'][n_valid:] 192 | validation_images = train_set['images'][:n_valid] 193 | validation_labels = train_set['labels'][:n_valid] 194 | validation_set = {'images': validation_images, 'labels': validation_labels} 195 | train_set = {'images': train_images, 'labels': train_labels} 196 | 197 | 198 | #split training set into labeled and unlabeled data 199 | validation_set = split_test(validation_set, tot_class=tot_class) 200 | test_set = split_test(test_set, tot_class=tot_class) 201 | l_train_set, u_train_set = split_l_u(train_set, n_labels, n_unlabels, tot_class=tot_class, ratio=ratio) 202 | 203 | print("Unlabeled data in distribuiton : {}, Unlabeled data out distribution : {}".format( 204 | np.sum(u_train_set['labels'] < tot_class), np.sum(u_train_set['labels'] >= tot_class))) 205 | 206 | l_train_set = SimpleDataset(l_train_set, transform) 207 | u_train_set = SimpleDataset(u_train_set, transform) 208 | validation_set = SimpleDataset(validation_set, transform) 209 | test_set = SimpleDataset(test_set, transform) 210 | 211 | print("labeled data : {}, unlabeled data : {}, training data : {}".format( 212 | len(l_train_set), len(u_train_set), len(l_train_set) + len(u_train_set))) 213 | print("validation data : {}, test data : {}".format(len(validation_set), len(test_set))) 214 | data_loaders = { 215 | 'labeled': DataLoader( 216 | l_train_set, l_batch_size, drop_last=True, 217 | sampler=RandomSampler(len(l_train_set), iterations * l_batch_size)), 218 | 'unlabeled': DataLoader( 219 | u_train_set, ul_batch_size, drop_last=True, 220 | sampler=RandomSampler(len(u_train_set), iterations * ul_batch_size)), 221 | 'valid': DataLoader( 222 | validation_set, test_batch_size, shuffle=False, drop_last=False), 223 | 'test': DataLoader( 224 | test_set, test_batch_size, shuffle=False, drop_last=False) 225 | } 226 | return data_loaders 227 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from load_dataset import * 7 | import transform 8 | from wideresnet import WideResNet, CNN, WNet 9 | 10 | import argparse 11 | import math 12 | import time 13 | 14 | import os 15 | 16 | parser = argparse.ArgumentParser(description='manual to this script') 17 | 18 | #model 19 | parser.add_argument('--depth', type=int, default=28) 20 | parser.add_argument('--width', type=int, default=2) 21 | 22 | 23 | #optimization 24 | parser.add_argument('--optim', default='adam') 25 | parser.add_argument('--iterations', type=int, default=200000) 26 | parser.add_argument('--l_batch_size', type=int, default=100) 27 | parser.add_argument('--ul_batch_size', type=int, default=100) 28 | parser.add_argument('--test_batch_size', type=int, default=128) 29 | parser.add_argument('--lr_decay_iter', type=int, default=400000) 30 | parser.add_argument('--lr_decay_factor', type=float, default=0.2) 31 | parser.add_argument('--warmup', type=int, default=200000) 32 | parser.add_argument('--meta_lr', type=float, default=0.001) 33 | parser.add_argument('--lr_wnet', type=float, default=6e-5) # this parameter need to be carefully tuned for different settings 34 | 35 | #dataset 36 | parser.add_argument('--dataset', default='MNIST') 37 | parser.add_argument('--n_labels', type=int, default=60) 38 | parser.add_argument('--n_unlabels', type=int, default=20000) 39 | parser.add_argument('--n_valid', type=int, default=5000) 40 | parser.add_argument('--n_class', type=int, default=6) 41 | parser.add_argument('--tot_class', type=int, default=10) 42 | parser.add_argument('--ratio', type=float, default=0.6) 43 | 44 | 45 | args = parser.parse_args() 46 | 47 | if torch.cuda.is_available(): 48 | device = "cuda" 49 | torch.backends.cudnn.benckmark = True 50 | else: 51 | device = "cpu" 52 | 53 | class MSE_Loss(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | 57 | def forward(self, x, y, model, mask): 58 | y_hat = model(x) 59 | return (F.mse_loss(y_hat.softmax(1), y.softmax(1).detach(), reduction='none').mean(1)*mask) 60 | 61 | def build_model(): 62 | 63 | if(args.dataset == 'CIFAR10'): 64 | transform_fn = transform.transform() 65 | model = WideResNet(widen_factor=args.width, n_classes=args.n_class, transform_fn=transform_fn).to(device) 66 | if(args.dataset == 'MNIST'): 67 | model = CNN(n_out=args.n_class).to(device) 68 | return model 69 | 70 | def bi_train(model, label_loader, unlabeled_loader, val_loader, test_loader, optimizer, ssl_obj): 71 | wnet = WNet(6, 100, 1).to(device) 72 | 73 | wnet.train() 74 | 75 | t = time.time() 76 | best_acc = 0.0 77 | test_acc = 0.0 78 | iteration = 0 79 | 80 | optimizer_wnet = torch.optim.Adam(wnet.params(), lr=args.lr_wnet) 81 | 82 | for l_data, u_data in zip(label_loader, unlabeled_loader): 83 | 84 | #load data 85 | iteration += 1 86 | l_images, l_labels, _ = l_data 87 | u_images, u_labels, idx = u_data 88 | if args.dataset == 'MNIST': 89 | l_images = l_images.unsqueeze(1) 90 | u_images = u_images.unsqueeze(1) 91 | l_images, l_labels = l_images.to(device).float(), l_labels.to(device).long() 92 | u_images, u_labels = u_images.to(device).float(), u_labels.to(device).long() 93 | 94 | model.train() 95 | meta_net = build_model() 96 | meta_net.load_state_dict(model.state_dict()) 97 | 98 | # cat labeled and unlabeled data 99 | labels = torch.cat([l_labels, u_labels], 0) 100 | labels[-len(u_labels):] = -1 #unlabeled mask 101 | unlabeled_mask = (labels == -1).float() 102 | images = torch.cat([l_images, u_images], 0) 103 | 104 | #coefficient for unsupervised loss 105 | coef = 10.0 * math.exp(-5 * (1 - min(iteration / args.warmup, 1)) ** 2) 106 | 107 | out = meta_net(images) 108 | ssl_loss = ssl_obj(images, out.detach(), meta_net, unlabeled_mask) 109 | 110 | cost_w = torch.reshape(ssl_loss[len(l_labels):], (len(ssl_loss[len(l_labels):]), 1)) 111 | 112 | 113 | weight = wnet(out.softmax(1)[len(l_labels):]) 114 | norm = torch.sum(weight) 115 | 116 | cls_loss = F.cross_entropy(out, labels, reduction='none', ignore_index=-1).mean() 117 | if norm != 0: 118 | loss_hat = cls_loss + coef * (torch.sum(cost_w * weight) / norm + ssl_loss[:len(l_labels)].mean()) 119 | else: 120 | loss_hat = cls_loss + coef * (torch.sum(cost_w * weight) + ssl_loss[:len(l_labels)].mean()) 121 | 122 | meta_net.zero_grad() 123 | grads = torch.autograd.grad(loss_hat, (meta_net.params()), create_graph=True) 124 | meta_net.update_params(lr_inner=args.meta_lr, source_params=grads) 125 | del grads 126 | 127 | #compute upper level objective 128 | y_g_hat = meta_net(l_images) 129 | l_g_meta = F.cross_entropy(y_g_hat, l_labels) 130 | 131 | optimizer_wnet.zero_grad() 132 | l_g_meta.backward() 133 | optimizer_wnet.step() 134 | 135 | out = model(images) 136 | 137 | ssl_loss = ssl_obj(images, out.detach(), model, unlabeled_mask) 138 | cls_loss = F.cross_entropy(out, labels, reduction='none', ignore_index=-1).mean() 139 | cost_w = torch.reshape(ssl_loss[len(l_labels):], (len(ssl_loss[len(l_labels):]), 1)) 140 | with torch.no_grad(): 141 | weight = wnet(out.softmax(1)[len(l_labels):]) 142 | norm = torch.sum(weight) 143 | 144 | if norm != 0: 145 | loss = cls_loss + coef * (torch.sum(cost_w * weight) / norm + ssl_loss[:len(l_labels)].mean()) 146 | else: 147 | loss = cls_loss + coef * (torch.sum(cost_w * weight) + ssl_loss[:len(l_labels)].mean()) 148 | 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | 153 | if iteration == 1 or (iteration % 1000) == 0: 154 | time_cost = time.time() - t 155 | print("iteration [{}/{}] cls loss : {:.6e}, time : {:.3f} sec/iter, lr : {}, coef: {}".format( 156 | iteration, args.iterations, loss.item(), time_cost / 100, optimizer.param_groups[0]["lr"], coef)) 157 | t = time.time() 158 | 159 | if (iteration % 10000) == 0 or iteration == args.iterations: 160 | acc = test(model, val_loader) 161 | print("Validation Accuracy: {}".format(acc)) 162 | if (acc > best_acc): 163 | best_acc = acc 164 | test_acc = test(model, test_loader) 165 | model.train() 166 | if iteration == args.lr_decay_iter: 167 | optimizer.param_groups[0]['lr'] *= args.lr_decay_factor 168 | print("Last Model Accuracy: {}".format(test(model, test_loader))) 169 | print("Test Accuracy: {}".format(test_acc)) 170 | 171 | 172 | def test(model, test_loader): 173 | with torch.no_grad(): 174 | model.eval() 175 | correct = 0. 176 | tot = 0. 177 | for i, data in enumerate(test_loader): 178 | images, labels, _ = data 179 | 180 | if args.dataset == 'MNIST': 181 | images = images.unsqueeze(1) 182 | 183 | images = images.to(device).float() 184 | labels = labels.to(device).long() 185 | 186 | out = model(images) 187 | 188 | pred_label = out.max(1)[1] 189 | correct += (pred_label == labels).float().sum() 190 | tot += pred_label.size(0) 191 | acc = correct / tot 192 | return acc 193 | 194 | def main(): 195 | 196 | args.l_batch_size = args.l_batch_size // 2 197 | args.ul_batch_size = args.ul_batch_size // 2 198 | 199 | data_loaders = get_dataloaders(dataset=args.dataset, n_labels=args.n_labels, n_unlabels=args.n_unlabels, n_valid=args.n_valid, 200 | l_batch_size=args.l_batch_size, ul_batch_size=args.ul_batch_size, 201 | test_batch_size=args.test_batch_size, iterations=args.iterations, 202 | tot_class=args.n_class, ratio=args.ratio) 203 | label_loader = data_loaders['labeled'] 204 | unlabel_loader = data_loaders['unlabeled'] 205 | test_loader = data_loaders['test'] 206 | val_loader = data_loaders['valid'] 207 | 208 | 209 | model = build_model() 210 | 211 | if(args.dataset=="MNIST"): 212 | optimizer = torch.optim.SGD(model.params(), lr=1e-3) 213 | else: 214 | optimizer = torch.optim.Adam(model.params(), lr=3e-4) 215 | 216 | 217 | U_Loss = MSE_Loss() 218 | bi_train(model, label_loader, unlabel_loader, val_loader, test_loader, optimizer, U_Loss) 219 | 220 | if __name__ == '__main__': 221 | main() 222 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | 5 | class transform: 6 | def __init__(self, flip=True, r_crop=True, g_noise=True): 7 | self.flip = flip 8 | self.r_crop = r_crop 9 | self.g_noise = g_noise 10 | 11 | def __call__(self, x): 12 | if self.flip and random.random() > 0.5: 13 | x = x.flip(-1) 14 | if self.r_crop: 15 | h, w = x.shape[-2:] 16 | x = F.pad(x, [2,2,2,2], mode="reflect") 17 | l, t = random.randint(0, 4), random.randint(0,4) 18 | x = x[:,:,t:t+h,l:l+w] 19 | if self.g_noise: 20 | n = torch.randn_like(x) * 0.15 21 | x = n + x 22 | return x -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import random 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class MetaModule(nn.Module): 16 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | # name_s, param_s = src 52 | # grad = param_s.grad 53 | # name_s, param_s = src 54 | grad = src 55 | if first_order: 56 | grad = to_var(grad.detach().data) 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | else: 60 | 61 | for name, param in self.named_params(self): 62 | if not detach: 63 | grad = param.grad 64 | if first_order: 65 | grad = to_var(grad.detach().data) 66 | tmp = param - lr_inner * grad 67 | self.set_param(self, name, tmp) 68 | else: 69 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 70 | self.set_param(self, name, param) 71 | 72 | def set_param(self, curr_mod, name, param): 73 | if '.' in name: 74 | n = name.split('.') 75 | module_name = n[0] 76 | rest = '.'.join(n[1:]) 77 | for name, mod in curr_mod.named_children(): 78 | if module_name == name: 79 | self.set_param(mod, rest, param) 80 | break 81 | else: 82 | setattr(curr_mod, name, param) 83 | 84 | def detach_params(self): 85 | for name, param in self.named_params(self): 86 | self.set_param(self, name, param.detach()) 87 | 88 | def copy(self, other, same_var=False): 89 | for name, param in other.named_params(): 90 | if not same_var: 91 | param = to_var(param.data.clone(), requires_grad=True) 92 | self.set_param(name, param) 93 | 94 | 95 | class MetaLinear(MetaModule): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__() 98 | ignore = nn.Linear(*args, **kwargs) 99 | 100 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 101 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 102 | 103 | def forward(self, x): 104 | return F.linear(x, self.weight, self.bias) 105 | 106 | def named_leaves(self): 107 | return [('weight', self.weight), ('bias', self.bias)] 108 | 109 | 110 | class MetaConv2d(MetaModule): 111 | def __init__(self, *args, **kwargs): 112 | super().__init__() 113 | ignore = nn.Conv2d(*args, **kwargs) 114 | 115 | self.in_channels = ignore.in_channels 116 | self.out_channels = ignore.out_channels 117 | self.stride = ignore.stride 118 | self.padding = ignore.padding 119 | self.dilation = ignore.dilation 120 | self.groups = ignore.groups 121 | self.kernel_size = ignore.kernel_size 122 | 123 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 124 | 125 | if ignore.bias is not None: 126 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 127 | else: 128 | self.register_buffer('bias', None) 129 | 130 | def forward(self, x): 131 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 132 | 133 | def named_leaves(self): 134 | return [('weight', self.weight), ('bias', self.bias)] 135 | 136 | 137 | class MetaConvTranspose2d(MetaModule): 138 | def __init__(self, *args, **kwargs): 139 | super().__init__() 140 | ignore = nn.ConvTranspose2d(*args, **kwargs) 141 | 142 | self.stride = ignore.stride 143 | self.padding = ignore.padding 144 | self.dilation = ignore.dilation 145 | self.groups = ignore.groups 146 | 147 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 148 | 149 | if ignore.bias is not None: 150 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 151 | else: 152 | self.register_buffer('bias', None) 153 | 154 | def forward(self, x, output_size=None): 155 | output_padding = self._output_padding(x, output_size) 156 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 157 | output_padding, self.groups, self.dilation) 158 | 159 | def named_leaves(self): 160 | return [('weight', self.weight), ('bias', self.bias)] 161 | 162 | 163 | class MetaBatchNorm2d(MetaModule): 164 | def __init__(self, *args, **kwargs): 165 | super().__init__() 166 | ignore = nn.BatchNorm2d(*args, **kwargs) 167 | 168 | self.num_features = ignore.num_features 169 | self.eps = ignore.eps 170 | self.momentum = ignore.momentum 171 | self.affine = ignore.affine 172 | self.track_running_stats = ignore.track_running_stats 173 | 174 | self.update_batch_stats = True 175 | 176 | if self.affine: 177 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 178 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 179 | 180 | if self.track_running_stats: 181 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 182 | self.register_buffer('running_var', torch.ones(self.num_features)) 183 | else: 184 | self.register_parameter('running_mean', None) 185 | self.register_parameter('running_var', None) 186 | 187 | def forward(self, x): 188 | #if self.update_batch_stats: 189 | # return super().forward(x) 190 | #else: 191 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 192 | self.training or not self.track_running_stats, self.momentum, self.eps) 193 | 194 | def named_leaves(self): 195 | return [('weight', self.weight), ('bias', self.bias)] 196 | 197 | 198 | class MetaBasicBlock(MetaModule): 199 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 200 | super(MetaBasicBlock, self).__init__() 201 | 202 | self.bn1 = MetaBatchNorm2d(in_planes) 203 | self.relu1 = nn.LeakyReLU(0.1) 204 | self.conv1 = MetaConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 205 | padding=1, bias=False) 206 | self.bn2 = MetaBatchNorm2d(out_planes) 207 | self.relu2 = nn.LeakyReLU(0.1) 208 | self.conv2 = MetaConv2d(out_planes, out_planes, kernel_size=3, stride=1, 209 | padding=1, bias=False) 210 | self.droprate = dropRate 211 | self.equalInOut = (in_planes == out_planes) 212 | self.convShortcut = (not self.equalInOut) and MetaConv2d(in_planes, out_planes, kernel_size=1, stride=stride, 213 | padding=0, bias=False) or None 214 | def forward(self, x): 215 | if not self.equalInOut: 216 | x = self.relu1(self.bn1(x)) 217 | else: 218 | out = self.relu1(self.bn1(x)) 219 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 220 | if self.droprate > 0: 221 | out = F.dropout(out, p=self.droprate, training=self.training) 222 | out = self.conv2(out) 223 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 224 | 225 | 226 | class MetaNetworkBlock(MetaModule): 227 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 228 | super(MetaNetworkBlock, self).__init__() 229 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 230 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 231 | layers = [] 232 | for i in range(int(nb_layers)): 233 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 234 | return nn.Sequential(*layers) 235 | def forward(self, x): 236 | return self.layer(x) 237 | 238 | class WideResNet(MetaModule): 239 | def __init__(self, depth=28, widen_factor=2, n_classes=10, dropRate=0.0, transform_fn=None): 240 | super(WideResNet, self).__init__() 241 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 242 | assert((depth - 4) % 6 == 0) 243 | n = (depth - 4) / 6 244 | block = MetaBasicBlock 245 | # 1st conv before any network block 246 | self.conv1 = MetaConv2d(3, nChannels[0], kernel_size=3, stride=1, 247 | padding=1, bias=False) 248 | # 1st block 249 | self.block1 = MetaNetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 250 | # 2nd block 251 | self.block2 = MetaNetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 252 | # 3rd block 253 | self.block3 = MetaNetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 254 | # global average pooling and classifier 255 | self.bn1 = MetaBatchNorm2d(nChannels[3]) 256 | self.relu = nn.LeakyReLU(0.1) 257 | self.fc = MetaLinear(nChannels[3], n_classes) 258 | self.nChannels = nChannels[3] 259 | 260 | for m in self.modules(): 261 | if isinstance(m, MetaConv2d): 262 | #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 263 | #m.weight.data.normal_(0, math.sqrt(2. / n)) 264 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 265 | elif isinstance(m, MetaBatchNorm2d): 266 | #m.weight.data.fill_(1) 267 | #m.bias.data.zero_() 268 | nn.init.constant_(m.weight, 1) 269 | nn.init.constant_(m.bias, 0) 270 | elif isinstance(m, MetaLinear): 271 | #m.bias.data.zero_() 272 | nn.init.xavier_normal_(m.weight) 273 | nn.init.constant_(m.bias, 0) 274 | 275 | self.transform_fn = transform_fn 276 | def forward(self, x): 277 | if self.training and self.transform_fn is not None: 278 | x = self.transform_fn(x) 279 | out = self.conv1(x) 280 | out = self.block1(out) 281 | out = self.block2(out) 282 | out = self.block3(out) 283 | out = self.relu(self.bn1(out)) 284 | out = F.avg_pool2d(out, 8) 285 | out = out.view(-1, self.nChannels) 286 | return self.fc(out) 287 | def update_batch_stats(self, flag): 288 | for m in self.modules(): 289 | if isinstance(m, MetaBatchNorm2d): 290 | m.update_batch_stats = flag 291 | 292 | 293 | class WNet(MetaModule): 294 | def __init__(self, input, hidden, output): 295 | super(WNet, self).__init__() 296 | self.linear1 = MetaLinear(input, hidden) 297 | self.relu = nn.ReLU(inplace=True) 298 | self.linear2 = MetaLinear(hidden, output) 299 | 300 | def forward(self, x): 301 | x = self.linear1(x) 302 | x = self.relu(x) 303 | out = self.linear2(x) 304 | return torch.sigmoid(out) 305 | 306 | 307 | class LeNet(MetaModule): 308 | def __init__(self, n_out): 309 | super(LeNet, self).__init__() 310 | 311 | layers = [] 312 | layers.append(MetaConv2d(1, 6, 3, padding=1)) 313 | layers.append(nn.MaxPool2d(3, stride=2, padding=1)) 314 | layers.append(nn.ReLU()) 315 | 316 | layers.append(MetaConv2d(6, 16, 3, padding=1)) 317 | layers.append(nn.MaxPool2d(3, stride=2, padding=1)) 318 | layers.append(nn.ReLU()) 319 | 320 | layers.append(MetaConv2d(16, 120, 3, padding=1)) 321 | layers.append(nn.ReLU()) 322 | 323 | self.main = nn.Sequential(*layers) 324 | 325 | layers = [] 326 | layers.append(MetaLinear(120 * 7 * 7, 84)) 327 | layers.append(nn.ReLU()) 328 | layers.append(MetaLinear(84, n_out)) 329 | 330 | self.fc_layers = nn.Sequential(*layers) 331 | 332 | def forward(self, x): 333 | x = self.main(x) 334 | x = x.view(-1, 120 * 7 * 7) 335 | x = self.fc_layers(x) 336 | return x 337 | 338 | class CNN(MetaModule): 339 | def __init__(self, n_out): 340 | super(CNN, self).__init__() 341 | 342 | self.conv = torch.nn.Sequential(MetaConv2d(1, 16, 3, padding=1), 343 | nn.MaxPool2d(3, stride=2, padding=1), 344 | nn.ReLU(), 345 | MetaConv2d(16, 32, 3, padding=1), 346 | nn.MaxPool2d(3, stride=2, padding=1), 347 | nn.ReLU() 348 | ) 349 | self.dense = torch.nn.Sequential(nn.Dropout(p=0.5), 350 | MetaLinear(32 * 7 * 7, n_out)) 351 | def forward(self, x): 352 | x = self.conv(x) 353 | x = x.view(-1, 32 * 7 * 7) 354 | x = self.dense(x) 355 | return x --------------------------------------------------------------------------------