├── .gitignore ├── data └── cat.jpg ├── genotypes.py ├── utils.py ├── main.py ├── README.md ├── operations.py ├── model.py ├── convert.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /data/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxi116/PNASNet.pytorch/HEAD/data/cat.jpg -------------------------------------------------------------------------------- /genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 4 | 5 | PNASNet = Genotype( 6 | normal = [ 7 | ('sep_conv_5x5', 0), 8 | ('max_pool_3x3', 0), 9 | ('sep_conv_7x7', 1), 10 | ('max_pool_3x3', 1), 11 | ('sep_conv_5x5', 1), 12 | ('sep_conv_3x3', 1), 13 | ('sep_conv_3x3', 4), 14 | ('max_pool_3x3', 1), 15 | ('sep_conv_3x3', 0), 16 | ('skip_connect', 1), 17 | ], 18 | normal_concat = [2, 3, 4, 5, 6], 19 | reduce = [ 20 | ('sep_conv_5x5', 0), 21 | ('max_pool_3x3', 0), 22 | ('sep_conv_7x7', 1), 23 | ('max_pool_3x3', 1), 24 | ('sep_conv_5x5', 1), 25 | ('sep_conv_3x3', 1), 26 | ('sep_conv_3x3', 4), 27 | ('max_pool_3x3', 1), 28 | ('sep_conv_3x3', 0), 29 | ('skip_connect', 1), 30 | ], 31 | reduce_concat = [2, 3, 4, 5, 6], 32 | ) 33 | 34 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def preprocess_for_eval(image, height, width, 5 | central_fraction=0.875, scope=None): 6 | """Prepare one image for evaluation. 7 | 8 | If height and width are specified it would output an image with that size by 9 | applying resize_bilinear. 10 | 11 | If central_fraction is specified it would crop the central fraction of the 12 | input image. 13 | 14 | Args: 15 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 16 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 17 | is [0, MAX], where MAX is largest positive representable number for 18 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 19 | height: integer 20 | width: integer 21 | central_fraction: Optional Float, fraction of the image to crop. 22 | scope: Optional scope for name_scope. 23 | Returns: 24 | 3-D float Tensor of prepared image. 25 | """ 26 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 27 | if image.dtype != tf.float32: 28 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 29 | # Crop the central region of the image with an area containing 87.5% of 30 | # the original image. 31 | if central_fraction: 32 | image = tf.image.central_crop(image, central_fraction=central_fraction) 33 | 34 | if height and width: 35 | # Resize the image to the specified height and width. 36 | image = tf.expand_dims(image, 0) 37 | image = tf.image.resize_bilinear(image, [height, width], 38 | align_corners=False) 39 | image = tf.squeeze(image, [0]) 40 | image = tf.subtract(image, 0.5) 41 | image = tf.multiply(image, 2.0) 42 | return image 43 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | import tensorflow as tf 6 | import torch 7 | import torchvision.datasets as datasets 8 | from torch.autograd import Variable 9 | from model import NetworkImageNet 10 | from genotypes import PNASNet 11 | from utils import preprocess_for_eval 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--valdir', type=str, default='data/val', 15 | help='path to ImageNet val folder') 16 | parser.add_argument('--image_size', type=int, default=331, 17 | help='image size') 18 | parser.add_argument('--num_conv_filters', type=int, default=216, 19 | help='number of filters') 20 | parser.add_argument('--num_classes', type=int, default=1001, 21 | help='number of categories') 22 | parser.add_argument('--num_cells', type=int, default=12, 23 | help='number of cells') 24 | 25 | 26 | def main(): 27 | args = parser.parse_args() 28 | assert torch.cuda.is_available() 29 | 30 | image_ph = tf.placeholder(tf.uint8, (None, None, 3)) 31 | image_proc = preprocess_for_eval(image_ph, args.image_size, args.image_size) 32 | config = tf.ConfigProto() 33 | config.gpu_options.allow_growth = True 34 | sess = tf.Session(config=config) 35 | 36 | model = NetworkImageNet(args.num_conv_filters, args.num_classes, 37 | args.num_cells, False, PNASNet) 38 | model.drop_path_prob = 0 39 | model.eval() 40 | model.load_state_dict(torch.load('data/PNASNet-5_Large.pth')) 41 | model = model.cuda() 42 | 43 | c1, c5 = 0, 0 44 | val_dataset = datasets.ImageFolder(args.valdir) 45 | for i, (image, label) in enumerate(val_dataset): 46 | tf_image_proc = sess.run(image_proc, feed_dict={image_ph: image}) 47 | image = torch.from_numpy(tf_image_proc.transpose((2, 0, 1))) 48 | image = Variable(image).cuda() 49 | logits, _ = model(image.unsqueeze(0)) 50 | top5 = logits.data.cpu().numpy().squeeze().argsort()[::-1][:5] 51 | top1 = top5[0] 52 | if label + 1 == top1: 53 | c1 += 1 54 | if label + 1 in top5: 55 | c5 += 1 56 | print('Test: [{0}/{1}]\t' 57 | 'Prec@1 {2:.3f}\t' 58 | 'Prec@5 {3:.3f}\t'.format( 59 | i + 1, len(val_dataset), c1 / (i + 1.), c5 / (i + 1.))) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PNASNet.pytorch 2 | 3 | PyTorch implementation of [PNASNet-5](https://arxiv.org/1712.00559). Specifically, PyTorch code from [this repository](https://github.com/quark0/darts) is adapted to completely match both [my implemetation](https://github.com/chenxi116/PNASNet.TF) and the [official implementation](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/pnasnet.py) of PNASNet-5, both written in TensorFlow. This complete match allows the pretrained TF model to be exactly converted to PyTorch: see `convert.py`. 4 | 5 | If you use the code, please cite: 6 | ```bash 7 | @inproceedings{liu2018progressive, 8 | author = {Chenxi Liu and 9 | Barret Zoph and 10 | Maxim Neumann and 11 | Jonathon Shlens and 12 | Wei Hua and 13 | Li{-}Jia Li and 14 | Li Fei{-}Fei and 15 | Alan L. Yuille and 16 | Jonathan Huang and 17 | Kevin Murphy}, 18 | title = {Progressive Neural Architecture Search}, 19 | booktitle = {European Conference on Computer Vision}, 20 | year = {2018} 21 | } 22 | ``` 23 | 24 | ## Requirements 25 | 26 | - TensorFlow 1.8.0 (for image preprocessing) 27 | - PyTorch 0.4.0 28 | - torchvision 0.2.1 29 | 30 | ## Data and Model Preparation 31 | 32 | - Download the ImageNet validation set and move images to labeled subfolders. To do the latter, you can use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh). Make sure the folder `val` is under `data/`. 33 | - Download [PNASNet.TF](https://github.com/chenxi116/PNASNet.TF) and follow its README to download the `PNASNet-5_Large_331` pretrained model. 34 | - Convert TensorFlow model to PyTorch model: 35 | ```bash 36 | python convert.py 37 | ``` 38 | 39 | ## Notes on Model Conversion 40 | 41 | - In both TensorFlow implementations, `net[0]` means `prev` and `net[1]` means `prev_prev`. However, in the [PyTorch implementation](https://github.com/quark0/darts), `states[0]` means `prev_prev` and `states[1]` means `prev`. I followed the PyTorch implemetation in this repository. This is why the 0 and 1 in PNASCell specification are reversed. 42 | - The default value of `eps` in BatchNorm layers is `1e-3` in TensorFlow and `1e-5` in PyTorch. I changed all BatchNorm `eps` values to `1e-3` (see `operations.py`) to exactly match the TensorFlow pretrained model. 43 | - The TensorFlow pretrained model uses `tf.image.resize_bilinear` to resize the image (see `utils.py`). I cannot find a python function that exactly matches this function's behavior (also see [this thread](https://github.com/tensorflow/tensorflow/issues/6720) and [this post](https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35) on this topic), so currently in `main.py` I call TensorFlow to do the image preprocessing, in order to guarantee both models have the identical input. 44 | - When converting the model from TensorFlow to PyTorch (i.e. `convert.py`), I use input image size of 323 instead of 331. This is because the 'SAME' padding in TensorFlow may differ from padding in PyTorch in some layers (see [this link](https://stackoverflow.com/questions/37674306/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-t); basically TF may only pad 1 right and bottom, whereas PyTorch always pads 1 for all four margins). However, they behave exactly the same when image size is 323: `conv0` does not have padding, so feature size becomes 161, then 81, 41, etc. 45 | - The exact conversion when image size is 323 is also corroborated by the following table: 46 | 47 | Image Size | Official TensorFlow Model | Converted PyTorch Model 48 | --- | --- | --- 49 | (331, 331) | (0.829, 0.962) | (0.828, 0.961) 50 | (323, 323) | (0.827, 0.961) | (0.827, 0.961) 51 | 52 | 53 | ## Usage 54 | 55 | ```bash 56 | python main.py 57 | ``` 58 | 59 | The last printed line should read: 60 | ```bash 61 | Test: [50000/50000] Prec@1 0.828 Prec@5 0.961 62 | ``` 63 | -------------------------------------------------------------------------------- /operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | OPS = { 5 | 'none' : lambda C_in, C_out, stride, affine: Zero(stride), 6 | 'avg_pool_3x3' : lambda C_in, C_out, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if C_in == C_out else nn.Sequential( 7 | nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 8 | nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), 9 | nn.BatchNorm2d(C_out, eps=1e-3, affine=affine) 10 | ), 11 | 'max_pool_3x3' : lambda C_in, C_out, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1) if C_in == C_out else nn.Sequential( 12 | nn.MaxPool2d(3, stride=stride, padding=1), 13 | nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), 14 | nn.BatchNorm2d(C_out, eps=1e-3, affine=affine) 15 | ), 16 | 'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 else ReLUConvBN(C_in, C_out, 1, stride, 0, affine=affine), 17 | 'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine), 18 | 'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine), 19 | 'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine), 20 | 'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine), 21 | 'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine), 22 | 'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: nn.Sequential( 23 | nn.ReLU(inplace=False), 24 | nn.Conv2d(C_in, C_in, (1,7), stride=(1, stride), padding=(0, 3), bias=False), 25 | nn.Conv2d(C_in, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), 26 | nn.BatchNorm2d(C_out, eps=1e-3, affine=affine) 27 | ), 28 | } 29 | 30 | class ReLUConvBN(nn.Module): 31 | 32 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 33 | super(ReLUConvBN, self).__init__() 34 | self.op = nn.Sequential( 35 | nn.ReLU(inplace=False), 36 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 37 | nn.BatchNorm2d(C_out, eps=1e-3, affine=affine) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.op(x) 42 | 43 | class DilConv(nn.Module): 44 | 45 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 46 | super(DilConv, self).__init__() 47 | self.op = nn.Sequential( 48 | nn.ReLU(inplace=False), 49 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 50 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 51 | nn.BatchNorm2d(C_out, eps=1e-3, affine=affine), 52 | ) 53 | 54 | def forward(self, x): 55 | return self.op(x) 56 | 57 | 58 | class SepConv(nn.Module): 59 | 60 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 61 | super(SepConv, self).__init__() 62 | self.op = nn.Sequential( 63 | nn.ReLU(inplace=False), 64 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 65 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 66 | nn.BatchNorm2d(C_out, eps=1e-3, affine=affine), 67 | nn.ReLU(inplace=False), 68 | nn.Conv2d(C_out, C_out, kernel_size=kernel_size, stride=1, padding=padding, groups=C_out, bias=False), 69 | nn.Conv2d(C_out, C_out, kernel_size=1, padding=0, bias=False), 70 | nn.BatchNorm2d(C_out, eps=1e-3, affine=affine), 71 | ) 72 | 73 | def forward(self, x): 74 | return self.op(x) 75 | 76 | 77 | class Identity(nn.Module): 78 | 79 | def __init__(self): 80 | super(Identity, self).__init__() 81 | 82 | def forward(self, x): 83 | return x 84 | 85 | 86 | class Zero(nn.Module): 87 | 88 | def __init__(self, stride): 89 | super(Zero, self).__init__() 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | if self.stride == 1: 94 | return x.mul(0.) 95 | return x[:,:,::self.stride,::self.stride].mul(0.) 96 | 97 | 98 | class FactorizedReduce(nn.Module): 99 | 100 | def __init__(self, C_in, C_out, affine=True): 101 | super(FactorizedReduce, self).__init__() 102 | assert C_out % 2 == 0 103 | self.relu = nn.ReLU(inplace=False) 104 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 105 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 106 | self.bn = nn.BatchNorm2d(C_out, eps=1e-3, affine=affine) 107 | self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) 108 | 109 | def forward(self, x): 110 | x = self.relu(x) 111 | y = self.pad(x) 112 | out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1) 113 | out = self.bn(out) 114 | return out 115 | 116 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operations import * 4 | from torch.autograd import Variable 5 | # from utils import drop_path 6 | 7 | 8 | class Cell(nn.Module): 9 | 10 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): 11 | super(Cell, self).__init__() 12 | print(C_prev_prev, C_prev, C) 13 | self.reduction = reduction 14 | 15 | if reduction_prev is None: 16 | self.preprocess0 = Identity() 17 | elif reduction_prev is True: 18 | self.preprocess0 = FactorizedReduce(C_prev_prev, C) 19 | else: 20 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) 21 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) 22 | 23 | if reduction: 24 | op_names, indices = zip(*genotype.reduce) 25 | concat = genotype.reduce_concat 26 | else: 27 | op_names, indices = zip(*genotype.normal) 28 | concat = genotype.normal_concat 29 | 30 | assert len(op_names) == len(indices) 31 | self._steps = len(op_names) // 2 32 | self._concat = concat 33 | self.multiplier = len(concat) 34 | 35 | self._ops = nn.ModuleList() 36 | for name, index in zip(op_names, indices): 37 | stride = 2 if reduction and index < 2 else 1 38 | if reduction_prev is None and index == 0: 39 | op = OPS[name](C_prev_prev, C, stride, True) 40 | else: 41 | op = OPS[name](C, C, stride, True) 42 | self._ops += [op] 43 | self._indices = indices 44 | 45 | def forward(self, s0, s1, drop_prob): 46 | s0 = self.preprocess0(s0) 47 | s1 = self.preprocess1(s1) 48 | 49 | states = [s0, s1] 50 | for i in range(self._steps): 51 | h1 = states[self._indices[2*i]] 52 | h2 = states[self._indices[2*i+1]] 53 | op1 = self._ops[2*i] 54 | op2 = self._ops[2*i+1] 55 | h1 = op1(h1) 56 | h2 = op2(h2) 57 | # if self.training and drop_prob > 0.: 58 | # if not isinstance(op1, Identity): 59 | # h1 = drop_path(h1, drop_prob) 60 | # if not isinstance(op2, Identity): 61 | # h2 = drop_path(h2, drop_prob) 62 | s = h1 + h2 63 | states += [s] 64 | return torch.cat([states[i] for i in self._concat], dim=1) 65 | 66 | 67 | class AuxiliaryHeadImageNet(nn.Module): 68 | 69 | def __init__(self, C, num_classes): 70 | """assuming input size 14x14""" 71 | super(AuxiliaryHeadImageNet, self).__init__() 72 | self.features = nn.Sequential( 73 | nn.ReLU(inplace=True), 74 | nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), 75 | nn.Conv2d(C, 128, 1, bias=False), 76 | nn.BatchNorm2d(128), 77 | nn.ReLU(inplace=True), 78 | nn.Conv2d(128, 768, 2, bias=False), 79 | nn.BatchNorm2d(768), 80 | nn.ReLU(inplace=True) 81 | ) 82 | self.classifier = nn.Linear(768, num_classes) 83 | 84 | def forward(self, x): 85 | x = self.features(x) 86 | x = self.classifier(x.view(x.size(0),-1)) 87 | return x 88 | 89 | 90 | class NetworkImageNet(nn.Module): 91 | 92 | def __init__(self, C, num_classes, layers, auxiliary, genotype): 93 | super(NetworkImageNet, self).__init__() 94 | self._layers = layers 95 | self._auxiliary = auxiliary 96 | 97 | self.conv0 = nn.Conv2d(3, 96, kernel_size=3, stride=2, padding=0, bias=False) 98 | self.conv0_bn = nn.BatchNorm2d(96, eps=1e-3) 99 | self.stem1 = Cell(genotype, 96, 96, C // 4, True, None) 100 | self.stem2 = Cell(genotype, 96, C * self.stem1.multiplier // 4, C // 2, True, True) 101 | 102 | C_prev_prev, C_prev, C_curr = C * self.stem1.multiplier // 4, C * self.stem2.multiplier // 2, C 103 | 104 | self.cells = nn.ModuleList() 105 | reduction_prev = True 106 | for i in xrange(layers): 107 | if i in [layers // 3, 2 * layers // 3]: 108 | C_curr *= 2 109 | reduction = True 110 | else: 111 | reduction = False 112 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 113 | reduction_prev = reduction 114 | self.cells += [cell] 115 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr 116 | if i == 2 * layers // 3: 117 | C_to_auxiliary = C_prev 118 | 119 | if auxiliary: 120 | self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) 121 | self.relu = nn.ReLU(inplace=False) 122 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 123 | self.classifier = nn.Linear(C_prev, num_classes) 124 | 125 | def forward(self, input): 126 | logits_aux = None 127 | s0 = self.conv0(input) 128 | s0 = self.conv0_bn(s0) 129 | s1 = self.stem1(s0, s0, self.drop_path_prob) 130 | s0, s1 = s1, self.stem2(s0, s1, self.drop_path_prob) 131 | for i, cell in enumerate(self.cells): 132 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 133 | if i == 2 * self._layers // 3: 134 | if self._auxiliary and self.training: 135 | logits_aux = self.auxiliary_head(s1) 136 | s1 = self.relu(s1) 137 | out = self.global_pooling(s1) 138 | logits = self.classifier(out.view(out.size(0), -1)) 139 | return logits, logits_aux 140 | 141 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | from model import NetworkImageNet 7 | from genotypes import PNASNet 8 | from operations import * 9 | from utils import preprocess_for_eval 10 | 11 | import sys 12 | import os 13 | sys.path.append('../PNASNet.TF') 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 15 | 16 | import tensorflow as tf 17 | from pnasnet import build_pnasnet_large, pnasnet_large_arg_scope 18 | slim = tf.contrib.slim 19 | 20 | 21 | class ConvertPNASNet(object): 22 | 23 | def __init__(self): 24 | self.image = Image.open('data/cat.jpg') 25 | self.read_tf_weight() 26 | self.write_pytorch_weight() 27 | 28 | def read_tf_weight(self): 29 | self.weight_dict = {} 30 | image_ph = tf.placeholder(tf.uint8, (None, None, 3)) 31 | image_proc = preprocess_for_eval(image_ph, 323, 323) 32 | with slim.arg_scope(pnasnet_large_arg_scope()): 33 | logits, end_points = build_pnasnet_large( 34 | tf.expand_dims(image_proc, 0), num_classes=1001, is_training=False) 35 | 36 | config = tf.ConfigProto() 37 | config.gpu_options.allow_growth = True 38 | sess = tf.Session(config=config) 39 | ckpt_restorer = tf.train.Saver() 40 | ckpt_restorer.restore(sess, '../PNASNet.TF/data/model.ckpt') 41 | 42 | weight_keys = [var.name[:-2] for var in tf.global_variables()] 43 | weight_vals = sess.run(tf.global_variables()) 44 | for weight_key, weight_val in zip(weight_keys, weight_vals): 45 | self.weight_dict[weight_key] = weight_val 46 | 47 | self.tf_logits, self.tf_end_points, self.tf_image_proc = sess.run( 48 | [logits, end_points, image_proc], feed_dict={image_ph: self.image}) 49 | 50 | def write_pytorch_weight(self): 51 | model = NetworkImageNet(216, 1001, 12, False, PNASNet) 52 | model.drop_path_prob = 0 53 | model.eval() 54 | 55 | self.used_keys = [] 56 | self.convert_conv(model.conv0, 'conv0/weights') 57 | self.convert_bn(model.conv0_bn, 'conv0_bn/gamma', 'conv0_bn/beta', 58 | 'conv0_bn/moving_mean', 'conv0_bn/moving_variance') 59 | self.convert_cell(model.stem1, 'cell_stem_0/') 60 | self.convert_cell(model.stem2, 'cell_stem_1/') 61 | 62 | for i in range(12): 63 | self.convert_cell(model.cells[i], 'cell_{}/'.format(i)) 64 | 65 | self.convert_fc(model.classifier, 'final_layer/FC/weights', 66 | 'final_layer/FC/biases') 67 | 68 | print('Conversion complete!') 69 | print('Check 1: whether all TF variables are used...') 70 | assert len(self.weight_dict) == len(self.used_keys) 71 | print('Pass!') 72 | 73 | model = model.cuda() 74 | image = self.tf_image_proc.transpose((2, 0, 1)) 75 | image = Variable(self.Tensor(image)).cuda() 76 | logits, _ = model(image.unsqueeze(0)) 77 | self.pytorch_logits = logits.data.cpu().numpy() 78 | 79 | print('Check 2: whether logits have small diff...') 80 | assert np.max(np.abs(self.tf_logits - self.pytorch_logits)) < 1e-5 81 | print('Pass!') 82 | 83 | model_path = 'data/PNASNet-5_Large.pth' 84 | torch.save(model.state_dict(), model_path) 85 | print('PyTorch model saved to {}'.format(model_path)) 86 | 87 | def convert_cell(self, cell, name): 88 | # cell.preprocess0 89 | assert isinstance(cell.preprocess0, FactorizedReduce) or isinstance(cell.preprocess0, ReLUConvBN) or isinstance(cell.preprocess0, Identity) 90 | if isinstance(cell.preprocess0, FactorizedReduce): 91 | self.convert_conv(cell.preprocess0.conv_1, name + 'path1_conv/weights') 92 | self.convert_conv(cell.preprocess0.conv_2, name + 'path2_conv/weights') 93 | self.convert_bn(cell.preprocess0.bn, name + 'final_path_bn/gamma', 94 | name + 'final_path_bn/beta', name + 'final_path_bn/moving_mean', 95 | name + 'final_path_bn/moving_variance') 96 | else: 97 | if name + 'prev_1x1/weights' in self.weight_dict: 98 | self.convert_conv(cell.preprocess0.op[1], name + 'prev_1x1/weights') 99 | self.convert_bn(cell.preprocess0.op[2], name + 'prev_bn/gamma', 100 | name + 'prev_bn/beta', name + 'prev_bn/moving_mean', 101 | name + 'prev_bn/moving_variance') 102 | # else preprocess0 is Identity or = preprocess1; do nothing 103 | 104 | # cell.preprocess1 105 | assert isinstance(cell.preprocess1, ReLUConvBN) 106 | self.convert_conv(cell.preprocess1.op[1], name + '1x1/weights') 107 | self.convert_bn(cell.preprocess1.op[2], name + 'beginning_bn/gamma', 108 | name + 'beginning_bn/beta', name + 'beginning_bn/moving_mean', 109 | name + 'beginning_bn/moving_variance') 110 | 111 | # cell._ops 112 | for i in range(len(cell._ops)): 113 | side = 'left/' if i % 2 == 0 else 'right/' 114 | prefix = name + 'comb_iter_{}/'.format(i // 2) + side 115 | if isinstance(cell._ops[i], SepConv): 116 | suffix = '{0}x{0}'.format(cell._ops[i].op[1].kernel_size[0]) 117 | 118 | self.convert_conv(cell._ops[i].op[1], 119 | prefix + 'separable_' + suffix + '_1/depthwise_weights', sep=True) 120 | self.convert_conv(cell._ops[i].op[2], 121 | prefix + 'separable_' + suffix + '_1/pointwise_weights', sep=False) 122 | self.convert_bn(cell._ops[i].op[3], 123 | prefix + 'bn_sep_' + suffix + '_1/gamma', 124 | prefix + 'bn_sep_' + suffix + '_1/beta', 125 | prefix + 'bn_sep_' + suffix + '_1/moving_mean', 126 | prefix + 'bn_sep_' + suffix + '_1/moving_variance') 127 | self.convert_conv(cell._ops[i].op[5], 128 | prefix + 'separable_' + suffix + '_2/depthwise_weights', sep=True) 129 | self.convert_conv(cell._ops[i].op[6], 130 | prefix + 'separable_' + suffix + '_2/pointwise_weights', sep=False) 131 | self.convert_bn(cell._ops[i].op[7], 132 | prefix + 'bn_sep_' + suffix + '_2/gamma', 133 | prefix + 'bn_sep_' + suffix + '_2/beta', 134 | prefix + 'bn_sep_' + suffix + '_2/moving_mean', 135 | prefix + 'bn_sep_' + suffix + '_2/moving_variance') 136 | elif isinstance(cell._ops[i], ReLUConvBN): 137 | # skip_connect with stride > 1 138 | self.convert_conv(cell._ops[i].op[1], prefix + '1x1/weights') 139 | self.convert_bn(cell._ops[i].op[2], 140 | prefix + 'bn_1/gamma', prefix + 'bn_1/beta', 141 | prefix + 'bn_1/moving_mean', prefix + 'bn_1/moving_variance') 142 | elif isinstance(cell._ops[i], nn.Sequential): 143 | # max_pool or avg_pool with C_in != C_out 144 | self.convert_conv(cell._ops[i][1], prefix + '1x1/weights') 145 | self.convert_bn(cell._ops[i][2], 146 | prefix + 'bn_1/gamma', prefix + 'bn_1/beta', 147 | prefix + 'bn_1/moving_mean', prefix + 'bn_1/moving_variance') 148 | 149 | def convert_conv(self, conv2d, weights_key, sep=False): 150 | weights = self.weight_dict[weights_key] 151 | if sep: 152 | # TF: [filter_height, filter_width, in_channels, channel_multiplier] 153 | # TF: [1, 1, channel_multiplier * in_channels, channel_multiplier] 154 | # PyTorch: [out_channels, in_channels // groups, *kernel_size] 155 | weights = np.transpose(weights, (2, 3, 0, 1)) 156 | else: 157 | # TF: [filter_height, filter_width, in_channels, out_channels] 158 | # PyTorch: [out_channels, in_channels, *kernel_size] 159 | weights = np.transpose(weights, (3, 2, 0, 1)) 160 | assert conv2d.weight.shape == self.Param(weights).shape, '{0} vs {1}'.format(conv2d.weight.shape, self.Param(weights).shape) 161 | conv2d.weight = self.Param(weights) 162 | self.used_keys += [weights_key] 163 | 164 | def convert_bn(self, bn, gamma_key, beta_key, moving_mean_key, moving_var_key): 165 | gamma = self.weight_dict[gamma_key] 166 | beta = self.weight_dict[beta_key] 167 | moving_mean = self.weight_dict[moving_mean_key] 168 | moving_var = self.weight_dict[moving_var_key] 169 | assert bn.weight.shape == self.Param(gamma).shape 170 | assert bn.bias.shape == self.Param(beta).shape 171 | assert bn.running_mean.shape == self.Tensor(moving_mean).shape 172 | assert bn.running_var.shape == self.Tensor(moving_var).shape 173 | bn.weight = self.Param(gamma) 174 | bn.bias = self.Param(beta) 175 | bn.running_mean = self.Tensor(moving_mean) 176 | bn.running_var = self.Tensor(moving_var) 177 | self.used_keys += [gamma_key, beta_key, moving_mean_key, moving_var_key] 178 | 179 | def convert_fc(self, fc, weights_key, biases_key): 180 | weights = self.weight_dict[weights_key] 181 | biases = self.weight_dict[biases_key] 182 | weights = np.transpose(weights) 183 | assert fc.weight.shape == self.Param(weights).shape 184 | assert fc.bias.shape == self.Param(biases).shape 185 | fc.weight = self.Param(weights) 186 | fc.bias = self.Param(biases) 187 | self.used_keys += [weights_key, biases_key] 188 | 189 | def Param(self, x): 190 | return torch.nn.Parameter(torch.from_numpy(x)) 191 | 192 | def Tensor(self, x): 193 | return torch.from_numpy(x) 194 | 195 | 196 | if __name__ == '__main__': 197 | ConvertPNASNet() 198 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 Chenxi Liu. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | --------------------------------------------------------------------------------