├── nasnet ├── __init__.py ├── droppath.py ├── optimizer.py ├── nasnet.py └── layers.py ├── README.md ├── .gitignore └── imagenet.py /nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .nasnet import * 2 | from .optimizer import PowersignCD -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nasnet.pytorch 2 | 3 | ## Work In Progress 4 | 5 | Pytorch implementation of [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012). Currently crashes in Pytorch 0.2.0, should work with Pytorch@master or Pytorch 0.3.0. 6 | 7 | ## TODO 8 | 9 | * Clean up code 10 | * Refactor the quick and dirty PowersignCD 11 | * Pretrain nets on ImageNet 12 | * Write a better path dropout 13 | -------------------------------------------------------------------------------- /nasnet/droppath.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.cuda 3 | import torch.nn as nn 4 | import torch.functional as F 5 | from random import random 6 | from torch.autograd import Variable 7 | 8 | # Currently there is a risk of dropping all paths... 9 | # We should create a version that take all paths into account to make sure one stays alive 10 | # But then keep_prob is meaningless and we have to copute/keep track of the conditional probability 11 | class DropPath(nn.Module): 12 | def __init__(self, module, keep_prob=0.9): 13 | super(DropPath, self).__init__() 14 | self.module = module 15 | self.keep_prob = keep_prob 16 | self.shape = None 17 | self.training = True 18 | self.dtype = torch.FloatTensor 19 | 20 | def forward(self, *input): 21 | if self.training: 22 | # If we don't now the shape we run the forward path once and store the output shape 23 | if self.shape is None: 24 | temp = self.module(*input) 25 | self.shape = temp.size() 26 | if temp.data.is_cuda: 27 | self.dtype = torch.cuda.FloatTensor 28 | del temp 29 | p = random() 30 | if p <= self.keep_prob: 31 | return Variable(self.dtype(self.shape).zero_()) 32 | else: 33 | return self.module(*input)/self.keep_prob # Inverted scaling 34 | else: 35 | return self.module(*input) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Pycharm 2 | .idea 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | .static_storage/ 59 | .media/ 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | -------------------------------------------------------------------------------- /nasnet/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim.optimizer import Optimizer, required 2 | import torch 3 | import math 4 | 5 | 6 | class PowersignCD(Optimizer): 7 | def __init__(self, params, steps, lr=required, momentum=0.9, dampening=0, 8 | weight_decay=0, nesterov=True): 9 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 10 | weight_decay=weight_decay, nesterov=nesterov) 11 | if nesterov and (momentum <= 0 or dampening != 0): 12 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 13 | super(PowersignCD, self).__init__(params, defaults) 14 | self.t = 0 15 | self.T = steps 16 | 17 | def __setstate__(self, state): 18 | super(PowersignCD, self).__setstate__(state) 19 | for group in self.param_groups: 20 | group.setdefault('nesterov', False) 21 | 22 | def step(self, closure=None): 23 | """Performs a single optimization step. 24 | 25 | Arguments: 26 | closure (callable, optional): A closure that reevaluates the model 27 | and returns the loss. 28 | """ 29 | loss = None 30 | if closure is not None: 31 | loss = closure() 32 | 33 | for group in self.param_groups: 34 | weight_decay = group['weight_decay'] 35 | momentum = group['momentum'] 36 | 37 | for p in group['params']: 38 | if p.grad is None: 39 | continue 40 | g = p.grad.data 41 | 42 | if weight_decay != 0: 43 | g.add_(weight_decay, p.data) 44 | 45 | param_state = self.state[p] 46 | if 'exp_avg' not in param_state: 47 | m = param_state['exp_avg'] = g.clone() 48 | else: 49 | m = param_state['exp_avg'] 50 | m.mul_(momentum).add_(1 - momentum, g) 51 | 52 | w = torch.sign(g).mul(torch.sign(m)) 53 | w.mul(.5*(1+math.cos(math.pi*self.t/self.T))) 54 | w.exp_() 55 | p.data.addcmul_(-.5 * (1 + math.cos(math.pi * self.t / self.T)) * group['lr'], w, g) 56 | 57 | self.t += 1 58 | 59 | return loss 60 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn.parallel 3 | import torch.optim 4 | import torch.utils.data 5 | 6 | from torch.autograd import Variable 7 | 8 | from nasnet import NASNet, nasnetmobile, nasnetlarge, PowersignCD 9 | 10 | import os 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torchvision import datasets, transforms 16 | from tqdm import tqdm 17 | 18 | class Trainer(object): 19 | cuda = torch.cuda.is_available() 20 | 21 | def __init__(self, model, optimizer, loss_f, save_dir=None, save_freq=5): 22 | self.model = model 23 | if self.cuda: 24 | model.cuda() 25 | self.optimizer = optimizer 26 | self.loss_f = loss_f 27 | self.save_dir = save_dir 28 | self.save_freq = save_freq 29 | 30 | def _loop(self, data_loader, is_train=True): 31 | loop_loss = [] 32 | correct = [] 33 | for data, target in tqdm(data_loader): 34 | if self.cuda: 35 | data, target = data.cuda(), target.cuda() 36 | data, target = Variable(data, volatile=not is_train), Variable(target, volatile=not is_train) 37 | self.optimizer.zero_grad() 38 | output = self.model(data) 39 | loss = self.loss_f(output, target) 40 | loop_loss.append(loss.data[0] / len(data_loader)) 41 | correct.append((output.data.max(1)[1] == target.data).sum() / len(data_loader.dataset)) 42 | if is_train: 43 | loss.backward() 44 | self.optimizer.step() 45 | mode = "train" if is_train else "test" 46 | print(f">>>[{mode}] loss: {sum(loop_loss):.2f}/accuracy: {sum(correct):.2%}") 47 | return loop_loss, correct 48 | 49 | def train(self, data_loader): 50 | self.model.train() 51 | loss, correct = self._loop(data_loader) 52 | 53 | def test(self, data_loader): 54 | self.model.eval() 55 | loss, correct = self._loop(data_loader, is_train=False) 56 | 57 | def loop(self, epochs, train_data, test_data, scheduler=None): 58 | for ep in range(1, epochs + 1): 59 | if scheduler is not None: 60 | scheduler.step() 61 | print(f"epochs: {ep}") 62 | self.train(train_data) 63 | self.test(test_data) 64 | if ep % self.save_freq: 65 | self.save(ep) 66 | 67 | def save(self, epoch, **kwargs): 68 | if self.save_dir: 69 | name = f"weight-{epoch}-" + "-".join([f"{k}_{v}" for k, v in kwargs.items()]) + ".pkl" 70 | torch.save({"weight": self.model.state_dict(), 71 | "optimizer": self.optimizer.state_dict()}, 72 | os.path.join(self.save_dir, name)) 73 | 74 | def main(type, batch_size, data_root, n_epochs): 75 | if type == 'mobile': 76 | input_size = 224, 77 | model = nasnetmobile 78 | elif type == 'large': 79 | input_size = 331 80 | model = nasnetlarge 81 | else: 82 | input_size = 299 83 | model = nasnetlarge 84 | 85 | transform_train = transforms.Compose([ 86 | transforms.RandomSizedCrop(input_size), 87 | transforms.RandomHorizontalFlip(), 88 | transforms.ToTensor(), 89 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 90 | std=[0.229, 0.224, 0.225]), 91 | ]) 92 | transform_test = transforms.Compose([ 93 | transforms.CenterCrop(input_size), 94 | transforms.ToTensor(), 95 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 96 | std=[0.229, 0.224, 0.225]), 97 | ]) 98 | 99 | traindir = os.path.join(data_root, 'train') 100 | valdir = os.path.join(data_root, 'val') 101 | train = datasets.ImageFolder(traindir, transform_train) 102 | val = datasets.ImageFolder(valdir, transform_test) 103 | train_loader = torch.utils.data.DataLoader( 104 | train, batch_size=batch_size, shuffle=True, num_workers=8) 105 | test_loader = torch.utils.data.DataLoader( 106 | val, batch_size=batch_size, shuffle=True, num_workers=8) 107 | net = model(num_classes=1000) 108 | optimizer = PowersignCD(params=net.parameters(), steps=len(train)/batch_size*n_epochs, lr=0.6, momentum=0.9) 109 | trainer = Trainer(net, optimizer, F.cross_entropy, save_dir=".") 110 | trainer.loop(n_epochs, train_loader, test_loader) 111 | 112 | 113 | if __name__ == '__main__': 114 | import argparse 115 | 116 | p = argparse.ArgumentParser() 117 | p.add_argument("root", help="imagenet data root") 118 | p.add_argument("--batch_size", default=8, type=int) 119 | p.add_argument("--n_epochs", default=10, type=int) 120 | args = p.parse_args() 121 | main(args.batch_size, args.root, args.n_epochs) -------------------------------------------------------------------------------- /nasnet/nasnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.functional as F 4 | from torch.nn.functional import relu 5 | from .layers import NormalCell, ReductionCell, ResizeCell0, ResizeCell1 6 | 7 | class NASNet(nn.Module): 8 | def __init__(self, stem_filters, normals, filters, scaling, num_classes, use_aux=True, pretrained=True): 9 | super(NASNet, self).__init__() 10 | self.normals = normals 11 | self.use_aux = use_aux 12 | self.num_classes = num_classes 13 | 14 | self.stemcell = nn.Sequential( 15 | nn.Conv2d(3, stem_filters, kernel_size=3, stride=2), 16 | nn.BatchNorm2d(stem_filters, eps=0.001, momentum=0.1, affine=True) 17 | ) 18 | 19 | self.reduction1 = ReductionCell(in_channels_x=stem_filters, 20 | in_channels_h=stem_filters, 21 | out_channels=int(filters * scaling ** (-2)), 22 | resize_cell=ResizeCell1) 23 | self.reduction2 = ReductionCell(in_channels_x=int(4*filters * scaling ** (-2)), 24 | in_channels_h=stem_filters, 25 | out_channels=int(filters * scaling ** (-1)), 26 | resize_cell=ResizeCell0) 27 | 28 | x_channels = int(4*filters * scaling ** (-1)) 29 | h_channels = int(4*filters * scaling ** (-2)) 30 | 31 | self.add_module('normal_block1_0', 32 | NormalCell(in_channels_x=x_channels, 33 | in_channels_h=h_channels, 34 | out_channels=filters, 35 | resize_cell=ResizeCell0, 36 | keep_prob=0.9)) 37 | # TODO: Can we do that in a cleaner way? 38 | h_channels = x_channels 39 | x_channels = 6*filters 40 | for i in range(normals-1): 41 | self.add_module('normal_block1_{}'.format(i+1), 42 | NormalCell(in_channels_x=x_channels, 43 | in_channels_h=h_channels, 44 | out_channels=filters, 45 | resize_cell=ResizeCell1, 46 | keep_prob=0.9)) 47 | h_channels = x_channels 48 | x_channels = 6*filters 49 | 50 | self.reduction3 = ReductionCell(in_channels_x=x_channels, 51 | in_channels_h=h_channels, 52 | out_channels=filters * scaling) 53 | 54 | h_channels = x_channels 55 | x_channels = 4 * filters * scaling 56 | 57 | self.add_module('normal_block2_0', 58 | NormalCell(in_channels_x=x_channels, 59 | in_channels_h=h_channels, 60 | out_channels=filters*scaling, 61 | resize_cell=ResizeCell0, 62 | keep_prob=0.9)) 63 | h_channels = x_channels 64 | x_channels = 6 * filters * scaling 65 | for i in range(normals - 1): 66 | self.add_module('normal_block2_{}'.format(i + 1), 67 | NormalCell(in_channels_x=x_channels, 68 | in_channels_h=h_channels, 69 | out_channels=filters*scaling, 70 | resize_cell=ResizeCell1, keep_prob=0.9)) 71 | h_channels = x_channels 72 | x_channels = 6 * filters * scaling 73 | 74 | self.reduction4 = ReductionCell(in_channels_x=x_channels, 75 | in_channels_h=h_channels, 76 | out_channels=filters * scaling ** 2) 77 | 78 | h_channels = x_channels 79 | x_channels = 4 * filters * scaling ** 2 80 | 81 | self.add_module('normal_block3_0', 82 | NormalCell(in_channels_x=x_channels, 83 | in_channels_h=h_channels, 84 | out_channels=filters * scaling ** 2, 85 | resize_cell=ResizeCell0, keep_prob=0.9)) 86 | h_channels = x_channels 87 | x_channels = 6 * filters * scaling ** 2 88 | for i in range(normals - 1): 89 | self.add_module('normal_block3_{}'.format(i + 1), 90 | NormalCell(in_channels_x=x_channels, 91 | in_channels_h=h_channels, 92 | out_channels=filters * scaling ** 2, 93 | resize_cell=ResizeCell1, 94 | keep_prob=0.9)) 95 | h_channels = x_channels 96 | x_channels = 6 * filters * scaling ** 2 97 | 98 | self.avg_pool_0 = nn.AvgPool2d(11, stride=1, padding=0) 99 | self.dropout_0 = nn.Dropout() 100 | self.fc = nn.Linear(x_channels, self.num_classes) 101 | 102 | def features(self, x): 103 | x = self.stemcell(x) 104 | 105 | x, h = self.reduction1(x, x) 106 | x, h = self.reduction2(x, h) 107 | 108 | for i in range(self.normals): 109 | x, h = self._modules['normal_block1_{}'.format(i)](x, h) 110 | 111 | x, h = self.reduction3(x, h) 112 | 113 | for i in range(self.normals): 114 | x, h = self._modules['normal_block2_{}'.format(i)](x, h) 115 | 116 | # Should we check for training or not ? 117 | if self.use_aux and self.training: 118 | x_aux = x 119 | 120 | x, h = self.reduction4(x, h) 121 | 122 | for i in range(self.normals): 123 | x, h = self._modules['normal_block3_{}'.format(i)](x, h) 124 | 125 | if self.use_aux and self.training: 126 | return x, x_aux 127 | else: 128 | return x 129 | 130 | def classifier(self, x): 131 | x = relu(x) 132 | x = self.avg_pool_0(x) 133 | x = x.view(-1, self.fc.in_features) 134 | x = self.dropout_0(x) 135 | x = self.fc(x) 136 | return x 137 | 138 | def aux_classifier(self, x): 139 | x = relu(x) 140 | x = self.avg_pool_0(x) 141 | x = x.view(-1, self.fc.in_features) 142 | x = self.dropout_0(x) 143 | x = self.fc(x) 144 | return x 145 | 146 | def forward(self, x): 147 | if self.use_aux: 148 | x, x_b = self.features(x) 149 | x = self.classifier(x) 150 | x_b = self.aux_classifier(x_b) 151 | return x, x_b 152 | else: 153 | x = self.features(x) 154 | x = self.classifier(x) 155 | return x 156 | 157 | 158 | def nasnetmobile(num_classes=1000, pretrained=False): 159 | return NASNet(32, 4, 44, 2, num_classes=num_classes, use_aux=True, pretrained=pretrained) 160 | 161 | def nasnetlarge(num_classes=1000, pretrained=False): 162 | return NASNet(96, 6, 168, 2, num_classes=num_classes, use_aux=True, pretrained=pretrained) -------------------------------------------------------------------------------- /nasnet/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import relu 4 | from .droppath import DropPath 5 | 6 | class SeparableConv2d(nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): 9 | super(SeparableConv2d, self).__init__() 10 | self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel, 11 | stride=dw_stride, 12 | padding=dw_padding, 13 | bias=bias, 14 | groups=in_channels) 15 | self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias) 16 | 17 | def forward(self, x): 18 | x = self.depthwise_conv2d(x) 19 | x = self.pointwise_conv2d(x) 20 | return x 21 | 22 | 23 | class TwoSeparables(nn.Module): 24 | 25 | def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): 26 | super(TwoSeparables, self).__init__() 27 | self.separable_0 = SeparableConv2d(in_channels, in_channels, dw_kernel, dw_stride, dw_padding, bias=bias) 28 | self.bn_0 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True) 29 | self.separable_1 = SeparableConv2d(in_channels, out_channels, dw_kernel, 1, dw_padding, bias=bias) 30 | self.bn_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) 31 | 32 | def forward(self, x): 33 | x = relu(x) 34 | x = self.separable_0(x) 35 | x = self.bn_0(x) 36 | x = relu(x) 37 | x = self.separable_1(x) 38 | x = self.bn_1(x) 39 | return x 40 | 41 | class ResizeCell0(nn.Module): 42 | def __init__(self, in_channels_x, in_channels_h, out_channels): 43 | super(ResizeCell0, self).__init__() 44 | self.pool_left_0 = nn.AvgPool2d(3, stride=2, padding=1) 45 | self.conv_left_0 = nn.Conv2d(in_channels_h, out_channels//2, 1, stride=1, bias=False) 46 | 47 | self.pool_left_1 = nn.AvgPool2d(3, stride=2, padding=1) 48 | self.conv_left_1 = nn.Conv2d(in_channels_h, out_channels//2, 1, stride=1, bias=False) 49 | 50 | self.bn_left = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) 51 | 52 | self.conv_right = nn.Conv2d(in_channels_x, out_channels, 1, stride=1, bias=False) 53 | self.bn_right = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) 54 | 55 | def forward(self, x, h): 56 | h = relu(h) 57 | 58 | h_0 = self.pool_left_0(h) 59 | h_0 = self.conv_left_0(h_0) 60 | 61 | h_1 = self.pool_left_1(h) 62 | h_1 = self.conv_left_1(h_1) 63 | 64 | h = torch.cat([h_0, h_1], 1) 65 | h = self.bn_left(h) 66 | 67 | x = relu(x) 68 | x = self.conv_right(x) 69 | x = self.bn_right(x) 70 | 71 | return x, h 72 | 73 | class ResizeCell1(nn.Module): 74 | def __init__(self, in_channels_x, in_channels_h, out_channels): 75 | super(ResizeCell1, self).__init__() 76 | self.conv_left = nn.Conv2d(in_channels_x, out_channels, 1, stride=1, bias=False) 77 | self.bn_left = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) 78 | 79 | self.conv_right = nn.Conv2d(in_channels_h, out_channels, 1, stride=1, bias=False) 80 | self.bn_right = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) 81 | 82 | def forward(self, x, h): 83 | x = relu(x) 84 | x = self.conv_left(x) 85 | x = self.bn_left(x) 86 | 87 | h = relu(h) 88 | h = self.conv_right(h) 89 | h = self.bn_right(h) 90 | 91 | return x, h 92 | 93 | 94 | class ReductionCell(nn.Module): 95 | def __init__(self, in_channels_x, in_channels_h, out_channels, resize_cell=ResizeCell1, keep_prob=0.9): 96 | super(ReductionCell, self).__init__() 97 | 98 | self.resize = resize_cell(in_channels_x, in_channels_h, out_channels) 99 | 100 | self.comb_iter_0_left = DropPath(TwoSeparables(out_channels, out_channels, 7, 2, 3, bias=False), keep_prob) 101 | self.comb_iter_0_right = DropPath(TwoSeparables(out_channels, out_channels, 5, 2, 2, bias=False), keep_prob) 102 | 103 | self.comb_iter_1_left = DropPath(nn.MaxPool2d(3, stride=2, padding=1), keep_prob) 104 | self.comb_iter_1_right = DropPath(TwoSeparables(out_channels, out_channels, 7, 2, 3, bias=False), keep_prob) 105 | 106 | self.comb_iter_2_left = DropPath(nn.AvgPool2d(3, stride=2, padding=1), keep_prob) 107 | self.comb_iter_2_right = DropPath(TwoSeparables(out_channels, out_channels, 5, 2, 2, bias=False), keep_prob) 108 | 109 | self.comb_iter_3_left = DropPath(nn.MaxPool2d(3, stride=2, padding=1), keep_prob) 110 | self.comb_iter_3_right = DropPath(TwoSeparables(out_channels, out_channels, 3, 1, 1, bias=False), keep_prob) 111 | 112 | self.comb_iter_4_left = DropPath(nn.AvgPool2d(3, stride=1, padding=1), keep_prob) 113 | 114 | def forward(self, x, h): 115 | prev = x 116 | 117 | x, h = self.resize(x, h) 118 | 119 | comb_iter_0_left = self.comb_iter_0_left(h) 120 | comb_iter_0_right = self.comb_iter_0_right(x) 121 | comb_iter_0 = comb_iter_0_left + comb_iter_0_right 122 | 123 | comb_iter_1_left = self.comb_iter_1_left(x) 124 | comb_iter_1_right = self.comb_iter_1_right(h) 125 | comb_iter_1 = comb_iter_1_left + comb_iter_1_right 126 | 127 | comb_iter_2_left = self.comb_iter_2_left(x) 128 | comb_iter_2_right = self.comb_iter_2_right(h) 129 | x_comb_iter_2 = comb_iter_2_left + comb_iter_2_right 130 | 131 | comb_iter_3_left = self.comb_iter_3_left(x) 132 | comb_iter_3_right = self.comb_iter_3_right(comb_iter_0) 133 | comb_iter_3 = comb_iter_3_left + comb_iter_3_right 134 | 135 | comb_iter_4_left = self.comb_iter_4_left(comb_iter_0) 136 | comb_iter_4 = comb_iter_4_left + comb_iter_1 137 | 138 | return torch.cat([comb_iter_1, x_comb_iter_2, comb_iter_3, comb_iter_4], 1), prev 139 | 140 | class NormalCell(nn.Module): 141 | def __init__(self, in_channels_x, in_channels_h, out_channels, resize_cell=ResizeCell1, keep_prob=0.9): 142 | super(NormalCell, self).__init__() 143 | self.adjust = resize_cell(in_channels_x, in_channels_h, out_channels) 144 | 145 | self.comb_iter_0_left = DropPath(TwoSeparables(out_channels, out_channels, 3, 1, 1, bias=False), keep_prob) 146 | 147 | self.comb_iter_1_left = DropPath(TwoSeparables(out_channels, out_channels, 3, 1, 1, bias=False), keep_prob) 148 | self.comb_iter_1_right = DropPath(TwoSeparables(out_channels, out_channels, 5, 1, 2, bias=False), keep_prob) 149 | 150 | self.comb_iter_2_left = DropPath(nn.AvgPool2d(3, stride=1, padding=1), keep_prob) 151 | 152 | self.comb_iter_3_left = DropPath(nn.AvgPool2d(3, stride=1, padding=1), keep_prob) 153 | self.comb_iter_3_h = DropPath(nn.AvgPool2d(3, stride=1, padding=1), keep_prob) 154 | 155 | self.comb_iter_4_left = DropPath(TwoSeparables(out_channels, out_channels, 5, 1, 2, bias=False), keep_prob) 156 | self.comb_iter_4_right = DropPath(TwoSeparables(out_channels, out_channels, 3, 1, 1, bias=False), keep_prob) 157 | 158 | def forward(self, x, h): 159 | prev = x 160 | 161 | x, h = self.adjust(x, h) 162 | 163 | comb_iter_0_left = self.comb_iter_0_left(x) 164 | comb_iter_0 = comb_iter_0_left + x 165 | 166 | comb_iter_1_left = self.comb_iter_1_left(h) 167 | comb_iter_1_right = self.comb_iter_1_right(x) 168 | comb_iter_1 = comb_iter_1_left + comb_iter_1_right 169 | 170 | comb_iter_2_left = self.comb_iter_2_left(x) 171 | comb_iter_2 = comb_iter_2_left + h 172 | 173 | comb_iter_3_left = self.comb_iter_3_left(h) 174 | comb_iter_3_right = self.comb_iter_3_h(h) 175 | comb_iter_3 = comb_iter_3_left + comb_iter_3_right 176 | 177 | comb_iter_4_left = self.comb_iter_4_left(h) 178 | comb_iter_4_right = self.comb_iter_4_right(h) 179 | comb_iter_4 = comb_iter_4_left + comb_iter_4_right 180 | 181 | return torch.cat([x, comb_iter_0, comb_iter_1, comb_iter_2, comb_iter_3, comb_iter_4], 1), prev 182 | --------------------------------------------------------------------------------