├── utils ├── losses.py ├── utils.py ├── summary.py ├── preprocessing.py ├── second_order_update.py └── dataset.py ├── README.md ├── genotypes.py ├── nets ├── operations.py ├── eval_model.py └── search_model.py ├── cifar_eval.py ├── imgnet_eval.py └── cifar_search.py /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | def __init__(self, num_classes, epsilon): 6 | super(CrossEntropyLabelSmooth, self).__init__() 7 | self.num_classes = num_classes 8 | self.epsilon = epsilon 9 | self.logsoftmax = nn.LogSoftmax(dim=1) 10 | 11 | def forward(self, inputs, targets): 12 | log_probs = self.logsoftmax(inputs) 13 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 14 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 15 | loss = (-targets * log_probs).mean(0).sum() 16 | return loss 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A simple Pytorch implementation of Differentiable Architecture Search (DARTS) 2 | 3 | This repository is my pytorch implementation of [Differentiable Architecture Search (DARTS)](https://arxiv.org/abs/1806.09055). 4 | Some of the code is taken from the [offical implementation](https://github.com/quark0/darts). 5 | 6 | ## Requirements: 7 | - python >= 3.5 8 | - pytorch >= 1.0 9 | - tensorboardX (optional) 10 | 11 | ## Search 12 | ``` 13 | python3 cifar_search.py --log_name darts_cifar_search --order 1st --gpus 0 14 | ``` 15 | 16 | ## Evaluate 17 | * ```python cifar_eval.py --log_name darts_cifar_search --gpus 0``` 18 | * ```python imgnet_eval.py --log_name darts_cifar_search --data_dir YOUR_IMGNET_DIR --gpus 0,1,2,3``` 19 | 20 | ## CIFAR10 Results: 21 | Method|Acc.| 22 | :---:|:---: 23 | DARTS 1st order|97.06%| 24 | DARTS 2st order|97.36%| 25 | 26 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import shutil 5 | import pickle 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def snapshot(ckpt_name): 12 | os.makedirs(os.path.join('./scripts', ckpt_name), exist_ok=True) 13 | os.makedirs(os.path.join('./scripts', ckpt_name, 'nets'), exist_ok=True) 14 | os.makedirs(os.path.join('./scripts', ckpt_name, 'utils'), exist_ok=True) 15 | for script in glob.glob('*.py'): 16 | dst_file = os.path.join('scripts', ckpt_name, script) 17 | shutil.copyfile(script, dst_file) 18 | for script in glob.glob('nets/*.py'): 19 | dst_file = os.path.join('scripts', ckpt_name, script) 20 | shutil.copyfile(script, dst_file) 21 | for script in glob.glob('utils/*.py'): 22 | dst_file = os.path.join('scripts', ckpt_name, script) 23 | shutil.copyfile(script, dst_file) 24 | 25 | 26 | def save_genotype(model, ckpt_dir): 27 | print('Genotype: ', model.genotype()) 28 | genotype = {'genotype': model.genotype()} 29 | with open(os.path.join(ckpt_dir, 'genotype.pickle'), 'wb') as handle: 30 | pickle.dump(genotype, handle, protocol=pickle.HIGHEST_PROTOCOL) 31 | 32 | 33 | def save_checkpoint(model, ckpt_dir, append=''): 34 | torch.save(model.state_dict(), os.path.join(ckpt_dir, 'checkpoint%s.t7' % append)) 35 | print('checkpoint saved to %s !' % os.path.join(ckpt_dir, 'checkpoint%s.t7' % append)) 36 | 37 | 38 | def count_parameters(model): 39 | num_paras = [v.numel() / 1e6 for k, v in model.named_parameters() if 'aux' not in k] 40 | print("Total num of param = %f M" % sum(num_paras)) 41 | 42 | 43 | def count_flops(model, input_size=224): 44 | flops = [] 45 | handles = [] 46 | 47 | def conv_hook(self, input, output): 48 | flops.append(output.shape[2] ** 2 * 49 | self.kernel_size[0] ** 2 * 50 | self.in_channels * 51 | self.out_channels / 52 | self.groups / 1e6) 53 | 54 | def fc_hook(self, input, output): 55 | flops.append(self.in_features * self.out_features / 1e6) 56 | 57 | for m in model.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | handles.append(m.register_forward_hook(conv_hook)) 60 | if isinstance(m, nn.Linear): 61 | handles.append(m.register_forward_hook(fc_hook)) 62 | 63 | _ = model(torch.randn(2, 3, input_size, input_size)) 64 | print("Total FLOPs = %f M" % sum(flops)) 65 | 66 | 67 | class DisablePrint: 68 | def __enter__(self): 69 | self._original_stdout = sys.stdout 70 | sys.stdout = open(os.devnull, 'w') 71 | 72 | def __exit__(self, exc_type, exc_val, exc_tb): 73 | sys.stdout.close() 74 | sys.stdout = self._original_stdout 75 | -------------------------------------------------------------------------------- /utils/summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from datetime import datetime 5 | 6 | import torch 7 | 8 | # return a fake summarywriter if tensorbaordX is not installed 9 | try: 10 | from tensorboardX import SummaryWriter 11 | except ImportError: 12 | class SummaryWriter: 13 | def __init__(self, log_dir=None, comment='', **kwargs): 14 | print('\nunable to import tensorboardX, summary will be recorded by torch!\n') 15 | self.log_dir = log_dir if log_dir is not None else './logs' 16 | os.makedirs('./logs', exist_ok=True) 17 | self.logs = {'comment': comment} 18 | return 19 | 20 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 21 | if tag in self.logs: 22 | self.logs[tag].append((scalar_value, global_step, walltime)) 23 | else: 24 | self.logs[tag] = [(scalar_value, global_step, walltime)] 25 | return 26 | 27 | def close(self): 28 | timestamp = str(datetime.now()).replace(' ', '_').replace(':', '_') 29 | torch.save(self.logs, os.path.join(self.log_dir, 'summary_%s.pickle' % timestamp)) 30 | return 31 | 32 | 33 | class EmptySummaryWriter: 34 | def __init__(self, **kwargs): 35 | pass 36 | 37 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 38 | pass 39 | 40 | def close(self): 41 | pass 42 | 43 | 44 | def create_summary(distributed_rank=0, **kwargs): 45 | if distributed_rank > 0: 46 | return EmptySummaryWriter(**kwargs) 47 | else: 48 | return SummaryWriter(**kwargs) 49 | 50 | 51 | def create_logger(distributed_rank=0, save_dir=None): 52 | logger = logging.getLogger('logger') 53 | logger.setLevel(logging.DEBUG) 54 | 55 | filename = "log_%s.txt" % (datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) 56 | 57 | # don't log results for the non-master process 58 | if distributed_rank > 0: 59 | return logger 60 | ch = logging.StreamHandler(stream=sys.stdout) 61 | ch.setLevel(logging.DEBUG) 62 | # formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 63 | formatter = logging.Formatter("[%(asctime)s] %(message)s") 64 | ch.setFormatter(formatter) 65 | logger.addHandler(ch) 66 | 67 | if save_dir is not None: 68 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 69 | fh.setLevel(logging.DEBUG) 70 | fh.setFormatter(formatter) 71 | logger.addHandler(fh) 72 | 73 | return logger 74 | 75 | 76 | if __name__ == '__main__': 77 | sw = create_summary(distributed_rank=1, log_dir='./') 78 | sw.close() 79 | if __name__ == '__main__': 80 | logger = create_logger(save_dir='./', distributed_rank=0) 81 | logger.info('this is info') 82 | print(logging.getLogger('logger').info('this is info')) 83 | -------------------------------------------------------------------------------- /genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 4 | 5 | PRIMITIVES = [ 6 | 'none', 7 | 'max_pool_3x3', 8 | 'avg_pool_3x3', 9 | 'skip_connect', 10 | 'sep_conv_3x3', 11 | 'sep_conv_5x5', 12 | 'dil_conv_3x3', 13 | 'dil_conv_5x5' 14 | ] 15 | 16 | NASNet = Genotype( 17 | normal=[ 18 | ('sep_conv_5x5', 1), 19 | ('sep_conv_3x3', 0), 20 | ('sep_conv_5x5', 0), 21 | ('sep_conv_3x3', 0), 22 | ('avg_pool_3x3', 1), 23 | ('skip_connect', 0), 24 | ('avg_pool_3x3', 0), 25 | ('avg_pool_3x3', 0), 26 | ('sep_conv_3x3', 1), 27 | ('skip_connect', 1), 28 | ], 29 | normal_concat=[2, 3, 4, 5, 6], 30 | reduce=[ 31 | ('sep_conv_5x5', 1), 32 | ('sep_conv_7x7', 0), 33 | ('max_pool_3x3', 1), 34 | ('sep_conv_7x7', 0), 35 | ('avg_pool_3x3', 1), 36 | ('sep_conv_5x5', 0), 37 | ('skip_connect', 3), 38 | ('avg_pool_3x3', 2), 39 | ('sep_conv_3x3', 2), 40 | ('max_pool_3x3', 1), 41 | ], 42 | reduce_concat=[4, 5, 6], 43 | ) 44 | 45 | AmoebaNet = Genotype( 46 | normal=[ 47 | ('avg_pool_3x3', 0), 48 | ('max_pool_3x3', 1), 49 | ('sep_conv_3x3', 0), 50 | ('sep_conv_5x5', 2), 51 | ('sep_conv_3x3', 0), 52 | ('avg_pool_3x3', 3), 53 | ('sep_conv_3x3', 1), 54 | ('skip_connect', 1), 55 | ('skip_connect', 0), 56 | ('avg_pool_3x3', 1), 57 | ], 58 | normal_concat=[4, 5, 6], 59 | reduce=[ 60 | ('avg_pool_3x3', 0), 61 | ('sep_conv_3x3', 1), 62 | ('max_pool_3x3', 0), 63 | ('sep_conv_7x7', 2), 64 | ('sep_conv_7x7', 0), 65 | ('avg_pool_3x3', 1), 66 | ('max_pool_3x3', 0), 67 | ('max_pool_3x3', 1), 68 | ('conv_7x1_1x7', 0), 69 | ('sep_conv_3x3', 5), 70 | ], 71 | reduce_concat=[3, 4, 6] 72 | ) 73 | 74 | DARTS_V1 = Genotype( 75 | normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), 76 | ('skip_connect', 0), 77 | ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], 78 | normal_concat=[2, 3, 4, 5], 79 | reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), 80 | ('max_pool_3x3', 0), 81 | ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], 82 | reduce_concat=[2, 3, 4, 5]) 83 | DARTS_V2 = Genotype( 84 | normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), 85 | ('sep_conv_3x3', 1), 86 | ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], 87 | normal_concat=[2, 3, 4, 5], 88 | reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), 89 | ('max_pool_3x3', 0), 90 | ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], 91 | reduce_concat=[2, 3, 4, 5]) 92 | -------------------------------------------------------------------------------- /utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | 7 | class Cutout: 8 | def __init__(self, length): 9 | self.length = length 10 | 11 | def __call__(self, img): 12 | h, w = img.size(1), img.size(2) 13 | mask = np.ones((h, w), np.float32) 14 | y = np.random.randint(h) 15 | x = np.random.randint(w) 16 | 17 | y1 = np.clip(y - self.length // 2, 0, h) 18 | y2 = np.clip(y + self.length // 2, 0, h) 19 | x1 = np.clip(x - self.length // 2, 0, w) 20 | x2 = np.clip(x + self.length // 2, 0, w) 21 | 22 | mask[y1: y2, x1: x2] = 0. 23 | mask = torch.from_numpy(mask) 24 | mask = mask.expand_as(img) 25 | img *= mask 26 | return img 27 | 28 | 29 | def mnist_transform(is_training=True): 30 | if is_training: 31 | transform_list = transforms.Compose([transforms.ToTensor(), 32 | transforms.Normalize((0.1307,), (0.3081,))]) 33 | else: 34 | transform_list = transforms.Compose([transforms.ToTensor(), 35 | transforms.Normalize((0.1307,), (0.3081,))]) 36 | return transform_list 37 | 38 | 39 | def cifar_search_transform(is_training=True, cutout=None): 40 | transform_list = [] 41 | if is_training: 42 | transform_list += [transforms.RandomCrop(32, padding=4), 43 | transforms.RandomHorizontalFlip()] 44 | 45 | transform_list += [transforms.ToTensor(), 46 | transforms.Normalize([0.49139968, 0.48215827, 0.44653124], 47 | [0.24703233, 0.24348505, 0.26158768])] 48 | 49 | if cutout is not None: 50 | transform_list += [Cutout(cutout)] 51 | 52 | return transforms.Compose(transform_list) 53 | 54 | 55 | def cifar_transform(is_training=True): 56 | # Data 57 | if is_training: 58 | transform_list = transforms.Compose([transforms.RandomHorizontalFlip(), 59 | transforms.Pad(4, padding_mode='reflect'), 60 | transforms.RandomCrop(32, padding=0), 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.4914, 0.4822, 0.4465), 63 | (0.2023, 0.1994, 0.2010))]) 64 | 65 | else: 66 | transform_list = transforms.Compose([transforms.ToTensor(), 67 | transforms.Normalize((0.4914, 0.4822, 0.4465), 68 | (0.2023, 0.1994, 0.2010))]) 69 | 70 | return transform_list 71 | 72 | 73 | def imgnet_transform(is_training=True): 74 | if is_training: 75 | transform_list = transforms.Compose([transforms.RandomResizedCrop(224), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ColorJitter(brightness=0.4, 78 | contrast=0.4, 79 | saturation=0.4, 80 | hue=0.2), 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 83 | std=[0.229, 0.224, 0.225])]) 84 | else: 85 | transform_list = transforms.Compose([transforms.Resize(256), 86 | transforms.CenterCrop(224), 87 | transforms.ToTensor(), 88 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 89 | std=[0.229, 0.224, 0.225])]) 90 | return transform_list 91 | -------------------------------------------------------------------------------- /utils/second_order_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def update(model, updated_model, criterion, optimizer_w, 7 | inputs_val, targets_val, inputs_train, targets_train): 8 | # extract some hyper-parameters 9 | lr = optimizer_w.param_groups[0]['lr'] 10 | wd = optimizer_w.param_groups[0]['weight_decay'] 11 | momentum = optimizer_w.param_groups[0]['momentum'] 12 | 13 | # ------------------------------------------------------------------------------------- 14 | # -----------------------step1: get an updated model w.r.t train loss------------------ 15 | # ------------------------------------------------------------------------------------- 16 | 17 | # forward & calc loss 18 | loss = criterion(model(inputs_train), targets_train) # L_train(w) 19 | 20 | # compute gradient 21 | weights = [v for k, v in model.named_parameters() if 'alpha' not in k] 22 | new_weights = [v for k, v in updated_model.named_parameters() if 'alpha' not in k] 23 | gradients = torch.autograd.grad(loss, weights) 24 | 25 | # do virtual step (update gradient) 26 | # below operations do not need gradient tracking 27 | with torch.no_grad(): 28 | # optimizer.state is a dict, which uses model's parameters as keys 29 | # however, the dict key is not the value, but the pointer. 30 | # so original network weight have to be iterated also. 31 | for w, new_w, grad in zip(weights, new_weights, gradients): 32 | mom = optimizer_w.state[w].get('momentum_buffer', 0.) * momentum 33 | new_w.copy_(w - lr * (mom + grad + wd * w)) 34 | 35 | alphas = [v for k, v in model.named_parameters() if 'alpha' in k] 36 | new_alphas = [v for k, v in updated_model.named_parameters() if 'alpha' in k] 37 | # simply copy the value of alphas 38 | for a, new_a in zip(alphas, new_alphas): 39 | new_a.copy_(a) 40 | 41 | # ------------------------------------------------------------------------------------- 42 | # ------------------step2: get dL_val(w', a)/dw' and dL_val(w', a)/da------------------ 43 | # ------------------------------------------------------------------------------------- 44 | 45 | # calc val loss on updated model 46 | val_loss = criterion(updated_model(inputs_val), targets_val) # L_val(w', a) 47 | 48 | # compute gradient 49 | grad_new = torch.autograd.grad(val_loss, new_alphas + new_weights) 50 | grad_new_alphas = grad_new[:len(new_alphas)] 51 | grad_new_weights = grad_new[len(new_alphas):] 52 | 53 | # ------------------------------------------------------------------------------------- 54 | # ---------------------------step3: compute approximated hessian----------------------- 55 | # ------------------------------------------------------------------------------------- 56 | 57 | # dw = dw' { L_val(w', a) } 58 | # w+ = w + eps * dw 59 | # w- = w - eps * dw 60 | # hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) 61 | 62 | # eps = 0.01 / ||dw|| 63 | norm = torch.cat([g.view(-1) for g in grad_new_weights]).norm() 64 | eps = 0.01 / norm 65 | 66 | # w+ = w + eps*dw' 67 | with torch.no_grad(): 68 | for w, grad in zip(weights, grad_new_weights): 69 | w += eps * grad 70 | 71 | loss = criterion(model(inputs_train), targets_train) # L_train(w+) 72 | grad_alphas_pos = torch.autograd.grad(loss, alphas) # dalpha { L_train(w+) } 73 | 74 | # w- = w - eps*dw' 75 | with torch.no_grad(): 76 | for w, grad in zip(weights, grad_new_weights): 77 | w -= 2. * eps * grad 78 | 79 | loss = criterion(model(inputs_train), targets_train) # L_train(w-) 80 | grad_alphas_neg = torch.autograd.grad(loss, alphas) # dalpha { L_train(w-) } 81 | 82 | # recover w 83 | with torch.no_grad(): 84 | for w, grad in zip(weights, grad_new_weights): 85 | w += eps * grad 86 | 87 | hessian = [(g_pos - g_neg) / 2. * eps for g_pos, g_neg 88 | in zip(grad_alphas_pos, grad_alphas_neg)] 89 | 90 | # ------------------------------------------------------------------------------------- 91 | # -----------------------------------step4: update alphas------------------------------ 92 | # ------------------------------------------------------------------------------------- 93 | 94 | # update final gradient = dalpha - xi*hessian 95 | with torch.no_grad(): 96 | for a, grad_a, h in zip(alphas, grad_new_alphas, hessian): 97 | a.grad = grad_a - lr * h 98 | 99 | return val_loss 100 | 101 | 102 | def alpha_entropy_grad(alphas): 103 | with torch.no_grad(): 104 | probs = F.softmax(alphas, dim=1) 105 | dw = -torch.bmm(probs[:, :, None], probs[:, None, :]) 106 | dw[:, torch.arange(alphas.shape[1]), torch.arange(alphas.shape[1])] += probs 107 | grad = (-dw * (torch.log(probs[:, :, None]) + 1)).sum(1) 108 | return grad 109 | 110 | 111 | def alpha_entropy_loss(alphas, axis=-1): 112 | probs = F.softmax(alphas, dim=axis) 113 | entropy = ((-probs * probs.log()).sum(axis)).mean() 114 | return entropy 115 | 116 | # if __name__ == '__main__': 117 | # alpha_entropy_grad(torch.randn(14,8)) 118 | 119 | # if __name__ == '__main__': 120 | # from copy import deepcopy 121 | # from nets.cifar_search_model import * 122 | # 123 | # model = Network(6, 8, 4) 124 | # temp_model = Network(6, 8, 4) 125 | # 126 | # alphas=[v for k,v in model.named_parameters() if 'alpha' in k] 127 | # weights = [v for k, v in model.named_parameters() if 'alpha' not in k] 128 | # optimizer = torch.optim.SGD(weights, lr=0.01, momentum=0.9, weight_decay=5e-4) 129 | # alpha_optim = torch.optim.Adam(alphas, 1e-2, betas=(0.5, 0.999), weight_decay=1e-4) 130 | # criterion = nn.CrossEntropyLoss() 131 | # 132 | # loss = criterion(model(torch.randn(10, 3, 32, 32)), torch.randint(0, 10, [10]).long()) 133 | # loss.backward() 134 | # optimizer.step() 135 | # 136 | # alpha_optim.zero_grad() 137 | # before = deepcopy(model) 138 | # update(model, temp_model, criterion, optimizer, 139 | # torch.randn(10, 3, 32, 32), torch.randint(0, 10, [10]).long(), 140 | # torch.randn(10, 3, 32, 32), torch.randint(0, 10, [10]).long()) 141 | # alpha_optim.step() 142 | # after = model 143 | -------------------------------------------------------------------------------- /nets/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from genotypes import * 4 | 5 | 6 | class MixedLayer(nn.Module): 7 | def __init__(self, c, stride, op_names): 8 | super(MixedLayer, self).__init__() 9 | self.op_names = op_names 10 | self.layers = nn.ModuleList() 11 | """ 12 | PRIMITIVES = [ 13 | 'none', 14 | 'max_pool_3x3', 15 | 'avg_pool_3x3', 16 | 'skip_connect', 17 | 'sep_conv_3x3', 18 | 'sep_conv_5x5', 19 | 'dil_conv_3x3', 20 | 'dil_conv_5x5' 21 | ] 22 | """ 23 | for primitive in op_names: 24 | layer = OPS[primitive](c, stride, False) 25 | if 'pool' in primitive: 26 | layer = nn.Sequential(layer, nn.BatchNorm2d(c, affine=False)) 27 | 28 | self.layers.append(layer) 29 | 30 | def forward(self, x, weights): 31 | return sum([w * layer(x) for w, layer in zip(weights, self.layers)]) 32 | 33 | 34 | # OPS is a set of layers with same input/output channel. 35 | 36 | OPS = {'none': lambda C, stride, affine: Zero(stride), 37 | 'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, 38 | padding=1, count_include_pad=False), 39 | 'max_pool_3x3': lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 40 | 'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 41 | 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 42 | 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 43 | 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 44 | 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 45 | 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 46 | 47 | 'conv_7x1_1x7': lambda C, stride, affine: nn.Sequential( 48 | nn.ReLU(inplace=False), 49 | nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False), 50 | nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False), 51 | nn.BatchNorm2d(C, affine=affine))} 52 | 53 | 54 | class ReLUConvBN(nn.Module): 55 | """ 56 | Stack of relu-conv-bn 57 | """ 58 | 59 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 60 | """ 61 | 62 | :param C_in: 63 | :param C_out: 64 | :param kernel_size: 65 | :param stride: 66 | :param padding: 67 | :param affine: 68 | """ 69 | super(ReLUConvBN, self).__init__() 70 | 71 | self.op = nn.Sequential( 72 | nn.ReLU(inplace=False), 73 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 74 | nn.BatchNorm2d(C_out, affine=affine)) 75 | 76 | def forward(self, x): 77 | return self.op(x) 78 | 79 | 80 | class DilConv(nn.Module): 81 | """ 82 | relu-dilated conv-bn 83 | """ 84 | 85 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 86 | """ 87 | 88 | :param C_in: 89 | :param C_out: 90 | :param kernel_size: 91 | :param stride: 92 | :param padding: 2/4 93 | :param dilation: 2 94 | :param affine: 95 | """ 96 | super(DilConv, self).__init__() 97 | 98 | self.op = nn.Sequential( 99 | nn.ReLU(inplace=False), 100 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, 101 | dilation=dilation, groups=C_in, bias=False), 102 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 103 | nn.BatchNorm2d(C_out, affine=affine)) 104 | 105 | def forward(self, x): 106 | return self.op(x) 107 | 108 | 109 | class SepConv(nn.Module): 110 | """ 111 | implemented separate convolution via pytorch groups parameters 112 | """ 113 | 114 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 115 | """ 116 | 117 | :param C_in: 118 | :param C_out: 119 | :param kernel_size: 120 | :param stride: 121 | :param padding: 1/2 122 | :param affine: 123 | """ 124 | super(SepConv, self).__init__() 125 | 126 | self.op = nn.Sequential( 127 | nn.ReLU(inplace=False), 128 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, 129 | groups=C_in, bias=False), 130 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 131 | nn.BatchNorm2d(C_in, affine=affine), 132 | nn.ReLU(inplace=False), 133 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, 134 | groups=C_in, bias=False), 135 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 136 | nn.BatchNorm2d(C_out, affine=affine)) 137 | 138 | def forward(self, x): 139 | return self.op(x) 140 | 141 | 142 | class Identity(nn.Module): 143 | 144 | def __init__(self): 145 | super(Identity, self).__init__() 146 | 147 | def forward(self, x): 148 | return x 149 | 150 | 151 | class Zero(nn.Module): 152 | """ 153 | zero by stride 154 | """ 155 | 156 | def __init__(self, stride): 157 | super(Zero, self).__init__() 158 | 159 | self.stride = stride 160 | 161 | def forward(self, x): 162 | if self.stride == 1: 163 | return x.mul(0.) 164 | return x[:, :, ::self.stride, ::self.stride].mul(0.) 165 | 166 | 167 | class FactorizedReduce(nn.Module): 168 | """ 169 | reduce feature maps height/width by half while keeping channel same 170 | """ 171 | 172 | def __init__(self, C_in, C_out, affine=True): 173 | """ 174 | 175 | :param C_in: 176 | :param C_out: 177 | :param affine: 178 | """ 179 | super(FactorizedReduce, self).__init__() 180 | 181 | assert C_out % 2 == 0 182 | 183 | self.relu = nn.ReLU(inplace=False) 184 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 185 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 186 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 187 | 188 | def forward(self, x): 189 | x = self.relu(x) 190 | 191 | # x: torch.Size([32, 32, 32, 32]) 192 | # conv1: [b, c_out//2, d//2, d//2] 193 | # conv2: [] 194 | # out: torch.Size([32, 32, 16, 16]) 195 | 196 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) 197 | out = self.bn(out) 198 | return out 199 | -------------------------------------------------------------------------------- /cifar_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.distributed as dist 8 | 9 | from nets.eval_model import NetworkCIFAR 10 | 11 | from utils.utils import count_parameters, count_flops, DisablePrint 12 | from utils.dataset import CIFAR_split 13 | from utils.preprocessing import cifar_search_transform 14 | from utils.summary import create_summary, create_logger 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='darts') 20 | 21 | parser.add_argument('--local_rank', type=int, default=0) 22 | parser.add_argument('--dist', action='store_true') 23 | 24 | parser.add_argument('--root_dir', type=str, default='./') 25 | parser.add_argument('--data_dir', type=str, default='./data') 26 | parser.add_argument('--log_name', type=str, default='test') 27 | 28 | parser.add_argument('--lr', type=float, default=0.025) 29 | parser.add_argument('--wd', type=float, default=3e-4) 30 | 31 | parser.add_argument('--init_ch', type=int, default=36) 32 | parser.add_argument('--num_cells', type=int, default=20) 33 | 34 | parser.add_argument('--auxiliary', type=float, default=0.4) 35 | parser.add_argument('--cutout', type=int, default=16) 36 | parser.add_argument('--drop_path_prob', type=float, default=0.2) 37 | 38 | parser.add_argument('--batch_size', type=int, default=96) 39 | parser.add_argument('--max_epochs', type=int, default=600) 40 | 41 | parser.add_argument('--log_interval', type=int, default=10) 42 | parser.add_argument('--gpus', type=str, default='0') 43 | parser.add_argument('--num_workers', type=int, default=3) 44 | 45 | cfg = parser.parse_args() 46 | 47 | os.chdir(cfg.root_dir) 48 | 49 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 50 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus 51 | 52 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name + '_eval') 53 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 54 | 55 | os.makedirs(cfg.log_dir, exist_ok=True) 56 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 57 | 58 | 59 | def main(): 60 | logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir) 61 | summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir) 62 | print = logger.info 63 | 64 | print(cfg) 65 | num_gpus = torch.cuda.device_count() 66 | if cfg.dist: 67 | device = torch.device('cuda:%d' % cfg.local_rank) if cfg.dist else torch.device('cuda') 68 | torch.cuda.set_device(cfg.local_rank) 69 | dist.init_process_group(backend='nccl', init_method='env://', 70 | world_size=num_gpus, rank=cfg.local_rank) 71 | else: 72 | device = torch.device('cuda') 73 | 74 | print('==> Preparing data..') 75 | cifar = 100 if 'cifar100' in cfg.log_name else 10 76 | train_dataset = CIFAR_split(cifar=cifar, root=cfg.data_dir, split='train', ratio=1.0, 77 | transform=cifar_search_transform(is_training=True, cutout=cfg.cutout)) 78 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, 79 | num_replicas=num_gpus, 80 | rank=cfg.local_rank) 81 | train_loader = torch.utils.data.DataLoader(train_dataset, 82 | batch_size=cfg.batch_size // num_gpus if cfg.dist 83 | else cfg.batch_size, 84 | shuffle=not cfg.dist, 85 | num_workers=cfg.num_workers, 86 | sampler=train_sampler if cfg.dist else None) 87 | 88 | test_dataset = CIFAR_split(cifar=cifar, root=cfg.data_dir, split='test', 89 | transform=cifar_search_transform(is_training=False)) 90 | test_loader = torch.utils.data.DataLoader(test_dataset, 91 | batch_size=cfg.batch_size, 92 | shuffle=False, 93 | num_workers=cfg.num_workers) 94 | 95 | print('==> Building model..') 96 | genotype = torch.load(os.path.join(cfg.ckpt_dir, 'genotype.pickle'))['genotype'] 97 | model = NetworkCIFAR(genotype, cfg.init_ch, cfg.num_cells, cfg.auxiliary, num_classes=cifar) 98 | 99 | if not cfg.dist: 100 | model = nn.DataParallel(model).to(device) 101 | else: 102 | # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 103 | model = model.to(device) 104 | model = nn.parallel.DistributedDataParallel(model, 105 | device_ids=[cfg.local_rank, ], 106 | output_device=cfg.local_rank) 107 | 108 | optimizer = torch.optim.SGD(model.parameters(), cfg.lr, momentum=0.9, weight_decay=cfg.wd) 109 | criterion = nn.CrossEntropyLoss().to(device) 110 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.max_epochs) 111 | 112 | # Training 113 | def train(epoch): 114 | model.train() 115 | 116 | start_time = time.time() 117 | for batch_idx, (inputs, targets) in enumerate(train_loader): 118 | inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) 119 | 120 | outputs, outputs_aux = model(inputs) 121 | loss = criterion(outputs, targets) 122 | loss_aux = criterion(outputs_aux, targets) 123 | loss += cfg.auxiliary * loss_aux 124 | 125 | optimizer.zero_grad() 126 | loss.backward() 127 | nn.utils.clip_grad_norm_(model.parameters(), 5.0) 128 | optimizer.step() 129 | 130 | if batch_idx % cfg.log_interval == 0: 131 | step = len(train_loader) * epoch + batch_idx 132 | duration = time.time() - start_time 133 | 134 | print('[%d/%d - %d/%d] cls_loss= %.5f (%d samples/sec)' % 135 | (epoch, cfg.max_epochs, batch_idx, len(train_loader), 136 | loss.item(), cfg.batch_size * cfg.log_interval / duration)) 137 | 138 | start_time = time.time() 139 | summary_writer.add_scalar('cls_loss', loss.item(), step) 140 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) 141 | 142 | def test(epoch): 143 | model.eval() 144 | correct = 0 145 | with torch.no_grad(): 146 | for batch_idx, (inputs, targets) in enumerate(test_loader): 147 | inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) 148 | 149 | outputs, _ = model(inputs) 150 | _, predicted = torch.max(outputs.data, 1) 151 | correct += predicted.eq(targets.data).cpu().sum().item() 152 | 153 | acc = 100. * correct / len(test_loader.dataset) 154 | print(' Precision@1 ==> %.2f%% \n' % acc) 155 | summary_writer.add_scalar('Precision@1', acc, global_step=epoch) 156 | return 157 | 158 | for epoch in range(cfg.max_epochs): 159 | print('\nEpoch: %d lr: %.5f drop_path_prob: %.3f' % 160 | (epoch, scheduler.get_lr()[0], cfg.drop_path_prob * epoch / cfg.max_epochs)) 161 | model._modules['module'].drop_path_prob = cfg.drop_path_prob * epoch / cfg.max_epochs 162 | train_sampler.set_epoch(epoch) 163 | train(epoch) 164 | test(epoch) 165 | scheduler.step(epoch) # move to here after pytorch1.1.0 166 | print(model.module.genotype()) 167 | if cfg.local_rank == 0: 168 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 169 | 170 | summary_writer.close() 171 | count_parameters(model) 172 | count_flops(model, input_size=32) 173 | 174 | 175 | if __name__ == '__main__': 176 | if cfg.local_rank == 0: 177 | main() 178 | else: 179 | with DisablePrint(): 180 | main() 181 | -------------------------------------------------------------------------------- /imgnet_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.distributed as dist 8 | 9 | from nets.eval_model import NetworkImageNet 10 | 11 | from utils.utils import count_parameters, count_flops, DisablePrint 12 | from utils.losses import CrossEntropyLabelSmooth 13 | from utils.dataset import ImgNet_split 14 | from utils.preprocessing import imgnet_transform 15 | from utils.summary import create_summary, create_logger 16 | 17 | torch.backends.cudnn.benchmark = True 18 | 19 | # Training settings 20 | parser = argparse.ArgumentParser(description='darts') 21 | 22 | parser.add_argument('--local_rank', type=int, default=0) 23 | parser.add_argument('--dist', action='store_true') 24 | 25 | parser.add_argument('--root_dir', type=str, default='./') 26 | parser.add_argument('--data_dir', type=str, default='./data') 27 | parser.add_argument('--log_name', type=str, default='test') 28 | 29 | parser.add_argument('--lr', type=float, default=0.05) 30 | parser.add_argument('--lr_steps', type=str, default='30,60,90') 31 | parser.add_argument('--wd', type=float, default=3e-5) 32 | 33 | parser.add_argument('--init_ch', type=int, default=48) 34 | parser.add_argument('--num_cells', type=int, default=16) 35 | 36 | parser.add_argument('--auxiliary', type=float, default=0.4) 37 | parser.add_argument('--drop_path_prob', type=float, default=0) 38 | parser.add_argument('--label_smooth', type=float, default=0.1) 39 | 40 | parser.add_argument('--batch_size', type=int, default=1024) 41 | parser.add_argument('--max_epochs', type=int, default=250) 42 | 43 | parser.add_argument('--log_interval', type=int, default=100) 44 | parser.add_argument('--gpus', type=str, default='0') 45 | parser.add_argument('--num_workers', type=int, default=7) 46 | 47 | cfg = parser.parse_args() 48 | 49 | cfg.lr_steps = [int(s) for s in cfg.lr_steps.split(',')] 50 | 51 | os.chdir(cfg.root_dir) 52 | 53 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 54 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus 55 | 56 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name + '_imgnet_eval') 57 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 58 | 59 | os.makedirs(cfg.log_dir, exist_ok=True) 60 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 61 | 62 | 63 | def main(): 64 | logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir) 65 | summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir) 66 | print = logger.info 67 | 68 | print(cfg) 69 | num_gpus = torch.cuda.device_count() 70 | if cfg.dist: 71 | device = torch.device('cuda:%d' % cfg.local_rank) if cfg.dist else torch.device('cuda') 72 | torch.cuda.set_device(cfg.local_rank) 73 | dist.init_process_group(backend='nccl', init_method='env://', 74 | world_size=num_gpus, rank=cfg.local_rank) 75 | else: 76 | device = torch.device('cuda') 77 | 78 | print('==> Preparing data..') 79 | train_dataset = ImgNet_split(root=os.path.join(cfg.data_dir, 'train'), 80 | transform=imgnet_transform(is_training=True)) 81 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, 82 | num_replicas=num_gpus, 83 | rank=cfg.local_rank) 84 | train_loader = torch.utils.data.DataLoader(train_dataset, 85 | batch_size=cfg.batch_size // num_gpus if cfg.dist 86 | else cfg.batch_size, 87 | shuffle=not cfg.dist, 88 | num_workers=cfg.num_workers, 89 | sampler=train_sampler if cfg.dist else None) 90 | 91 | val_dataset = ImgNet_split(root=os.path.join(cfg.data_dir, 'val'), 92 | transform=imgnet_transform(is_training=False)) 93 | val_loader = torch.utils.data.DataLoader(val_dataset, 94 | batch_size=cfg.batch_size, 95 | shuffle=False, 96 | num_workers=cfg.num_workers) 97 | 98 | print('==> Building model..') 99 | genotype = torch.load(os.path.join(cfg.ckpt_dir, 'genotype.pickle'))['genotype'] 100 | model = NetworkImageNet(genotype, cfg.init_ch, cfg.num_cells, cfg.auxiliary, num_classes=1000) 101 | 102 | if not cfg.dist: 103 | model = nn.DataParallel(model).to(device) 104 | else: 105 | # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 106 | model = model.to(device) 107 | model = nn.parallel.DistributedDataParallel(model, 108 | device_ids=[cfg.local_rank, ], 109 | output_device=cfg.local_rank) 110 | 111 | optimizer = torch.optim.SGD(model.parameters(), cfg.lr, momentum=0.9, weight_decay=cfg.wd) 112 | criterion = CrossEntropyLabelSmooth(num_classes=1000, epsilon=cfg.label_smooth).to(device) 113 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97) 114 | warmup = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=2) 115 | 116 | # Training 117 | def train(epoch): 118 | model.train() 119 | 120 | start_time = time.time() 121 | for batch_idx, (inputs, targets) in enumerate(train_loader): 122 | inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) 123 | 124 | outputs, outputs_aux = model(inputs) 125 | loss = criterion(outputs, targets) 126 | loss_aux = criterion(outputs_aux, targets) 127 | loss += cfg.auxiliary * loss_aux 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | nn.utils.clip_grad_norm_(model.parameters(), 5.0) 132 | optimizer.step() 133 | 134 | if batch_idx % cfg.log_interval == 0: 135 | step = len(train_loader) * epoch + batch_idx 136 | duration = time.time() - start_time 137 | 138 | print('[%d/%d - %d/%d] cls_loss= %.5f (%d samples/sec)' % 139 | (epoch, cfg.max_epochs, batch_idx, len(train_loader), 140 | loss.item(), cfg.batch_size * cfg.log_interval / duration)) 141 | 142 | start_time = time.time() 143 | summary_writer.add_scalar('cls_loss', loss.item(), step) 144 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], step) 145 | 146 | def val(epoch): 147 | # switch to evaluate mode 148 | model.eval() 149 | top1 = 0 150 | top5 = 0 151 | with torch.no_grad(): 152 | for i, (inputs, targets) in enumerate(val_loader): 153 | inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) 154 | 155 | output, _ = model(inputs) 156 | 157 | # measure accuracy and record loss 158 | _, pred = output.data.topk(5, dim=1, largest=True, sorted=True) 159 | pred = pred.t() 160 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 161 | 162 | top1 += correct[:1].view(-1).float().sum(0, keepdim=True).item() 163 | top5 += correct[:5].view(-1).float().sum(0, keepdim=True).item() 164 | 165 | top1 *= 100 / len(val_dataset) 166 | top5 *= 100 / len(val_dataset) 167 | print(' Precision@1 ==> %.2f%% Precision@1: %.2f%%\n' % (top1, top5)) 168 | summary_writer.add_scalar('Precision@1', top1, epoch) 169 | summary_writer.add_scalar('Precision@5', top5, epoch) 170 | return 171 | 172 | for epoch in range(cfg.max_epochs): 173 | print('\nEpoch: %d lr: %.5f drop_path_prob: %.3f' % 174 | (epoch, scheduler.get_lr()[0], cfg.drop_path_prob * epoch / cfg.max_epochs)) 175 | model.module.drop_path_prob = cfg.drop_path_prob * epoch / cfg.max_epochs 176 | train_sampler.set_epoch(epoch) 177 | train(epoch) 178 | val(epoch) 179 | if epoch < 5: 180 | warmup.step(epoch) 181 | else: 182 | scheduler.step(epoch) # move to here after pytorch1.1.0 183 | print(model.module.genotype()) 184 | if cfg.local_rank == 0: 185 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'checkpoint.t7')) 186 | 187 | summary_writer.close() 188 | count_parameters(model) 189 | count_flops(model, input_size=224) 190 | 191 | 192 | if __name__ == '__main__': 193 | if cfg.local_rank == 0: 194 | main() 195 | else: 196 | with DisablePrint(): 197 | main() 198 | -------------------------------------------------------------------------------- /nets/eval_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from nets.operations import * 4 | 5 | 6 | def drop_path(x, drop_prob): 7 | if drop_prob > 0.: 8 | keep_prob = 1. - drop_prob 9 | mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) 10 | x.div_(keep_prob) 11 | x.mul_(mask) 12 | return x 13 | 14 | 15 | class Cell(nn.Module): 16 | 17 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): 18 | """ 19 | :param genotype: 20 | :param C_prev_prev: 21 | :param C_prev: 22 | :param C: 23 | :param reduction: 24 | :param reduction_prev: 25 | """ 26 | super(Cell, self).__init__() 27 | 28 | print(C_prev_prev, C_prev, C) 29 | 30 | if reduction_prev: 31 | self.preprocess0 = FactorizedReduce(C_prev_prev, C) 32 | else: 33 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, kernel_size=1, stride=1, padding=0) 34 | self.preprocess1 = ReLUConvBN(C_prev, C, kernel_size=1, stride=1, padding=0) 35 | 36 | if reduction: 37 | op_names, indices = zip(*genotype.reduce) 38 | concat = genotype.reduce_concat 39 | else: 40 | op_names, indices = zip(*genotype.normal) 41 | concat = genotype.normal_concat 42 | 43 | assert len(op_names) == len(indices) 44 | 45 | self._num_nodes = len(op_names) // 2 46 | self._concat = concat 47 | self.multiplier = len(concat) 48 | 49 | self._ops = nn.ModuleList() 50 | for name, index in zip(op_names, indices): 51 | stride = 2 if reduction and index < 2 else 1 52 | op = OPS[name](C, stride, affine=True) 53 | self._ops += [op] 54 | self._indices = indices 55 | 56 | def forward(self, s0, s1, drop_prob): 57 | """ 58 | 59 | :param s0: 60 | :param s1: 61 | :param drop_prob: 62 | :return: 63 | """ 64 | s0 = self.preprocess0(s0) 65 | s1 = self.preprocess1(s1) 66 | 67 | states = [s0, s1] 68 | for i in range(self._num_nodes): 69 | h1 = states[self._indices[2 * i]] 70 | h2 = states[self._indices[2 * i + 1]] 71 | op1 = self._ops[2 * i] 72 | op2 = self._ops[2 * i + 1] 73 | h1 = op1(h1) 74 | h2 = op2(h2) 75 | 76 | if self.training and drop_prob > 0.: 77 | if not isinstance(op1, Identity): 78 | h1 = drop_path(h1, drop_prob) 79 | if not isinstance(op2, Identity): 80 | h2 = drop_path(h2, drop_prob) 81 | 82 | s = (h1 + h2) / 2 83 | states += [s] 84 | return torch.cat([states[i] for i in self._concat], dim=1) 85 | 86 | 87 | class AuxiliaryHeadCIFAR(nn.Module): 88 | 89 | def __init__(self, C, num_classes): 90 | """assuming input size 8x8""" 91 | super(AuxiliaryHeadCIFAR, self).__init__() 92 | 93 | self.features = nn.Sequential( 94 | nn.ReLU(inplace=True), 95 | nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 96 | nn.Conv2d(C, 128, kernel_size=1, bias=False), 97 | nn.BatchNorm2d(128), 98 | nn.ReLU(inplace=True), 99 | nn.Conv2d(128, 768, kernel_size=2, bias=False), 100 | nn.BatchNorm2d(768), 101 | nn.ReLU(inplace=True) 102 | ) 103 | self.classifier = nn.Linear(768, num_classes) 104 | 105 | def forward(self, x): 106 | x = self.features(x) 107 | x = self.classifier(x.view(x.size(0), -1)) 108 | return x 109 | 110 | 111 | class AuxiliaryHeadImageNet(nn.Module): 112 | 113 | def __init__(self, C, num_classes): 114 | """assuming input size 14x14""" 115 | super(AuxiliaryHeadImageNet, self).__init__() 116 | self.features = nn.Sequential( 117 | nn.ReLU(inplace=True), 118 | nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), 119 | nn.Conv2d(C, 128, kernel_size=1, bias=False), 120 | nn.BatchNorm2d(128), 121 | nn.ReLU(inplace=True), 122 | nn.Conv2d(128, 768, kernel_size=2, bias=False), 123 | # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. 124 | # Commenting it out for consistency with the experiments in the paper. 125 | # nn.BatchNorm2d(768), 126 | nn.ReLU(inplace=True) 127 | ) 128 | self.classifier = nn.Linear(768, num_classes) 129 | 130 | def forward(self, x): 131 | x = self.features(x) 132 | x = self.classifier(x.view(x.size(0), -1)) 133 | return x 134 | 135 | 136 | class NetworkCIFAR(nn.Module): 137 | 138 | def __init__(self, genotype, C, layers, auxiliary, num_classes): 139 | super(NetworkCIFAR, self).__init__() 140 | self.drop_path_prob = 0.0 141 | self._layers = layers 142 | self._auxiliary = auxiliary 143 | 144 | stem_multiplier = 3 145 | C_curr = stem_multiplier * C 146 | self.stem = nn.Sequential(nn.Conv2d(3, C_curr, kernel_size=3, padding=1, bias=False), 147 | nn.BatchNorm2d(C_curr)) 148 | 149 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C 150 | self.cells = nn.ModuleList() 151 | reduction_prev = False 152 | for i in range(layers): 153 | if i in [layers // 3, 2 * layers // 3]: 154 | C_curr *= 2 155 | reduction = True 156 | else: 157 | reduction = False 158 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 159 | reduction_prev = reduction 160 | self.cells += [cell] 161 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr 162 | if i == 2 * layers // 3: 163 | C_to_auxiliary = C_prev 164 | 165 | if auxiliary: 166 | self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) 167 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 168 | self.classifier = nn.Linear(C_prev, num_classes) 169 | 170 | def forward(self, input): 171 | logits_aux = None 172 | s0 = s1 = self.stem(input) 173 | for i, cell in enumerate(self.cells): 174 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 175 | if i == 2 * self._layers // 3: 176 | if self._auxiliary and self.training: 177 | logits_aux = self.auxiliary_head(s1) 178 | out = self.global_pooling(s1) 179 | logits = self.classifier(out.view(out.size(0), -1)) 180 | return logits, logits_aux 181 | 182 | 183 | class NetworkImageNet(nn.Module): 184 | 185 | def __init__(self, genotype, C, layers, auxiliary, num_classes): 186 | super(NetworkImageNet, self).__init__() 187 | self.drop_path_prob = 0.0 188 | self._layers = layers 189 | self._auxiliary = auxiliary 190 | 191 | self.stem0 = nn.Sequential( 192 | nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), 193 | nn.BatchNorm2d(C // 2), 194 | nn.ReLU(inplace=True), 195 | nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), 196 | nn.BatchNorm2d(C)) 197 | 198 | self.stem1 = nn.Sequential( 199 | nn.ReLU(inplace=True), 200 | nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), 201 | nn.BatchNorm2d(C)) 202 | 203 | C_prev_prev, C_prev, C_curr = C, C, C 204 | 205 | self.cells = nn.ModuleList() 206 | reduction_prev = True 207 | for i in range(layers): 208 | if i in [layers // 3, 2 * layers // 3]: 209 | C_curr *= 2 210 | reduction = True 211 | else: 212 | reduction = False 213 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 214 | reduction_prev = reduction 215 | self.cells += [cell] 216 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr 217 | if i == 2 * layers // 3: 218 | C_to_auxiliary = C_prev 219 | 220 | if auxiliary: 221 | self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) 222 | self.global_pooling = nn.AvgPool2d(7) 223 | self.classifier = nn.Linear(C_prev, num_classes) 224 | 225 | def forward(self, input): 226 | logits_aux = None 227 | s0 = self.stem0(input) 228 | s1 = self.stem1(s0) 229 | for i, cell in enumerate(self.cells): 230 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 231 | if i == 2 * self._layers // 3: 232 | if self._auxiliary and self.training: 233 | logits_aux = self.auxiliary_head(s1) 234 | out = self.global_pooling(s1) 235 | logits = self.classifier(out.view(out.size(0), -1)) 236 | return logits, logits_aux 237 | 238 | 239 | # if __name__ == '__main__': 240 | # import os 241 | # import pickle 242 | # from genotypes import * 243 | # from utils.utils import count_flops, count_parameters 244 | # 245 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 246 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0' 247 | # 248 | # 249 | # def hook(self, input, output): 250 | # print(output.data.cpu().numpy().shape) 251 | # pass 252 | # 253 | # 254 | # genotype = Genotype(normal=[('dil_conv_5x5', 0), ('skip_connect', 1), 255 | # ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), 256 | # ('sep_conv_3x3', 0), ('sep_conv_5x5', 2), 257 | # ('sep_conv_3x3', 0), ('dil_conv_5x5', 3)], 258 | # normal_concat=range(2, 6), 259 | # reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), 260 | # ('sep_conv_3x3', 0), ('avg_pool_3x3', 1), 261 | # ('dil_conv_3x3', 3), ('sep_conv_3x3', 0), 262 | # ('avg_pool_3x3', 1), ('max_pool_3x3', 0)], 263 | # reduce_concat=range(2, 6)) 264 | # 265 | # net = NetworkCIFAR(genotype=genotype, C=36, layers=20, auxiliary=0.4, num_classes=10) 266 | # 267 | # for m in net.modules(): 268 | # if isinstance(m, nn.Conv2d): 269 | # m.register_forward_hook(hook) 270 | # 271 | # y = net(torch.randn(2, 3, 32, 32)) 272 | # print(y[0].size()) 273 | # 274 | # count_parameters(net) 275 | # count_flops(net, input_size=32) 276 | -------------------------------------------------------------------------------- /cifar_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.distributed as dist 9 | 10 | from nets.search_model import Network 11 | 12 | from utils.utils import count_parameters, DisablePrint 13 | from utils.dataset import CIFAR_split 14 | from utils.preprocessing import cifar_search_transform 15 | from utils.second_order_update import * 16 | from utils.summary import create_summary, create_logger 17 | 18 | torch.backends.cudnn.benchmark = True 19 | 20 | # Training settings 21 | parser = argparse.ArgumentParser(description='darts') 22 | 23 | parser.add_argument('--local_rank', type=int, default=0) 24 | parser.add_argument('--dist', action='store_true') 25 | 26 | parser.add_argument('--root_dir', type=str, default='./') 27 | parser.add_argument('--data_dir', type=str, default='./data') 28 | parser.add_argument('--log_name', type=str, default='test') 29 | 30 | parser.add_argument('--order', type=str, default='1st', choices=['1st', '2nd']) 31 | 32 | parser.add_argument('--w_lr', type=float, default=0.025) 33 | parser.add_argument('--w_min_lr', type=float, default=0.001) 34 | parser.add_argument('--w_wd', type=float, default=3e-4) 35 | 36 | parser.add_argument('--a_lr', type=float, default=3e-4) 37 | parser.add_argument('--a_wd', type=float, default=1e-3) 38 | parser.add_argument('--a_start', type=int, default=-1) 39 | 40 | parser.add_argument('--init_ch', type=int, default=16) 41 | parser.add_argument('--num_cells', type=int, default=8) 42 | parser.add_argument('--num_nodes', type=int, default=4) 43 | parser.add_argument('--replica', type=int, default=1) 44 | 45 | parser.add_argument('--batch_size', type=int, default=64) 46 | parser.add_argument('--max_epochs', type=int, default=50) 47 | 48 | parser.add_argument('--log_interval', type=int, default=10) 49 | parser.add_argument('--gpus', type=str, default='0') 50 | parser.add_argument('--num_workers', type=int, default=2) 51 | 52 | cfg = parser.parse_args() 53 | 54 | os.chdir(cfg.root_dir) 55 | 56 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 57 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.log_name) 58 | 59 | os.makedirs(cfg.log_dir, exist_ok=True) 60 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 61 | 62 | 63 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 64 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus 65 | 66 | 67 | def main(): 68 | logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir) 69 | summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir) 70 | print = logger.info 71 | 72 | print(cfg) 73 | num_gpus = torch.cuda.device_count() 74 | if cfg.dist: 75 | device = torch.device('cuda:%d' % cfg.local_rank) if cfg.dist else torch.device('cuda') 76 | torch.cuda.set_device(cfg.local_rank) 77 | dist.init_process_group(backend='nccl', init_method='env://', 78 | world_size=num_gpus, rank=cfg.local_rank) 79 | else: 80 | device = torch.device('cuda') 81 | 82 | print('==> Preparing data..') 83 | cifar = 100 if 'cifar100' in cfg.log_name else 10 84 | 85 | train_dataset = CIFAR_split(cifar=cifar, root=cfg.data_dir, split='train', ratio=0.5, 86 | transform=cifar_search_transform(is_training=True)) 87 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, 88 | num_replicas=num_gpus, 89 | rank=cfg.local_rank) 90 | train_loader = torch.utils.data.DataLoader(train_dataset, 91 | batch_size=cfg.batch_size // num_gpus if cfg.dist 92 | else cfg.batch_size, 93 | shuffle=not cfg.dist, 94 | num_workers=cfg.num_workers, 95 | sampler=train_sampler if cfg.dist else None) 96 | 97 | val_dataset = CIFAR_split(cifar=cifar, root=cfg.data_dir, split='val', ratio=0.5, 98 | transform=cifar_search_transform(is_training=False)) 99 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, 100 | num_replicas=num_gpus, 101 | rank=cfg.local_rank) 102 | val_loader = torch.utils.data.DataLoader(val_dataset, 103 | batch_size=cfg.batch_size // num_gpus if cfg.dist 104 | else cfg.batch_size, 105 | shuffle=not cfg.dist, 106 | num_workers=cfg.num_workers, 107 | sampler=val_sampler if cfg.dist else None) 108 | 109 | print('==> Building model..') 110 | model = Network(C=cfg.init_ch, num_cells=cfg.num_cells, 111 | num_nodes=cfg.num_nodes, multiplier=cfg.num_nodes, num_classes=cifar) 112 | 113 | if not cfg.dist: 114 | model = nn.DataParallel(model).to(device) 115 | else: 116 | # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 117 | model = model.to(device) 118 | model = nn.parallel.DistributedDataParallel(model, 119 | device_ids=[cfg.local_rank, ], 120 | output_device=cfg.local_rank) 121 | 122 | # proxy_model is used for 2nd order update 123 | if cfg.order == '2nd': 124 | proxy_model = Network(cfg.init_ch, cfg.num_cells, cfg.num_nodes).cuda() 125 | 126 | count_parameters(model) 127 | 128 | weights = [v for k, v in model.named_parameters() if 'alpha' not in k] 129 | alphas = [v for k, v in model.named_parameters() if 'alpha' in k] 130 | optimizer_w = optim.SGD(weights, cfg.w_lr, momentum=0.9, weight_decay=cfg.w_wd) 131 | optimizer_a = optim.Adam(alphas, lr=cfg.a_lr, betas=(0.5, 0.999), weight_decay=cfg.a_wd) 132 | criterion = nn.CrossEntropyLoss().cuda() 133 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer_w, cfg.max_epochs, eta_min=cfg.w_min_lr) 134 | 135 | alphas = [] 136 | 137 | def train(epoch): 138 | model.train() 139 | print('\nEpoch: %d lr: %f' % (epoch, scheduler.get_lr()[0])) 140 | alphas.append([]) 141 | start_time = time.time() 142 | 143 | for batch_idx, ((inputs_w, targets_w), (inputs_a, targets_a)) \ 144 | in enumerate(zip(train_loader, val_loader)): 145 | 146 | inputs_w, targets_w = inputs_w.to(device), targets_w.to(device, non_blocking=True) 147 | inputs_a, targets_a = inputs_a.to(device), targets_a.to(device, non_blocking=True) 148 | 149 | # 1. update alpha 150 | if epoch > cfg.a_start: 151 | optimizer_a.zero_grad() 152 | 153 | if cfg.order == '1st': 154 | # using 1st order update 155 | outputs = model(inputs_a) 156 | val_loss = criterion(outputs, targets_a) 157 | val_loss.backward() 158 | else: 159 | # using 2nd order update 160 | val_loss = update(model, proxy_model, criterion, optimizer_w, 161 | inputs_a, targets_a, inputs_w, targets_w) 162 | 163 | optimizer_a.step() 164 | else: 165 | val_loss = torch.tensor([0]).cuda() 166 | 167 | # 2. update weights 168 | outputs = model(inputs_w) 169 | cls_loss = criterion(outputs, targets_w) 170 | 171 | optimizer_w.zero_grad() 172 | cls_loss.backward() 173 | nn.utils.clip_grad_norm_(model.parameters(), 5.0) 174 | optimizer_w.step() 175 | 176 | if batch_idx % cfg.log_interval == 0: 177 | step = len(train_loader) * epoch + batch_idx 178 | duration = time.time() - start_time 179 | 180 | print('[%d/%d - %d/%d] cls_loss: %5f val_loss: %5f (%d samples/sec)' % 181 | (epoch, cfg.max_epochs, batch_idx, len(train_loader), 182 | cls_loss.item(), val_loss.item(), cfg.batch_size * cfg.log_interval / duration)) 183 | 184 | start_time = time.time() 185 | summary_writer.add_scalar('cls_loss', cls_loss.item(), step) 186 | summary_writer.add_scalar('val_loss', val_loss.item(), step) 187 | summary_writer.add_scalar('learning rate', optimizer_w.param_groups[0]['lr'], step) 188 | 189 | alphas[-1].append(model.module.alpha_normal.detach().cpu().numpy()) 190 | alphas[-1].append(model.module.alpha_reduce.detach().cpu().numpy()) 191 | return 192 | 193 | def eval(epoch): 194 | model.eval() 195 | 196 | correct = 0 197 | total_loss = 0 198 | with torch.no_grad(): 199 | for step, (inputs, targets) in enumerate(val_loader): 200 | inputs, targets = inputs.to(device), targets.to(device, non_blocking=True) 201 | 202 | outputs = model(inputs) 203 | total_loss += criterion(outputs, targets).item() 204 | _, predicted = torch.max(outputs.data, 1) 205 | correct += predicted.eq(targets.data).cpu().sum().item() 206 | 207 | acc = 100. * correct / len(val_loader.dataset) 208 | total_loss = total_loss / len(val_loader) 209 | print('Val_loss==> %.5f Precision@1 ==> %.2f%% \n' % (total_loss, acc)) 210 | summary_writer.add_scalar('Precision@1', acc, global_step=epoch) 211 | summary_writer.add_scalar('val_loss_per_epoch', total_loss, global_step=epoch) 212 | return 213 | 214 | for epoch in range(cfg.max_epochs): 215 | train_sampler.set_epoch(epoch) 216 | val_sampler.set_epoch(epoch) 217 | train(epoch) 218 | eval(epoch) 219 | scheduler.step(epoch) # move to here after pytorch1.1.0 220 | print(model.module.genotype()) 221 | if cfg.local_rank == 0: 222 | torch.save(alphas, os.path.join(cfg.ckpt_dir, 'alphas.t7')) 223 | torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, 'search_checkpoint.t7')) 224 | torch.save({'genotype': model.module.genotype()}, os.path.join(cfg.ckpt_dir, 'genotype.t7')) 225 | 226 | summary_writer.close() 227 | 228 | 229 | if __name__ == '__main__': 230 | if cfg.local_rank == 0: 231 | main() 232 | else: 233 | with DisablePrint(): 234 | main() 235 | -------------------------------------------------------------------------------- /nets/search_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from nets.operations import FactorizedReduce, ReLUConvBN, MixedLayer 8 | from genotypes import PRIMITIVES, Genotype 9 | 10 | 11 | class Cell(nn.Module): 12 | 13 | def __init__(self, num_nodes, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): 14 | """ 15 | :param num_nodes: 4, number of layers inside a cell 16 | :param multiplier: 4 17 | :param C_prev_prev: 48 18 | :param C_prev: 48 19 | :param C: 16 20 | :param reduction: indicates whether to reduce the output maps width 21 | :param reduction_prev: when previous cell reduced width, s1_d = s0_d//2 22 | in order to keep same shape between s1 and s0, we adopt prep0 layer to 23 | reduce the s0 width by half. 24 | """ 25 | super(Cell, self).__init__() 26 | 27 | # indicating current cell is reduction or not 28 | self.reduction = reduction 29 | self.reduction_prev = reduction_prev 30 | 31 | # preprocess0 deal with output from prev_prev cell 32 | if reduction_prev: 33 | # if prev cell has reduced channel/double width, 34 | # it will reduce width by half 35 | self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) 36 | else: 37 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, kernel_size=1, 38 | stride=1, padding=0, affine=False) 39 | # preprocess1 deal with output from prev cell 40 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) 41 | 42 | # steps inside a cell 43 | self.num_nodes = num_nodes # 4 44 | self.multiplier = multiplier # 4 45 | 46 | self.layers = nn.ModuleList() 47 | 48 | for i in range(self.num_nodes): 49 | # for each i inside cell, it connects with all previous output 50 | # plus previous two cells' output 51 | for j in range(2 + i): 52 | # for reduction cell, it will reduce the heading 2 inputs only 53 | stride = 2 if reduction and j < 2 else 1 54 | layer = MixedLayer(C, stride, op_names=PRIMITIVES) 55 | self.layers.append(layer) 56 | 57 | def forward(self, s0, s1, weights): 58 | """ 59 | :param s0: 60 | :param s1: 61 | :param weights: [14, 8] 62 | :return: 63 | """ 64 | # print('s0:', s0.shape,end='=>') 65 | s0 = self.preprocess0(s0) # [40, 48, 32, 32], [40, 16, 32, 32] 66 | # print(s0.shape, self.reduction_prev) 67 | # print('s1:', s1.shape,end='=>') 68 | s1 = self.preprocess1(s1) # [40, 48, 32, 32], [40, 16, 32, 32] 69 | # print(s1.shape) 70 | 71 | states = [s0, s1] 72 | offset = 0 73 | # for each node, receive input from all previous intermediate nodes and s0, s1 74 | for i in range(self.num_nodes): # 4 75 | # [40, 16, 32, 32] 76 | s = sum(self.layers[offset + j](h, weights[offset + j]) 77 | for j, h in enumerate(states)) / len(states) 78 | offset += len(states) 79 | # append one state since s is the elem-wise addition of all output 80 | states.append(s) 81 | # print('node:',i, s.shape, self.reduction) 82 | 83 | # concat along dim=channel 84 | return torch.cat(states[-self.multiplier:], dim=1) # 6 of [40, 16, 32, 32] 85 | 86 | 87 | class Network(nn.Module): 88 | """ 89 | stack number:layer of cells and then flatten to fed a linear layer 90 | """ 91 | 92 | def __init__(self, C, num_cells, 93 | num_nodes=4, multiplier=4, stem_multiplier=3, num_classes=10, img_channel=3): 94 | """ 95 | 96 | :param C: 16 97 | :param num_cells: number of cells of current network 98 | :param num_nodes: nodes num inside cell 99 | :param multiplier: output channel of cell = multiplier * ch 100 | :param stem_multiplier: output channel of stem net = stem_multiplier * ch 101 | :param num_classes: 10 102 | """ 103 | super(Network, self).__init__() 104 | 105 | self.C = C 106 | self.num_classes = num_classes 107 | self.num_cells = num_cells 108 | self.num_nodes = num_nodes 109 | self.multiplier = multiplier 110 | 111 | # stem_multiplier is for stem network, 112 | # and multiplier is for general cell 113 | C_curr = stem_multiplier * C # 3*16 114 | # stem network, convert 3 channel to c_curr 115 | self.stem = nn.Sequential( # 3 => 48 116 | nn.Conv2d(img_channel, C_curr, 3, padding=1, bias=False), 117 | nn.BatchNorm2d(C_curr)) 118 | 119 | # c_curr means a factor of the output channels of current cell 120 | # output channels = multiplier * c_curr 121 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C # 48, 48, 16 122 | self.cells = nn.ModuleList() 123 | reduction_prev = False 124 | for i in range(num_cells): 125 | 126 | # for layer in the middle [1/3, 2/3], reduce via stride=2 127 | if i in [num_cells // 3, 2 * num_cells // 3]: 128 | C_curr *= 2 129 | reduction = True 130 | else: 131 | reduction = False 132 | 133 | # [cp, h, h] => [multiplier*c_curr, h/h//2, h/h//2] 134 | # the output channels = multiplier * c_curr 135 | cell = Cell(num_nodes, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 136 | # update reduction_prev 137 | reduction_prev = reduction 138 | 139 | self.cells += [cell] 140 | 141 | C_prev_prev, C_prev = C_prev, multiplier * C_curr 142 | 143 | # adaptive pooling output size to 1x1 144 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 145 | # since cp records last cell's output channels 146 | # it indicates the input channel number 147 | self.classifier = nn.Linear(C_prev, num_classes) 148 | 149 | # k is the total number of edges inside single cell, 14 150 | k = sum(1 for i in range(self.num_nodes) for j in range(2 + i)) 151 | num_ops = len(PRIMITIVES) # 8 152 | 153 | self.alpha_normal = nn.Parameter(torch.randn(k, num_ops)) 154 | self.alpha_reduce = nn.Parameter(torch.randn(k, num_ops)) 155 | with torch.no_grad(): 156 | # initialize to smaller value 157 | self.alpha_normal.mul_(1e-3) 158 | self.alpha_reduce.mul_(1e-3) 159 | self._arch_parameters = [self.alpha_normal, self.alpha_reduce] 160 | 161 | def forward(self, x): 162 | """ 163 | in: torch.Size([3, 3, 32, 32]) 164 | stem: torch.Size([3, 48, 32, 32]) 165 | cell: 0 torch.Size([3, 64, 32, 32]) False 166 | cell: 1 torch.Size([3, 64, 32, 32]) False 167 | cell: 2 torch.Size([3, 128, 16, 16]) True 168 | cell: 3 torch.Size([3, 128, 16, 16]) False 169 | cell: 4 torch.Size([3, 128, 16, 16]) False 170 | cell: 5 torch.Size([3, 256, 8, 8]) True 171 | cell: 6 torch.Size([3, 256, 8, 8]) False 172 | cell: 7 torch.Size([3, 256, 8, 8]) False 173 | pool: torch.Size([16, 256, 1, 1]) 174 | linear: [b, 10] 175 | :param x: 176 | :return: 177 | """ 178 | # print('in:', x.shape) 179 | # s0 & s1 means the last cells' output 180 | s0 = s1 = self.stem(x) # [b, 3, 32, 32] => [b, 48, 32, 32] 181 | # print('stem:', s0.shape) 182 | 183 | for i, cell in enumerate(self.cells): 184 | # weights are shared across all reduction cell or normal cell 185 | # according to current cell's type, it choose which architecture parameters 186 | # to use 187 | if cell.reduction: # if current cell is reduction cell 188 | weights = F.softmax(self.alpha_reduce, dim=-1) 189 | else: 190 | weights = F.softmax(self.alpha_normal, dim=-1) # [14, 8] 191 | # execute cell() firstly and then assign s0=s1, s1=result 192 | s0, s1 = s1, cell(s0, s1, weights) # [40, 64, 32, 32] 193 | # print('cell:',i, s1.shape, cell.reduction, cell.reduction_prev) 194 | # print('\n') 195 | 196 | # s1 is the last cell's output 197 | out = self.global_pooling(s1) 198 | # print('pool', out.shape) 199 | logits = self.classifier(out.view(out.size(0), -1)) 200 | 201 | return logits 202 | 203 | def genotype(self): 204 | def _parse(weights): 205 | """ 206 | :param weights: [14, 8] 207 | :return: 208 | """ 209 | gene = [] 210 | n = 2 211 | start = 0 212 | for i in range(self.num_nodes): # for each node 213 | end = start + n 214 | W = weights[start:end].copy() # shape=[2, 8], [3, 8], [4, 8], [5, 8] 215 | # i+2 is the number of connection for node i 216 | # sort by descending order, get strongest 2 edges 217 | # note here we assume the 0th op is none op, if it's not the case this will be wrong! 218 | edges = np.argsort(-np.max(W[:, 1:], axis=1))[:2] 219 | ops = np.argmax(W[edges, 1:], axis=1) + 1 220 | gene += [(PRIMITIVES[op], edge) for op, edge in zip(ops, edges)] 221 | start = end 222 | n += 1 223 | return gene 224 | 225 | gene_normal = _parse(F.softmax(self.alpha_normal, dim=-1).data.cpu().numpy()) 226 | gene_reduce = _parse(F.softmax(self.alpha_reduce, dim=-1).data.cpu().numpy()) 227 | 228 | concat = range(2 + self.num_nodes - self.multiplier, self.num_nodes + 2) 229 | genotype = Genotype(normal=gene_normal, normal_concat=concat, 230 | reduce=gene_reduce, reduce_concat=concat) 231 | 232 | return genotype 233 | 234 | 235 | # if __name__ == '__main__': 236 | # import numpy as np 237 | # from utils.utils import create_logger 238 | # 239 | # 240 | # def hook(self, input, output): 241 | # # print(output.data.cpu().numpy().shape) 242 | # pass 243 | # 244 | # 245 | # logger = create_logger(0) 246 | # net = Network(16, 8, 4) 247 | # print(net.genotype()) 248 | # logger.info(net.genotype()) 249 | # print(net.genotype()) 250 | # 251 | # for m in net.modules(): 252 | # if isinstance(m, nn.Conv2d): 253 | # m.register_forward_hook(hook) 254 | # 255 | # y = net(torch.randn(1, 3, 32, 32)) 256 | # print(y.size()) 257 | # 258 | # sep_size = 0 259 | # for k, v in net.named_parameters(): 260 | # print('%s: %f MB' % (k, v.numel() / 1024 / 1024)) 261 | # if '4.op' in k or '5.op' in k: 262 | # sep_size += v.numel() / 1024 / 1024 263 | # print("Sep conv size = %f MB" % sep_size) 264 | # print("Total param size = %f MB" % (sum(v.numel() for v in net.parameters()) / 1024 / 1024)) 265 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch 4 | import pickle 5 | import logging 6 | import numpy as np 7 | 8 | import torch.utils.data as data 9 | from PIL import Image 10 | 11 | 12 | class MNIST_split(data.Dataset): 13 | raw_folder = 'raw' 14 | processed_folder = 'processed' 15 | training_file = 'training.pt' 16 | test_file = 'test.pt' 17 | 18 | def __init__(self, root, split='train', split_size=5000, 19 | transform=None, target_transform=None): 20 | assert split in ['train', 'val', 'test'] 21 | self.root = os.path.expanduser(root) 22 | self.transform = transform 23 | self.target_transform = target_transform 24 | self.split = split # training set or test set 25 | 26 | if not self._check_exists(): 27 | raise RuntimeError('Dataset not found.' + 28 | ' You can use download=True to download it') 29 | 30 | if self.split == 'test': 31 | self.test_data, self.test_labels = torch.load( 32 | os.path.join(self.root, self.processed_folder, self.test_file)) 33 | else: 34 | self.train_data, self.train_labels = torch.load( 35 | os.path.join(self.root, self.processed_folder, self.training_file)) 36 | if self.split == 'train': 37 | self.train_data = self.train_data[:split_size] 38 | self.train_labels = self.train_labels[:split_size] 39 | else: 40 | self.train_data = self.train_data[-split_size:] 41 | self.train_labels = self.train_labels[-split_size:] 42 | 43 | def __getitem__(self, index): 44 | if self.split == 'test': 45 | img, target = self.test_data[index], self.test_labels[index] 46 | else: 47 | img, target = self.train_data[index], self.train_labels[index] 48 | 49 | # doing this so that it is consistent with all other datasets 50 | # to return a PIL Image 51 | img = Image.fromarray(img.numpy(), mode='L') 52 | 53 | if self.transform is not None: 54 | img = self.transform(img) 55 | 56 | if self.target_transform is not None: 57 | target = self.target_transform(target) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | if self.split == 'test': 63 | return len(self.test_data) 64 | else: 65 | return len(self.train_data) 66 | 67 | def _check_exists(self): 68 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 69 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 70 | 71 | 72 | class CIFAR10_split(data.Dataset): 73 | """`CIFAR10 `_ Dataset. 74 | """ 75 | base_folder = 'cifar-10-batches-py' 76 | train_list = [['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 77 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 78 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 79 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 80 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']] 81 | 82 | test_list = [['test_batch', '40351d587109b95175f43aff81a1287e']] 83 | 84 | def __init__(self, root, split, ratio, transform=None, target_transform=None): 85 | assert split in ['train', 'val', 'test'] 86 | self.root = os.path.expanduser(root) 87 | self.transform = transform 88 | self.target_transform = target_transform 89 | self.split = split # training set or test set 90 | 91 | # now load the picked numpy arrays 92 | if self.split == 'test': 93 | f = self.test_list[0][0] 94 | file = os.path.join(self.root, self.base_folder, f) 95 | fo = open(file, 'rb') 96 | entry = pickle.load(fo, encoding='latin1') 97 | self.data = entry['data'] 98 | if 'labels' in entry: 99 | self.labels = entry['labels'] 100 | else: 101 | self.labels = entry['fine_labels'] 102 | fo.close() 103 | 104 | self.data = self.data.reshape((-1, 3, 32, 32)) 105 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 106 | else: 107 | self.data = [] 108 | self.labels = [] 109 | for fentry in self.train_list: 110 | f = fentry[0] 111 | file = os.path.join(self.root, self.base_folder, f) 112 | fo = open(file, 'rb') 113 | entry = pickle.load(fo, encoding='latin1') 114 | self.data.append(entry['data']) 115 | if 'labels' in entry: 116 | self.labels += entry['labels'] 117 | else: 118 | self.labels += entry['fine_labels'] 119 | fo.close() 120 | 121 | self.data = np.concatenate(self.data) 122 | self.data = self.data.reshape((-1, 3, 32, 32)) 123 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 124 | if self.split == 'train' and 0.0 < ratio < 1.0: 125 | split_size = int(np.clip(len(self.data) * ratio, 1.0, len(self.data))) 126 | print('using %d images from start ...' % split_size) 127 | # logging.getLogger('logger').info('using %d images from start ...' % split_size) 128 | self.data = self.data[:split_size] 129 | self.labels = self.labels[:split_size] 130 | elif self.split == 'val' and 0.0 < ratio < 1.0: 131 | split_size = int(np.clip(len(self.data) * ratio, 1.0, len(self.data))) 132 | print('using %d images from end ...' % split_size) 133 | # logging.getLogger('logger').info('using %d images from end ...' % split_size) 134 | self.data = self.data[-split_size:] 135 | self.labels = self.labels[-split_size:] 136 | 137 | def __getitem__(self, index): 138 | """ 139 | Args: 140 | index (int): Index 141 | 142 | Returns: 143 | tuple: (image, target) where target is index of the target class. 144 | """ 145 | img, target = self.data[index], self.labels[index] 146 | # doing this so that it is consistent with all other datasets 147 | # to return a PIL Image 148 | img = Image.fromarray(img) 149 | 150 | if self.transform is not None: 151 | img = self.transform(img) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return img, target 157 | 158 | def __len__(self): 159 | return len(self.data) 160 | 161 | 162 | class CIFAR100_split(CIFAR10_split): 163 | base_folder = 'cifar-100-python' 164 | train_list = [['train', '16019d7e3df5f24257cddd939b257f8d']] 165 | test_list = [['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc']] 166 | 167 | 168 | def CIFAR_split(cifar, root, split='train', ratio=1.0, transform=None, target_transform=None): 169 | if cifar == 10: 170 | return CIFAR10_split(root, split, ratio, transform, target_transform) 171 | elif cifar == 100: 172 | return CIFAR100_split(root, split, ratio, transform, target_transform) 173 | else: 174 | raise NotImplementedError 175 | 176 | 177 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 178 | 179 | 180 | def pil_loader(path): 181 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 182 | with open(path, 'rb') as f: 183 | img = Image.open(f) 184 | return img.convert('RGB') 185 | 186 | 187 | def accimage_loader(path): 188 | import accimage 189 | try: 190 | return accimage.Image(path) 191 | except IOError: 192 | # Potentially a decoding problem, fall back to PIL.Image 193 | return pil_loader(path) 194 | 195 | 196 | def default_loader(path): 197 | from torchvision import get_image_backend 198 | if get_image_backend() == 'accimage': 199 | return accimage_loader(path) 200 | else: 201 | return pil_loader(path) 202 | 203 | 204 | def has_file_allowed_extension(filename, allowed_extensions): 205 | return any(filename.lower().endswith(ext) for ext in allowed_extensions) 206 | 207 | 208 | def find_classes(dir): 209 | class_folder_names = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 210 | class_folder_names.sort() 211 | class_to_idx = {class_folder_names[i]: i for i in range(len(class_folder_names))} 212 | return class_folder_names, class_to_idx 213 | 214 | 215 | def make_dataset(dir, class_to_idx, allowed_extensions, split, ratio): 216 | samples = [] 217 | for folder_name in sorted(os.listdir(os.path.expanduser(dir))): 218 | if not os.path.isdir(os.path.join(dir, folder_name)): 219 | continue 220 | 221 | for root, _, img_names in sorted(os.walk(os.path.join(dir, folder_name))): 222 | split_size = int(np.clip(len(img_names) * ratio, 1.0, len(img_names))) 223 | if split == 'start' and 0 < ratio < 1.0: 224 | img_names = img_names[:split_size] 225 | elif split == 'end' and 0 < ratio < 1.0: 226 | img_names = img_names[-split_size:] 227 | for img_name in sorted(img_names): 228 | if has_file_allowed_extension(img_name, allowed_extensions): 229 | samples.append((os.path.join(root, img_name), class_to_idx[folder_name])) 230 | 231 | return samples 232 | 233 | 234 | class ImgNet_split(data.Dataset): 235 | ''' 236 | when split is 'train', split_size means num of images excluded for EACH class 237 | when split is 'val', split_size means num of images choosed for EACH class 238 | when split is 'test', split_size will not be used! 239 | ''' 240 | 241 | def __init__(self, root, loader=default_loader, 242 | split='all', ratio=1.0, transform=None, target_transform=None): 243 | assert split in ['all', 'start', 'end'] 244 | classes, class_to_idx = find_classes(root) 245 | samples = make_dataset(root, class_to_idx, IMG_EXTENSIONS, split, ratio) 246 | if len(samples) == 0: 247 | raise (RuntimeError("Found 0 files in subfolders of: " + 248 | root + "\nSupported extensions are: " + 249 | ",".join(IMG_EXTENSIONS))) 250 | 251 | self.root = root 252 | self.loader = loader 253 | self.extensions = IMG_EXTENSIONS 254 | 255 | self.classes = classes 256 | self.class_to_idx = class_to_idx 257 | self.samples = samples 258 | self.imgs = samples 259 | 260 | self.transform = transform 261 | self.target_transform = target_transform 262 | print(len(samples)) 263 | # logging.getLogger('logger').info(len(samples)) 264 | 265 | def __getitem__(self, index): 266 | path, target = self.samples[index] 267 | sample = self.loader(path) 268 | if self.transform is not None: 269 | sample = self.transform(sample) 270 | if self.target_transform is not None: 271 | target = self.target_transform(target) 272 | 273 | return sample, target 274 | 275 | def __len__(self): 276 | return len(self.samples) 277 | 278 | def __repr__(self): 279 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 280 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 281 | fmt_str += ' Root Location: {}\n'.format(self.root) 282 | tmp = ' Transforms (if any): ' 283 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 284 | tmp = ' Target Transforms (if any): ' 285 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 286 | return fmt_str 287 | 288 | 289 | if __name__ == '__main__': 290 | # ds = MNIST_split('../data', split='train', split_size=30000) 291 | # pass 292 | ds = CIFAR100_split('../data', split='train') 293 | 294 | # dataset = ImgNet_split(root='E:\\imagenet_raw\\train', split='test', split_size=50) 295 | # print(len(dataset)) 296 | --------------------------------------------------------------------------------