├── NasBench101 ├── nas_model_search.py ├── nas_train_search.py └── nas_utils.py ├── README.md └── S1 ├── model_search.py ├── train_search.py └── utils.py /NasBench101/nas_model_search.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from nas_utils import random_choice 6 | 7 | 8 | class ConvBnRelu(nn.Module): 9 | def __init__(self, inplanes, outplanes, k): 10 | super(ConvBnRelu, self).__init__() 11 | 12 | self.op = nn.Sequential( 13 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), 14 | nn.BatchNorm2d(outplanes), 15 | nn.ReLU(), 16 | 17 | nn.Conv2d(outplanes, outplanes, kernel_size=k, stride=1, padding=k // 2, bias=False), 18 | nn.BatchNorm2d(outplanes), 19 | nn.ReLU() 20 | ) 21 | 22 | def forward(self, x): 23 | return self.op(x) 24 | 25 | 26 | class MaxPool(nn.Module): 27 | def __init__(self, inplanes, outplanes): 28 | super(MaxPool, self).__init__() 29 | 30 | self.op = nn.Sequential( 31 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), 32 | nn.BatchNorm2d(outplanes), 33 | nn.ReLU(), 34 | 35 | nn.MaxPool2d(3, 1, padding=1) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class Cell(nn.Module): 43 | def __init__(self, inplanes, outplanes, shadow_bn): 44 | super(Cell, self).__init__() 45 | self.inplanes = inplanes 46 | self.outplanes = outplanes 47 | self.shadow_bn = shadow_bn 48 | 49 | self.nodes = nn.ModuleList([]) 50 | for i in range(4): 51 | self.nodes.append(ConvBnRelu(self.inplanes, self.outplanes, 1)) 52 | self.nodes.append(ConvBnRelu(self.inplanes, self.outplanes, 3)) 53 | self.nodes.append(MaxPool(self.inplanes, self.outplanes)) 54 | self.nodes.append(nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1)) 55 | 56 | self.bn_list = nn.ModuleList([]) 57 | if self.shadow_bn: 58 | for j in range(4): 59 | self.bn_list.append(nn.BatchNorm2d(outplanes)) 60 | else: 61 | self.bn = nn.BatchNorm2d(outplanes) 62 | 63 | def forward(self, x, choice): 64 | path_ids = choice['path'] # eg.[0, 2, 3] 65 | op_ids = choice['op'] # eg.[1, 1, 2] 66 | x_list = [] 67 | for i, id in enumerate(path_ids): 68 | x_list.append(self.nodes[id * 3 + op_ids[i]](x)) 69 | 70 | x = sum(x_list) 71 | out = self.nodes[-1](x) 72 | return F.relu(out) 73 | 74 | 75 | class SuperNetwork(nn.Module): 76 | def __init__(self, init_channels, classes=10, shadow_bn=True): 77 | super(SuperNetwork, self).__init__() 78 | self.init_channels = init_channels 79 | 80 | self.stem = nn.Sequential( 81 | nn.Conv2d(3, self.init_channels, kernel_size=3, stride=1, padding=1, bias=False), 82 | nn.BatchNorm2d(self.init_channels), 83 | nn.ReLU(inplace=True) 84 | ) 85 | 86 | self.cell_list = nn.ModuleList([]) 87 | for i in range(9): 88 | if i in [3, 6]: 89 | self.cell_list.append(Cell(self.init_channels, self.init_channels * 2, shadow_bn=shadow_bn)) 90 | self.init_channels *= 2 91 | else: 92 | self.cell_list.append(Cell(self.init_channels, self.init_channels, shadow_bn=shadow_bn)) 93 | 94 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 95 | self.classifier = nn.Linear(self.init_channels, classes) 96 | self._initialize_weights() 97 | 98 | def _initialize_weights(self): 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | m.weight.data.normal_(0, math.sqrt(2. / n)) 103 | if m.bias is not None: 104 | m.bias.data.zero_() 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1.0) 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.Linear): 109 | n = m.weight.size(0) # fan-out 110 | init_range = 1.0 / math.sqrt(n) 111 | m.weight.data.uniform_(-init_range, init_range) 112 | m.bias.data.zero_() 113 | 114 | def forward(self, x, choice): 115 | x = self.stem(x) 116 | for i in range(9): 117 | x = self.cell_list[i](x, choice) 118 | if i in [2, 5]: 119 | x = nn.MaxPool2d(2, 2, padding=0)(x) 120 | x = self.global_pooling(x) 121 | x = x.view(-1, self.init_channels) 122 | out = self.classifier(x) 123 | 124 | return out 125 | 126 | 127 | if __name__ == '__main__': 128 | # ['conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3'] 129 | # choice = {'path': [0, 1, 2], # a list of shape (4, ) 130 | # 'op': [0, 0, 0]} # possible shapes: (), (1, ), (2, ), (3, ) 131 | choice = random_choice(3) 132 | print(choice) 133 | model = SuperNetwork(init_channels=128) 134 | input = torch.randn((1, 3, 32, 32)) 135 | print(model(input, choice)) 136 | -------------------------------------------------------------------------------- /NasBench101/nas_train_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import ast 4 | import argparse 5 | import nas_utils 6 | from tqdm import tqdm 7 | import torch.nn as nn 8 | from nas_utils import * 9 | from scipy.stats import kendalltau 10 | import torchvision.datasets as dset 11 | import torch.backends.cudnn as cudnn 12 | from nas_model_search import SuperNetwork 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser("MixPath") 17 | parser.add_argument('--exp_name', type=str, required=True, help='search model name') 18 | parser.add_argument('--m', type=int, default=2, required=True, help='num of selected paths as most') 19 | parser.add_argument('--shadow_bn', action='store_false', default=True, help='shadow bn or not, default: True') 20 | parser.add_argument('--data_dir', type=str, default='/home/work/dataset/cifar', help='dataset dir') 21 | parser.add_argument('--classes', type=int, default=10, help='classes') 22 | parser.add_argument('--layers', type=int, default=12, help='num of MB_layers') 23 | parser.add_argument('--kernels', type=list, default=[3, 5, 7], help='selective kernels') 24 | parser.add_argument('--batch_size', type=int, default=96, help='batch size') 25 | parser.add_argument('--epochs', type=int, default=200, help='num of epochs') 26 | parser.add_argument('--seed', type=int, default=2020, help='seed') 27 | parser.add_argument('--search_num', type=int, default=1000, help='num of epochs') 28 | parser.add_argument('--learning_rate', type=float, default=0.025, help='initial learning rate') 29 | parser.add_argument('--learning_rate_min', type=float, default=1e-8, help='min learning rate') 30 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 31 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') 32 | parser.add_argument('--train_interval', type=int, default=1, help='train to print frequency') 33 | parser.add_argument('--val_interval', type=int, default=5, help='evaluate and save frequency') 34 | parser.add_argument('--dropout_rate', type=float, default=0.2, help='drop out rate') 35 | parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop_path_prob') 36 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 37 | parser.add_argument('--gpu', type=int, default=0, help='gpu id') 38 | parser.add_argument('--resume', type=bool, default=False, help='resume') 39 | # ******************************* dataset *******************************# 40 | parser.add_argument('--data', type=str, default='cifar10', help='[cifar10, imagenet]') 41 | parser.add_argument('--cutout', action='store_false', default=True, help='use cutout') 42 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 43 | parser.add_argument('--resize', action='store_true', default=False, help='use resize') 44 | 45 | arguments = parser.parse_args() 46 | 47 | return arguments 48 | 49 | 50 | def validate_cali(args, val_data, device, model, choice): 51 | model.eval() 52 | val_loss = 0.0 53 | val_top1 = AvgrageMeter() 54 | val_top5 = AvgrageMeter() 55 | criterion = nn.CrossEntropyLoss() 56 | with torch.no_grad(): 57 | for step, (inputs, targets) in enumerate(val_data): 58 | inputs, targets = inputs.to(device), targets.to(device) 59 | 60 | outputs = model(inputs, choice) 61 | 62 | loss = criterion(outputs, targets) 63 | val_loss += loss.item() 64 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 65 | n = inputs.size(0) 66 | val_top1.update(prec1.item(), n) 67 | val_top5.update(prec5.item(), n) 68 | print(val_top1.avg, ',') 69 | return val_top1.avg, val_top5.avg, val_loss / (step + 1) 70 | 71 | check_dict = [] 72 | def validate_search(args, val_data, device, model): 73 | model.eval() 74 | choice_dict = {} 75 | val_loss = 0.0 76 | val_top1 = AvgrageMeter() 77 | val_top5 = AvgrageMeter() 78 | criterion = nn.CrossEntropyLoss() 79 | choice = random_choice(m=args.m) 80 | 81 | while choice in check_dict: 82 | print('Duplicate Index !') 83 | choice = random_choice(m=args.m) 84 | check_dict.append(choice) 85 | with torch.no_grad(): 86 | for step, (inputs, targets) in enumerate(val_data): 87 | inputs, targets = inputs.to(device), targets.to(device) 88 | outputs = model(inputs, choice) 89 | loss = criterion(outputs, targets) 90 | val_loss += loss.item() 91 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 92 | n = inputs.size(0) 93 | val_top1.update(prec1.item(), n) 94 | val_top5.update(prec5.item(), n) 95 | choice_dict['op'] = choice['op'] 96 | choice_dict['path'] = choice['path'] 97 | choice_dict['val_loss'] = val_loss / (step + 1) 98 | choice_dict['val_top1'] = val_top1.avg 99 | 100 | return choice_dict 101 | 102 | 103 | def train(args, epoch, train_data, device, model, criterion, optimizer): 104 | model.train() 105 | train_loss = 0.0 106 | top1 = AvgrageMeter() 107 | top5 = AvgrageMeter() 108 | 109 | for step, (inputs, targets) in enumerate(train_data): 110 | inputs, targets = inputs.to(device), targets.to(device) 111 | 112 | optimizer.zero_grad() 113 | choice = random_choice(m=args.m) 114 | outputs = model(inputs, choice) 115 | 116 | loss = criterion(outputs, targets) 117 | loss.backward() 118 | 119 | for p in model.parameters(): 120 | if p.grad is not None and p.grad.sum() == 0: 121 | p.grad = None 122 | 123 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 124 | 125 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 126 | n = inputs.size(0) 127 | top1.update(prec1.item(), n) 128 | top5.update(prec5.item(), n) 129 | optimizer.step() 130 | train_loss += loss.item() 131 | 132 | postfix = {'loss': '%.6f' % (train_loss / (step + 1)), 'top1': '%.3f' % top1.avg} 133 | 134 | train_data.set_postfix(postfix) 135 | 136 | 137 | def validate(args, val_data, device, model): 138 | model.eval() 139 | val_loss = 0.0 140 | val_top1 = AvgrageMeter() 141 | val_top5 = AvgrageMeter() 142 | criterion = nn.CrossEntropyLoss() 143 | 144 | with torch.no_grad(): 145 | top1_m = [] 146 | top5_m = [] 147 | loss_m = [] 148 | for _ in range(20): 149 | choice = random_choice(m=args.m) 150 | for step, (inputs, targets) in enumerate(val_data): 151 | inputs, targets = inputs.to(device), targets.to(device) 152 | outputs = model(inputs, choice) 153 | loss = criterion(outputs, targets) 154 | val_loss += loss.item() 155 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 156 | n = inputs.size(0) 157 | val_top1.update(prec1.item(), n) 158 | val_top5.update(prec5.item(), n) 159 | top1_m.append(val_top1.avg), top5_m.append(val_top5.avg), loss_m.append(val_loss / (step + 1)) 160 | 161 | return np.mean(top1_m), np.mean(top5_m), np.mean(loss_m) 162 | 163 | 164 | def separate_bn_params(model): 165 | bn_index = [] 166 | bn_params = [] 167 | for m in model.modules(): 168 | if isinstance(m, nn.BatchNorm2d): 169 | bn_index += list(map(id, m.parameters())) 170 | bn_params += m.parameters() 171 | base_params = list(filter(lambda p: id(p) not in bn_index, model.parameters())) 172 | return base_params, bn_params 173 | 174 | 175 | def main(): 176 | args = get_args() 177 | print(args) 178 | 179 | if not os.path.exists('./snapshots'): 180 | os.mkdir('./snapshots') 181 | 182 | # device 183 | if not torch.cuda.is_available(): 184 | device = torch.device('cpu') 185 | else: 186 | torch.cuda.set_device(args.gpu) 187 | cudnn.benchmark = True 188 | cudnn.enabled = True 189 | device = torch.device("cuda") 190 | 191 | set_seed(args.seed) 192 | 193 | criterion = nn.CrossEntropyLoss() 194 | model = SuperNetwork(init_channels=128, shadow_bn=args.shadow_bn) 195 | model = model.to(device) 196 | print("param size = %fMB" % count_parameters_in_MB(model)) 197 | 198 | base_params, bn_params = separate_bn_params(model) 199 | 200 | optimizer = torch.optim.SGD([ 201 | {'params': base_params, 'weight_decay': args.weight_decay}, 202 | {'params': bn_params, 'weight_decay': 0.0}], 203 | lr=args.learning_rate, 204 | momentum=args.momentum) 205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 206 | optimizer, float(args.epochs), eta_min=args.learning_rate_min, last_epoch=-1) 207 | 208 | if args.resume: 209 | resume_path = './snapshots/{}_states.pt.tar'.format(args.exp_name) 210 | if os.path.isfile(resume_path): 211 | print("Loading checkpoint '{}'".format(resume_path)) 212 | checkpoint = torch.load(resume_path) 213 | 214 | start_epoch = checkpoint['epoch'] 215 | optimizer.load_state_dict(checkpoint['optimizer_state']) 216 | model.load_state_dict(checkpoint['supernet_state']) 217 | scheduler.load_state_dict(checkpoint['scheduler_state']) 218 | else: 219 | raise ValueError("No checkpoint found at '{}'".format(resume_path)) 220 | else: 221 | start_epoch = 0 222 | 223 | train_transform, valid_transform = data_transforms_cifar(args) 224 | trainset = dset.CIFAR10(root=args.data_dir, train=True, download=False, transform=train_transform) 225 | train_queue = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 226 | shuffle=True, pin_memory=True, num_workers=8) 227 | valset = dset.CIFAR10(root=args.data_dir, train=False, download=False, transform=valid_transform) 228 | valid_queue = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, 229 | shuffle=False, pin_memory=True, num_workers=8) 230 | 231 | top1 = [] 232 | top5 = [] 233 | loss = [] 234 | for epoch in range(start_epoch, args.epochs): 235 | train_data = tqdm(train_queue) 236 | train_data.set_description( 237 | '[%s%04d/%04d %s%f]' % ('Epoch:', epoch+1, args.epochs, 'lr:', scheduler.get_lr()[0])) 238 | train(args, epoch, train_data, device, model, criterion=criterion, optimizer=optimizer) 239 | scheduler.step() 240 | 241 | if epoch % 2 == 1: 242 | # validation the model 243 | val_top1, val_top5, val_loss = validate(args, val_data=valid_queue, device=device, model=model) 244 | print('val loss: {:.6}, val top1: {:.6}'.format(val_loss, val_top1)) 245 | 246 | # save the states of this epoch 247 | state = { 248 | 'epoch': epoch, 249 | 'args': args, 250 | 'optimizer_state': optimizer.state_dict(), 251 | 'supernet_state': model.state_dict(), 252 | 'scheduler_state': scheduler.state_dict() 253 | } 254 | path = './snapshots/{}_states.pt.tar'.format(args.exp_name) 255 | torch.save(state, path) 256 | top1.append(val_top1), top5.append(val_top5), loss.append(val_loss) 257 | print('top1:', top1) 258 | print('top5:', top5) 259 | print('loss:', loss) 260 | 261 | 262 | candidate_dict = {} 263 | for epoch in range(args.search_num): 264 | # validation 265 | choice_dict = validate_search(args, val_data=valid_queue, device=device, model=model) 266 | candidate_dict[str(choice_dict)] = choice_dict['val_top1'] 267 | print('epoch: {:d},val loss: {:.6}, val top1: {:.6}'.format( 268 | epoch, choice_dict['val_loss'], choice_dict['val_top1'])) 269 | print(candidate_dict) 270 | 271 | # sort candidate_dict 272 | print('\n', '****************************** supernet *********************************') 273 | cand_dict = {k: v for k, v in sorted(candidate_dict.items(), key=lambda item: item[1])} 274 | for key in cand_dict.keys(): 275 | key = ast.literal_eval(key) 276 | print(key['val_top1'], ',') 277 | 278 | # look-up nasbench 279 | print('\n', '****************************** nas_bench *********************************') 280 | NASBENCH_TFRECORD = './nasbench_only108.tfrecord' 281 | nasbench = api.NASBench(NASBENCH_TFRECORD) 282 | nasbench_acc = [] 283 | for key in cand_dict.keys(): 284 | key = ast.literal_eval(key) 285 | choice = {} 286 | choice['op'] = key['op'] 287 | choice['path'] = key['path'] 288 | model_spec = nas_utils.conv_2_matrix(choice) 289 | data = nasbench.query(model_spec) 290 | nasbench_acc.append(data['validation_accuracy']) 291 | print(data['validation_accuracy'], ',') 292 | 293 | print('\n', '****************************** supernet **********************************') 294 | supernet_acc = [] 295 | for key in cand_dict.keys(): 296 | key = ast.literal_eval(key) 297 | supernet_acc.append(key['val_top1']) 298 | print(key['val_top1'], ',') 299 | 300 | # cali_bn 301 | print('\n', '****************************** cali_bn **********************************') 302 | cali_bn_acc = [] 303 | checkpoint = torch.load('./snapshots/{}_states.pt.tar'.format(args.exp_name)) 304 | for key in cand_dict.keys(): 305 | with torch.no_grad(): 306 | choice = {} 307 | key = ast.literal_eval(key) 308 | choice['op'] = key['op'] 309 | choice['path'] = key['path'] 310 | model.train() 311 | for inputs, targets in valid_queue: 312 | inputs, targets = inputs.to(device), targets.to(device) 313 | model(inputs, choice) 314 | top1_acc, _, _ = validate_cali(args, valid_queue, device, model, choice) 315 | cali_bn_acc.append(top1_acc) 316 | model.load_state_dict(checkpoint['supernet_state']) 317 | 318 | # ranking 319 | print('\n', '****************************** ranking **********************************') 320 | print('before_cali:', kendalltau(supernet_acc[30:], nasbench_acc[30:])) 321 | print('after_cali:', kendalltau(cali_bn_acc[30:], nasbench_acc[30:])) 322 | 323 | 324 | if __name__ == '__main__': 325 | main() 326 | -------------------------------------------------------------------------------- /NasBench101/nas_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from nasbench import api 5 | import torchvision.transforms as transforms 6 | 7 | 8 | class AvgrageMeter(object): 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.avg = 0 14 | self.sum = 0 15 | self.cnt = 0 16 | 17 | def update(self, val, n=1): 18 | self.sum += val * n 19 | self.cnt += n 20 | self.avg = self.sum / self.cnt 21 | 22 | 23 | class Cutout(object): 24 | def __init__(self, length): 25 | self.length = length 26 | 27 | def __call__(self, img): 28 | h, w = img.size(1), img.size(2) 29 | mask = np.ones((h, w), np.float32) 30 | y = np.random.randint(h) 31 | x = np.random.randint(w) 32 | 33 | y1 = np.clip(y - self.length // 2, 0, h) 34 | y2 = np.clip(y + self.length // 2, 0, h) 35 | x1 = np.clip(x - self.length // 2, 0, w) 36 | x2 = np.clip(x + self.length // 2, 0, w) 37 | 38 | mask[y1: y2, x1: x2] = 0. 39 | mask = torch.from_numpy(mask) 40 | mask = mask.expand_as(img) 41 | img *= mask 42 | return img 43 | 44 | 45 | def accuracy(output, label, topk=(1,)): 46 | maxk = max(topk) 47 | batch_size = label.size(0) 48 | 49 | _, pred = output.topk(maxk, 1, True, True) 50 | pred = pred.t() 51 | correct = pred.eq(label.view(1, -1).expand_as(pred)) 52 | 53 | res = [] 54 | for k in topk: 55 | correct_k = correct[:k].view(-1).float().sum(0) 56 | res.append(correct_k.mul_(100.0 / batch_size)) 57 | return res 58 | 59 | 60 | def data_transforms_cifar(args): 61 | assert args.data in ['cifar10', 'imagenet'] 62 | if args.data == 'cifar10': 63 | MEAN = [0.49139968, 0.48215827, 0.44653124] 64 | STD = [0.24703233, 0.24348505, 0.26158768] 65 | elif args.data == 'imagenet': 66 | MEAN = [0.485, 0.456, 0.406] 67 | STD = [0.229, 0.224, 0.225] 68 | 69 | if args.resize: # cifar10 or imagenet 70 | train_transform = transforms.Compose([ 71 | transforms.RandomResizedCrop(224), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), 74 | transforms.ToTensor(), 75 | transforms.Normalize(MEAN, STD) 76 | ]) 77 | valid_transform = transforms.Compose([ 78 | transforms.Resize(256), 79 | transforms.CenterCrop(224), 80 | transforms.ToTensor(), 81 | transforms.Normalize(MEAN, STD) 82 | ]) 83 | else: # cifar10 84 | train_transform = transforms.Compose([ 85 | transforms.RandomCrop(32, padding=4), 86 | transforms.RandomHorizontalFlip(), 87 | transforms.ToTensor(), 88 | transforms.Normalize(MEAN, STD) 89 | ]) 90 | valid_transform = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize(MEAN, STD) 93 | ]) 94 | 95 | if args.cutout: 96 | train_transform.transforms.append(Cutout(args.cutout_length)) 97 | 98 | return train_transform, valid_transform 99 | 100 | 101 | def random_choice(m): 102 | assert m >= 1 103 | 104 | choice = {} 105 | m_ = np.random.randint(low=1, high=m+1, size=1)[0] 106 | path_list = random.sample(range(m), m_) 107 | 108 | ops = [] 109 | for i in range(m_): 110 | ops.append(random.sample(range(3), 1)[0]) 111 | # ops.append(random.sample(range(2), 1)[0]) 112 | 113 | choice['op'] = ops 114 | choice['path'] = path_list 115 | 116 | return choice 117 | 118 | 119 | def conv_2_matrix(choice): 120 | op_ids = choice['op'] 121 | path_ids = choice['path'] 122 | selections = ['conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3'] 123 | 124 | ops = ['input'] 125 | for i in range(4): # 初始默认操作 126 | ops.append(selections[0]) 127 | for i, id in enumerate(path_ids): # 按choice修改 128 | ops[id + 1] = selections[op_ids[i]] 129 | ops.append('conv1x1-bn-relu') 130 | ops.append('output') 131 | 132 | matrix = np.zeros((7, 7), dtype=np.int) 133 | for id in path_ids: 134 | matrix[0, id + 1] = 1 135 | matrix[id + 1, 5] = 1 136 | matrix[5, -1] = 1 137 | matrix = matrix.tolist() 138 | model_spec = api.ModelSpec(matrix=matrix, ops=ops) 139 | 140 | return model_spec 141 | 142 | 143 | def count_parameters_in_MB(model): 144 | return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 145 | 146 | 147 | def set_seed(seed): 148 | # seed 149 | np.random.seed(seed) 150 | random.seed(seed) 151 | torch.manual_seed(seed) 152 | if torch.cuda.is_available(): 153 | torch.cuda.manual_seed_all(seed) 154 | torch.backends.cudnn.deterministic = True 155 | torch.backends.cudnn.benchmark = False 156 | 157 | 158 | 159 | if __name__ == '__main__': 160 | set_seed(2020) 161 | for i in range(10): 162 | choice = random_choice(m=2) 163 | print(choice) 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixPath: A Unified Approach for One-shot Neural Architecture Search 2 | 3 | This repo provides the supernet of S1 and our confirmatory experiments on NAS-Bench-101. 4 | 5 | 6 | ## Requirements 7 | 8 | ``` 9 | Python >= 3.6, Pytorch >= 1.0.0, torchvision >= 0.2.0 10 | ``` 11 | 12 | ## Datasets 13 | 14 | CIFAR-10 can be automatically downloaded by `torchvision`. It has 50,000 images for 15 | training and 10,000 images for validation. 16 | 17 | ## Usage 18 | 19 | ``` 20 | python S1/train_search.py \ 21 | --exp_name experiment_name \ 22 | --m number_of_paths[1,2,3,4] 23 | --data_dir /path/to/dataset \ 24 | --seed 2020 \ 25 | ``` 26 | ``` 27 | python NasBench101/nas_train_search.py \ 28 | --exp_name experiment_name \ 29 | --m number_of_paths[1,2,3,4] 30 | --data_dir /path/to/dataset \ 31 | --seed 2020 \ 32 | ``` 33 | 34 | ## Citation 35 | 36 | 37 | ``` 38 | @article{chu2020mixpath, 39 | title={MixPath: A Unified Approach for One-shot Neural Architecture Search}, 40 | author={Chu, Xiangxiang and Li, Xudong and Lu, Yi and Zhang, Bo and Li, Jixiang}, 41 | journal={arXiv preprint arXiv:2001.05887}, 42 | url={https://arxiv.org/abs/2001.05887}, 43 | year={2020} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /S1/model_search.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Inverted_Bottleneck(nn.Module): 7 | def __init__(self, inplanes, outplanes, shadow_bn, stride, activation=nn.ReLU6): 8 | super(Inverted_Bottleneck, self).__init__() 9 | self.inplanes = inplanes 10 | self.outplanes = outplanes 11 | self.shadow_bn = shadow_bn 12 | self.stride = stride 13 | self.kernel_list = [3, 5, 7, 9] 14 | self.expansion_rate = [3, 6] 15 | self.activation = activation(inplace=True) 16 | 17 | self.pw = nn.ModuleList([]) 18 | self.mix_conv = nn.ModuleList([]) 19 | self.mix_bn = nn.ModuleList([]) 20 | self.pw_linear = nn.ModuleList([]) 21 | 22 | for t in self.expansion_rate: 23 | # pw 24 | self.pw.append(nn.Sequential( 25 | nn.Conv2d(inplanes, inplanes * t, kernel_size=1, bias=False), 26 | nn.BatchNorm2d(inplanes * t), 27 | activation(inplace=True) 28 | )) 29 | # dw 30 | conv_list = nn.ModuleList([]) 31 | for j in self.kernel_list: 32 | conv_list.append(nn.Sequential( 33 | nn.Conv2d(inplanes * t, inplanes * t, kernel_size=j, stride=stride, padding=j // 2, 34 | bias=False, groups=inplanes * t), 35 | nn.BatchNorm2d(inplanes * t), 36 | activation(inplace=True) 37 | )) 38 | 39 | self.mix_conv.append(conv_list) 40 | del conv_list 41 | # pw 42 | self.pw_linear.append(nn.Conv2d(inplanes * t, outplanes, kernel_size=1, bias=False)) 43 | 44 | bn_list = nn.ModuleList([]) 45 | if self.shadow_bn: 46 | for j in range(len(self.kernel_list)): 47 | bn_list.append(nn.BatchNorm2d(outplanes)) 48 | self.mix_bn.append(bn_list) 49 | else: 50 | self.mix_bn.append(nn.BatchNorm2d(outplanes)) 51 | del bn_list 52 | 53 | def forward(self, x, choice): 54 | # choice: {'conv', 'rate'} 55 | conv_ids = choice['conv'] # conv_ids, e.g. [0], [1], [2], [0, 1], [0, 2], [1, 2], [0, 1, 2] 56 | m_ = len(conv_ids) # num of selected paths 57 | rate_id = choice['rate'] # rate_ids, e.g. 0, 1 58 | assert m_ in [1, 2, 3, 4] 59 | assert rate_id in [0, 1] 60 | residual = x 61 | # pw 62 | out = self.pw[rate_id](x) 63 | # dw 64 | if m_ == 1: 65 | out = self.mix_conv[rate_id][conv_ids[0]](out) 66 | else: 67 | temp = [] 68 | for id in conv_ids: 69 | temp.append(self.mix_conv[rate_id][id](out)) 70 | out = sum(temp) 71 | # pw 72 | out = self.pw_linear[rate_id](out) 73 | if self.shadow_bn: 74 | out = self.mix_bn[rate_id][m_ - 1](out) 75 | else: 76 | out = self.mix_bn[rate_id](out) 77 | 78 | # residual 79 | if self.stride == 1 and self.inplanes == self.outplanes: 80 | out = out + residual 81 | return out 82 | 83 | 84 | channel = [32, 48, 48, 96, 96, 96, 192, 192, 192, 256, 256, 320, 320] 85 | last_channel = 1280 86 | 87 | 88 | class SuperNetwork(nn.Module): 89 | def __init__(self, shadow_bn, layers=12, classes=10): 90 | super(SuperNetwork, self).__init__() 91 | self.layers = layers 92 | 93 | self.stem = nn.Sequential( 94 | nn.Conv2d(3, channel[0], kernel_size=3, stride=1, padding=1, bias=False), 95 | nn.BatchNorm2d(channel[0]), 96 | nn.ReLU6(inplace=True) 97 | ) 98 | 99 | self.Inverted_Block = nn.ModuleList([]) 100 | for i in range(self.layers): 101 | if i in [2, 5]: 102 | self.Inverted_Block.append(Inverted_Bottleneck(channel[i], channel[i + 1], shadow_bn, stride=2)) 103 | else: 104 | self.Inverted_Block.append(Inverted_Bottleneck(channel[i], channel[i + 1], shadow_bn, stride=1)) 105 | self.last_conv = nn.Sequential( 106 | nn.Conv2d(channel[-1], last_channel, kernel_size=1, stride=1, padding=0, bias=False), 107 | nn.BatchNorm2d(last_channel), 108 | nn.ReLU6(inplace=True) 109 | ) 110 | 111 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 112 | self.classifier = nn.Linear(last_channel, classes) 113 | self._initialize_weights() 114 | 115 | def _initialize_weights(self): 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 119 | m.weight.data.normal_(0, math.sqrt(2. / n)) 120 | if m.bias is not None: 121 | m.bias.data.zero_() 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1.0) 124 | m.bias.data.zero_() 125 | elif isinstance(m, nn.Linear): 126 | n = m.weight.size(0) # fan-out 127 | init_range = 1.0 / math.sqrt(n) 128 | m.weight.data.uniform_(-init_range, init_range) 129 | m.bias.data.zero_() 130 | 131 | def forward(self, x, choice=None): 132 | x = self.stem(x) 133 | for i in range(self.layers): 134 | x = self.Inverted_Block[i](x, choice[i]) 135 | x = self.last_conv(x) 136 | x = self.global_pooling(x) 137 | x = x.view(-1, last_channel) 138 | x = self.classifier(x) 139 | return x 140 | 141 | 142 | if __name__ == '__main__': 143 | choice = { 144 | 0: {'conv': [0, 0], 'rate': 1}, 145 | 1: {'conv': [0, 0], 'rate': 1}, 146 | 2: {'conv': [0, 0], 'rate': 1}, 147 | 3: {'conv': [0, 0], 'rate': 1}, 148 | 4: {'conv': [0, 0], 'rate': 1}, 149 | 5: {'conv': [0, 0], 'rate': 1}, 150 | 6: {'conv': [0, 0], 'rate': 1}, 151 | 7: {'conv': [0, 0], 'rate': 1}, 152 | 8: {'conv': [0, 0], 'rate': 1}, 153 | 9: {'conv': [0, 0], 'rate': 1}, 154 | 10: {'conv': [0, 0], 'rate': 1}, 155 | 11: {'conv': [0, 0], 'rate': 1}} 156 | 157 | model = SuperNetwork(shadow_bn=False, layers=12, classes=10) 158 | print(model) 159 | input = torch.randn((1, 3, 32, 32)) 160 | print(model(input, choice)) 161 | -------------------------------------------------------------------------------- /S1/train_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from utils import * 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | import torchvision.datasets as dset 7 | import torch.backends.cudnn as cudnn 8 | from model_search import SuperNetwork 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser("MixPath") 13 | parser.add_argument('--exp_name', type=str, required=True, help='search model name') 14 | parser.add_argument('--m', type=int, default=2, required=True, help='num of selected paths as most') 15 | parser.add_argument('--shadow_bn', action='store_false', default=True, help='shadow bn or not, default: True') 16 | parser.add_argument('--data_dir', type=str, default='/home/work/dataset/cifar', help='dataset dir') 17 | parser.add_argument('--classes', type=int, default=10, help='classes') 18 | parser.add_argument('--layers', type=int, default=12, help='num of MB_layers') 19 | parser.add_argument('--kernels', type=list, default=[3, 5, 7, 9], help='selective kernels') 20 | parser.add_argument('--batch_size', type=int, default=96, help='batch size') 21 | parser.add_argument('--epochs', type=int, default=200, help='num of epochs') 22 | parser.add_argument('--seed', type=int, default=2020, help='seed') 23 | parser.add_argument('--search_num', type=int, default=1000, help='num of epochs') 24 | parser.add_argument('--learning_rate', type=float, default=0.025, help='initial learning rate') 25 | parser.add_argument('--learning_rate_min', type=float, default=1e-8, help='min learning rate') 26 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 27 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') 28 | parser.add_argument('--train_interval', type=int, default=1, help='train to print frequency') 29 | parser.add_argument('--val_interval', type=int, default=5, help='evaluate and save frequency') 30 | parser.add_argument('--dropout_rate', type=float, default=0.2, help='drop out rate') 31 | parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop_path_prob') 32 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 33 | parser.add_argument('--gpu', type=int, default=0, help='gpu id') 34 | parser.add_argument('--resume', type=bool, default=False, help='resume') 35 | # ******************************* dataset *******************************# 36 | parser.add_argument('--dataset', type=str, default='cifar10', help='[cifar10, imagenet]') 37 | parser.add_argument('--cutout', action='store_false', default=True, help='use cutout') 38 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 39 | parser.add_argument('--colorjitter', action='store_true', default=False, help='use colorjitter') 40 | arguments = parser.parse_args() 41 | 42 | return arguments 43 | 44 | 45 | def train(args, epoch, train_data, device, model, criterion, optimizer): 46 | model.train() 47 | train_loss = 0.0 48 | top1 = AvgrageMeter() 49 | top5 = AvgrageMeter() 50 | 51 | for step, (inputs, targets) in enumerate(train_data): 52 | inputs, targets = inputs.to(device), targets.to(device) 53 | 54 | optimizer.zero_grad() 55 | choice = random_choice(path_num=len(args.kernels), m=args.m, layers=args.layers) 56 | outputs = model(inputs, choice) 57 | 58 | loss = criterion(outputs, targets) 59 | loss.backward() 60 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 61 | 62 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 63 | n = inputs.size(0) 64 | top1.update(prec1.item(), n) 65 | top5.update(prec5.item(), n) 66 | optimizer.step() 67 | train_loss += loss.item() 68 | 69 | postfix = {'train loss: {:.6}, train top1: {:.6}, train top5: {:.6}'.format( 70 | train_loss / (step + 1), top1.avg, top5.avg 71 | )} 72 | train_data.set_postfix(log=postfix) 73 | 74 | 75 | def validate(args, val_data, device, model, choice=None): 76 | model.eval() 77 | val_loss = 0.0 78 | val_top1 = AvgrageMeter() 79 | val_top5 = AvgrageMeter() 80 | criterion = nn.CrossEntropyLoss() 81 | acc_list = [] 82 | 83 | with torch.no_grad(): 84 | for step, (inputs, targets) in enumerate(val_data): 85 | inputs, targets = inputs.to(device), targets.to(device) 86 | if choice is None: 87 | choice = random_choice(path_num=len(args.kernels), m=args.m, layers=args.layers) 88 | outputs = model(inputs, choice) 89 | 90 | loss = criterion(outputs, targets) 91 | val_loss += loss.item() 92 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 93 | n = inputs.size(0) 94 | val_top1.update(prec1.item(), n) 95 | val_top5.update(prec5.item(), n) 96 | acc_list.append(val_top1.avg) 97 | 98 | return val_top1.avg, val_top5.avg, val_loss / (step + 1), acc_list 99 | 100 | 101 | def main(): 102 | args = get_args() 103 | print(args) 104 | # seed 105 | set_seed(args.seed) 106 | 107 | # prepare dir 108 | if not os.path.exists('./super_train'): 109 | os.mkdir('./super_train') 110 | if not os.path.exists('./super_train/{}'.format(args.exp_name)): 111 | os.mkdir('./super_train/{}'.format(args.exp_name)) 112 | 113 | # device 114 | if not torch.cuda.is_available(): 115 | device = torch.device('cpu') 116 | else: 117 | torch.cuda.set_device(args.gpu) 118 | cudnn.benchmark = True 119 | cudnn.enabled = True 120 | device = torch.device("cuda") 121 | 122 | criterion = nn.CrossEntropyLoss().to(device) 123 | model = SuperNetwork(shadow_bn=args.shadow_bn, layers=args.layers, classes=args.classes) 124 | model = model.to(device) 125 | print("param size = %fMB" % count_parameters_in_MB(model)) 126 | 127 | optimizer = torch.optim.SGD( 128 | model.parameters(), 129 | args.learning_rate, 130 | momentum=args.momentum, 131 | weight_decay=args.weight_decay) 132 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 133 | optimizer, float(args.epochs), eta_min=args.learning_rate_min, last_epoch=-1) 134 | 135 | if args.resume: 136 | resume_path = './super_train/{}/super_train_states.pt.tar'.format(args.exp_name) 137 | if os.path.isfile(resume_path): 138 | print("Loading checkpoint '{}'".format(resume_path)) 139 | checkpoint = torch.load(resume_path) 140 | 141 | start_epoch = checkpoint['epoch'] 142 | optimizer.load_state_dict(checkpoint['optimizer_state']) 143 | model.load_state_dict(checkpoint['supernet_state']) 144 | scheduler.laod_state_dict(checkpoint['scheduler_state']) 145 | else: 146 | raise ValueError("No checkpoint found at '{}'".format(resume_path)) 147 | else: 148 | start_epoch = 0 149 | 150 | train_transform, valid_transform = data_transforms_cifar(args) 151 | trainset = dset.CIFAR10(root=args.data_dir, train=True, download=False, transform=train_transform) 152 | train_queue = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 153 | shuffle=True, pin_memory=True, num_workers=8) 154 | valset = dset.CIFAR10(root=args.data_dir, train=False, download=False, transform=valid_transform) 155 | valid_queue = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, 156 | shuffle=False, pin_memory=True, num_workers=8) 157 | 158 | val_acc_list = [] 159 | for epoch in range(start_epoch, args.epochs): 160 | # train 161 | train_data = tqdm(train_queue) 162 | train_data.set_description( 163 | '[%s%04d/%04d %s%f]' % ('Epoch:', epoch, args.epochs, 'lr:', scheduler.get_lr()[0])) 164 | train(args, epoch, train_data, device, model, criterion=criterion, optimizer=optimizer) 165 | scheduler.step() 166 | 167 | # validate 168 | val_top1, val_top5, val_obj, val_acc = validate(args, val_data=valid_queue, device=device, model=model) 169 | val_acc_list.append(val_acc) 170 | print('val loss: {:.6}, val top1: {:.6}, val top5: {:.6}'.format(val_obj, val_top1, val_top5)) 171 | print(val_acc) 172 | 173 | # save the states of this epoch 174 | state = { 175 | 'epoch': epoch, 176 | 'args': args, 177 | 'optimizer_state': optimizer.state_dict(), 178 | 'supernet_state': model.state_dict(), 179 | 'scheduler_state': scheduler.state_dict() 180 | } 181 | path = './super_train/{}/super_train_states.pt.tar'.format(args.exp_name) 182 | torch.save(state, path) 183 | print(val_acc_list) 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /S1/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import random 4 | import numpy as np 5 | import collections 6 | from torch.autograd import Variable 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class AvgrageMeter(object): 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.avg = 0 16 | self.sum = 0 17 | self.cnt = 0 18 | 19 | def update(self, val, n=1): 20 | self.sum += val * n 21 | self.cnt += n 22 | self.avg = self.sum / self.cnt 23 | 24 | 25 | class Cutout(object): 26 | def __init__(self, length): 27 | self.length = length 28 | 29 | def __call__(self, img): 30 | h, w = img.size(1), img.size(2) 31 | mask = np.ones((h, w), np.float32) 32 | y = np.random.randint(h) 33 | x = np.random.randint(w) 34 | 35 | y1 = np.clip(y - self.length // 2, 0, h) 36 | y2 = np.clip(y + self.length // 2, 0, h) 37 | x1 = np.clip(x - self.length // 2, 0, w) 38 | x2 = np.clip(x + self.length // 2, 0, w) 39 | 40 | mask[y1: y2, x1: x2] = 0. 41 | mask = torch.from_numpy(mask) 42 | mask = mask.expand_as(img) 43 | img *= mask 44 | return img 45 | 46 | def accuracy(output, label, topk=(1,)): 47 | maxk = max(topk) 48 | batch_size = label.size(0) 49 | 50 | _, pred = output.topk(maxk, 1, True, True) 51 | pred = pred.t() 52 | correct = pred.eq(label.view(1, -1).expand_as(pred)) 53 | 54 | res = [] 55 | for k in topk: 56 | correct_k = correct[:k].view(-1).float().sum(0) 57 | res.append(correct_k.mul_(100.0 / batch_size)) 58 | return res 59 | 60 | def data_transforms_cifar(args): 61 | assert args.dataset in ['cifar10', 'imagenet'] 62 | if args.dataset == 'cifar10': 63 | MEAN = [0.49139968, 0.48215827, 0.44653124] 64 | STD = [0.24703233, 0.24348505, 0.26158768] 65 | elif args.dataset == 'imagenet': 66 | MEAN = [0.485, 0.456, 0.406] 67 | STD = [0.229, 0.224, 0.225] 68 | 69 | if args.dataset == 'cifar10': 70 | random_transform = [ 71 | transforms.RandomCrop(32, padding=4), 72 | transforms.RandomHorizontalFlip()] 73 | if args.colorjitter: 74 | random_transform += [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2)] 75 | normalize_transform = [ 76 | transforms.ToTensor(), 77 | transforms.Normalize(MEAN, STD)] 78 | train_transform = transforms.Compose( 79 | random_transform + normalize_transform 80 | ) 81 | valid_transform = transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize(MEAN, STD) 84 | ]) 85 | elif args.dataset == 'imagenet': 86 | train_transform = transforms.Compose([ 87 | transforms.RandomResizedCrop(224), 88 | transforms.RandomHorizontalFlip(), 89 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), 90 | transforms.ToTensor(), 91 | transforms.Normalize(MEAN, STD) 92 | ]) 93 | valid_transform = transforms.Compose([ 94 | transforms.Resize(256), 95 | transforms.CenterCrop(224), 96 | transforms.ToTensor(), 97 | transforms.Normalize(MEAN, STD) 98 | ]) 99 | 100 | if args.cutout: 101 | train_transform.transforms.append(Cutout(args.cutout_length)) 102 | 103 | return train_transform, valid_transform 104 | 105 | 106 | def random_choice(path_num, m, layers): 107 | # choice = {} 108 | choice = collections.OrderedDict() 109 | for i in range(layers): 110 | # expansion rate 111 | rate = np.random.randint(low=0, high=2, size=1)[0] 112 | # conv 113 | m_ = np.random.randint(low=1, high=(m+1), size=1)[0] 114 | rand_conv = random.sample(range(path_num), m_) 115 | choice[i] = {'conv': rand_conv, 'rate': rate} 116 | return choice 117 | 118 | 119 | def count_parameters_in_MB(model): 120 | return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 121 | 122 | 123 | def eta_time(elapse, epoch): 124 | eta = epoch * elapse 125 | hour = eta // 3600 126 | minute = (eta - hour * 3600) // 60 127 | second = eta - hour * 3600 - minute * 60 128 | return hour, minute, second 129 | 130 | 131 | def time_record(start): 132 | end = time.time() 133 | duration = end - start 134 | hour = duration // 3600 135 | minute = (duration - hour * 3600) // 60 136 | second = duration - hour * 3600 - minute * 60 137 | print('Elapsed: hour: %d, minute: %d, second: %f' % (hour, minute, second)) 138 | 139 | 140 | def drop_path(x, drop_prob): 141 | if drop_prob > 0.: 142 | keep_prob = 1. - drop_prob 143 | if str(x.device) == 'cpu': 144 | mask = Variable(torch.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 145 | else: 146 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 147 | x.div_(keep_prob) 148 | x.mul_(mask) 149 | return x 150 | 151 | def set_seed(seed): 152 | # seed 153 | np.random.seed(seed) 154 | random.seed(seed) 155 | torch.manual_seed(seed) 156 | if torch.cuda.is_available(): 157 | torch.cuda.manual_seed_all(seed) 158 | torch.backends.cudnn.deterministic = True 159 | 160 | 161 | if __name__ == '__main__': 162 | for i in range(8): 163 | np.random.seed(12) 164 | random.seed(12) 165 | choice = random_choice(path_num=3, m=2, layers=12) 166 | print(choice) 167 | --------------------------------------------------------------------------------