├── Model └── read me.txt ├── POCINet.py ├── README.md ├── Result └── read me.txt ├── TSPOANet.py ├── TSPORTNet-test.py ├── TSPORTNet-train.py ├── TestingModel.py ├── TrainingModel-2.py ├── __pycache__ ├── config.py ├── data ├── data │ ├── t10k-images-idx3-ubyte.gz │ ├── t10k-images.idx3-ubyte │ ├── t10k-labels-idx1-ubyte.gz │ ├── t10k-labels.idx1-ubyte │ └── train-images-idx3-ubyte.gz ├── mnist │ ├── t10k-images.idx3-ubyte │ └── t10k-labels.idx1-ubyte └── smallNORB.py ├── eval_accuracy.py ├── eval_recon_histogram.py ├── imgs ├── NLDF(1).py ├── read me.txt ├── spread_loss_norb.png ├── test_accuracy.png ├── test_accuracy_norb.png ├── training_loss.png └── vgg16(1).py ├── index.md ├── logdir └── NLDF.py └── nets ├── __init__.py ├── alexnet.py ├── alexnet_test.py ├── cifarnet.py ├── cyclegan.py ├── cyclegan_test.py ├── dcgan.py ├── dcgan_test.py ├── inception.py ├── inception_resnet_v2.py ├── inception_resnet_v2_test.py ├── inception_utils.py ├── inception_v1.py ├── inception_v1_test.py ├── inception_v2.py ├── inception_v2_test.py ├── inception_v3.py ├── inception_v3_test.py ├── inception_v4.py ├── inception_v4_test.py ├── lenet.py ├── mobilenet ├── README.md ├── conv_blocks.py ├── madds_top1_accuracy.png ├── mnet_v1_vs_v2_pixel1_latency.png ├── mobilenet.py ├── mobilenet_example.ipynb ├── mobilenet_v2.py └── mobilenet_v2_test.py ├── mobilenet_v1.md ├── mobilenet_v1.png ├── mobilenet_v1.py ├── mobilenet_v1_eval.py ├── mobilenet_v1_test.py ├── mobilenet_v1_train.py ├── nasnet ├── README.md ├── __init__.py ├── nasnet.py ├── nasnet_test.py ├── nasnet_utils.py ├── nasnet_utils_test.py ├── pnasnet.py └── pnasnet_test.py ├── nets_factory.py ├── nets_factory_test.py ├── overfeat.py ├── overfeat_test.py ├── pix2pix.py ├── pix2pix_test.py ├── resnet_utils.py ├── resnet_v1.py ├── resnet_v1_test.py ├── resnet_v2.py ├── resnet_v2_test.py ├── vgg.py └── vgg_test.py /Model/read me.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TSPORTNet 2 | 3 | 1. Implementation: Python3.6, Tesorflow 1.4.0, Numpy 1.16.0. The npy file the VGG-16 can be download at this link: https://github.com/machrisaa/tensorflow-vgg. 4 | 5 | 2. TSPOANet: 6 | 7 | @inproceedings{liu2019employing, 8 | 9 | title={Employing Deep Part-Object Relationships for Salient Object Detection}, 10 | 11 | author={Liu, Yi and Zhang, Qiang and Zhang, Dingwen and Han, Jungong}, 12 | 13 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 14 | 15 | pages={1232--1241}, 16 | 17 | year={2019} 18 | 19 | } 20 | 21 | Code is TSPOANet.py. 22 | 23 | Saliency map can be downloaded on 链接: https://pan.baidu.com/s/1kGbfkDY6Juf1RDXm_-AzLg 提取码: rs3p. 24 | 25 | 3. TSPORTNet: An extended version has been published in T-PAMI: 26 | 27 | @article{liu2021part, 28 | 29 | title={Part-Object Relational Visual Saliency}, 30 | 31 | author={Liu, Yi and Zhang, Dingwen and Zhang, Qiang and Han, Jungong}, 32 | 33 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 34 | 35 | year={2021}, 36 | 37 | publisher={IEEE} 38 | 39 | } 40 | 41 | Code is TSPORTNet.py. 42 | 43 | Saliency map can be downloaded on 链接:https://pan.baidu.com/s/1c-5gNo8Rj3VMboIkJiSJEA 44 | 提取码:tspo. 45 | 46 | 4. Integrating Part-Object Relationship and Contrast for Camouflaged Object Detection 47 | 48 | @article{liu2021integrating, 49 | 50 | title={Integrating Part-Object Relationship and Contrast for Camouflaged Object Detection}, 51 | 52 | author={Liu, Yi and Zhang, Dingwen and Zhang, Qiang and Han, Jungong}, 53 | 54 | journal={IEEE Transactions on Information Forensics and Security}, 55 | 56 | volume={16}, 57 | 58 | pages={5154--5166}, 59 | 60 | year={2021}, 61 | 62 | publisher={IEEE} 63 | 64 | } 65 | 66 | Code is POCINet.py. 67 | 68 | Camouflage map can bs found on 链接:https://pan.baidu.com/s/1zExyx1npIlRk65foO3r3Qg 69 | 提取码:poci. 70 | -------------------------------------------------------------------------------- /Result/read me.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /TestingModel.py: -------------------------------------------------------------------------------- 1 | import TSPORTNetcv2 2 | import numpy as np 3 | import TSPORTNet 4 | import os 5 | import sys 6 | import tensorflow as tf 7 | import time 8 | import vgg16 9 | 10 | 11 | def load_img_list(dataset): 12 | 13 | if dataset == 'ECSSD': 14 | path = '/home/hpc/LY/dataset/imgs/ECSSD' 15 | elif dataset == 'HKU-IS-TE': 16 | path = '/home/hpc/LY/dataset/imgs/HKU-IS-TE' 17 | elif dataset == 'SOD': 18 | path = '/home/hpc/LY/dataset/imgs/SOD' 19 | elif dataset == 'PASCAL-S': 20 | path = '/home/hpc/LY/dataset/imgs/pascal' 21 | elif dataset == 'DUT-OMRON': 22 | path = '/home/hpc/LY/dataset/imgs/DUT-OMRON' 23 | elif dataset == 'DUTS-TE': 24 | path = '/home/hpc/LY/dataset/imgs/DUTS-TE' 25 | 26 | imgs = os.listdir(path) 27 | 28 | return path, imgs 29 | 30 | 31 | if __name__ == "__main__": 32 | 33 | model = TSPORTNet.Model() 34 | model.build_model() 35 | os.environ['CUDA_VISIBLE_DEVICES'] = "2" 36 | 37 | sess = tf.Session() 38 | sess.run(tf.global_variables_initializer()) 39 | img_size = TSPORTNet.img_size 40 | label_size = TSPORTNet.label_size 41 | 42 | ckpt = tf.train.get_checkpoint_state('Model/') 43 | saver = tf.train.Saver() 44 | saver.restore(sess, ckpt.model_checkpoint_path) 45 | 46 | #datasets = ['MSRA-B', 'HKU-IS', 'DUT-OMRON', 47 | # 'PASCAL-S', 'ECSSD', 'SOD'] 48 | datasets = ['ECSSD'] 49 | 50 | if not os.path.exists('Result'): 51 | os.mkdir('Result') 52 | 53 | for dataset in datasets: 54 | path, imgs = load_img_list(dataset) 55 | 56 | save_dir = 'Result/' + dataset 57 | if not os.path.exists(save_dir): 58 | os.mkdir(save_dir) 59 | 60 | save_dir = 'Result/' + dataset 61 | if not os.path.exists(save_dir): 62 | os.mkdir(save_dir) 63 | 64 | for f_img in imgs: 65 | 66 | img = cv2.imread(os.path.join(path, f_img)) 67 | img_name, ext = os.path.splitext(f_img) 68 | 69 | if img is not None: 70 | ori_img = img.copy() 71 | img_shape = img.shape 72 | img = cv2.resize(img, (img_size, img_size)) - vgg16.VGG_MEAN 73 | img = img.reshape((1, img_size, img_size, 3)) 74 | 75 | start_time = time.time() 76 | #result = sess.run(model.Prob, 77 | # feed_dict={model.input_holder: img, 78 | # model.keep_prob: 1}) 79 | result = sess.run(model.Prob, 80 | feed_dict={model.input_holder: img, 81 | }) 82 | print("--- %s seconds ---" % (time.time() - start_time)) 83 | 84 | result = np.reshape(result, (label_size, label_size, 2)) 85 | result = result[:, :, 0] 86 | 87 | result = cv2.resize(np.squeeze(result), (img_shape[1], img_shape[0])) 88 | 89 | save_name = os.path.join(save_dir, img_name+'.png') 90 | cv2.imwrite(save_name, (result*255).astype(np.uint8)) 91 | 92 | sess.close() 93 | -------------------------------------------------------------------------------- /TrainingModel-2.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import TSPORTNet 4 | #import NLDFnew 5 | import vgg16 6 | import tensorflow as tf 7 | import os 8 | from config import cfg 9 | 10 | 11 | def load_training_list(): 12 | 13 | with open('/home/hpc/LY/NLDF/dataset/train/trainDUTS.txt') as f: 14 | lines = f.read().splitlines() 15 | 16 | files = [] 17 | labels = [] 18 | 19 | for line in lines: 20 | labels.append('%s' % line) 21 | files.append('%s' % line.replace('.png', '.jpg')) 22 | 23 | return files, labels 24 | 25 | 26 | def load_train_val_list(): 27 | 28 | files = [] 29 | labels = [] 30 | 31 | with open('/home/hpc/LY/NLDF/dataset/valid.txt') as f: 32 | lines = f.read().splitlines() 33 | 34 | for line in lines: 35 | labels.append('/home/hpc/LY/NLDF/dataset/valid_mask/%s' % line) 36 | files.append('/home/hpc/LY/NLDF/dataset/valid_img/%s' % line.replace('.png', '.jpg')) 37 | 38 | #with open('F:/related code/Non-Local Deep Features for Salient Object Detection (CVPR 2017)/NLDF-master/dataset/valid_imgs.txt') as f: 39 | # lines = f.read().splitlines() 40 | 41 | #for line in lines: 42 | # labels.append('dataset/MSRA-B/annotation/%s' % line) 43 | # files.append('dataset/MSRA-B/image/%s' % line.replace('.png', '.jpg')) 44 | 45 | return files, labels 46 | 47 | 48 | if __name__ == "__main__": 49 | 50 | model = TSPORTNet.Model() 51 | model.build_model() 52 | 53 | #modelnew = NLDFnew.Model() 54 | #modelnew.build_model() 55 | os.environ['CUDA_VISIBLE_DEVICES'] = "3" 56 | 57 | #sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) 58 | 59 | sess = tf.Session() 60 | 61 | global_step = 0 62 | max_grad_norm = 20 63 | tvars = tf.trainable_variables() 64 | grads, _ = tf.clip_by_global_norm(tf.gradients(model.Loss_Mean, tvars), max_grad_norm) 65 | lr = 1e-5 66 | #lr = tf.train.exponential_decay(0.01, global_step, 1, 0.5, staircase = False) 67 | opt = tf.train.AdamOptimizer(lr) 68 | train_op = opt.apply_gradients(zip(grads, tvars)) 69 | 70 | sess.run(tf.global_variables_initializer()) 71 | saver = tf.train.Saver() 72 | 73 | train_list, label_list = load_training_list() 74 | 75 | n_epochs = 5 76 | img_size = TSPORTNet.img_size 77 | label_size = TSPORTNet.label_size 78 | 79 | f = open("./Model/loss.txt", "w") 80 | f.truncate() 81 | f.close() 82 | 83 | 84 | for i in range(n_epochs): 85 | 86 | whole_loss = 0.0 87 | whole_acc = 0.0 88 | count = 0 89 | for f_img, f_label in zip(train_list, label_list): 90 | 91 | img = cv2.imread(f_img).astype(np.float32) 92 | img_flip0 = cv2.flip(img, 0) 93 | img_flip1 = cv2.flip(img, 1) 94 | img_flip2 = cv2.flip(img, -1) 95 | 96 | label = cv2.imread(f_label)[:, :, 0].astype(np.float32) 97 | label_flip0 = cv2.flip(label, 0) 98 | label_flip1 = cv2.flip(label, 1) 99 | label_flip2 = cv2.flip(label, -1) 100 | 101 | img = cv2.resize(img, (img_size, img_size)) - vgg16.VGG_MEAN 102 | label = cv2.resize(label, (label_size, label_size)) 103 | label = label.astype(np.float32) / 255. 104 | 105 | img = img.reshape((1, img_size, img_size, 3)) 106 | label = np.stack((label, 1-label), axis=2) 107 | label = np.reshape(label, [-1, 2]) 108 | 109 | if count==0: 110 | _, loss, acc, S11, S12, S21, S22 = sess.run([train_op, model.Loss_Mean, model.accuracy, model.S11, model.S12, model.S21, model.S22], 111 | feed_dict={model.input_holder: img, 112 | model.label_holder: label, 113 | model.S11_holder: np.ones([3*3*int(np.floor(cfg.B/2)), int(np.floor(cfg.C/2))], dtype=np.float32), 114 | model.S12_holder: np.ones([3*3*int(np.floor(cfg.B/2)), int(np.floor(cfg.C/2))], dtype=np.float32), 115 | model.S21_holder: np.ones([3*3*int(np.floor(cfg.C/2)), int(np.floor(cfg.D/2))], dtype=np.float32), 116 | model.S22_holder: np.ones([3*3*int(np.floor(cfg.C/2)), int(np.floor(cfg.D/2))], dtype=np.float32) 117 | }) 118 | else: 119 | _, loss, acc, S11, S12, S21, S22 = sess.run([train_op, model.Loss_Mean, model.accuracy, model.S11, model.S12, model.S21, model.S22], 120 | feed_dict={model.input_holder: img, 121 | model.label_holder: label, 122 | model.S11_holder: S11, 123 | model.S12_holder: S12, 124 | model.S21_holder: S21, 125 | model.S22_holder: S22 126 | }) 127 | 128 | whole_loss += loss 129 | whole_acc += acc 130 | count = count + 1 131 | 132 | # add vertical flip image for training 133 | img_flip1 = cv2.resize(img_flip1, (img_size, img_size)) - vgg16.VGG_MEAN 134 | label_flip1 = cv2.resize(label_flip1, (label_size, label_size)) 135 | label_flip1 = label_flip1.astype(np.float32) / 255. 136 | 137 | 138 | img_flip1 = img_flip1.reshape((1, img_size, img_size, 3)) 139 | label_flip1 = np.stack((label_flip1, 1 - label_flip1), axis=2) 140 | label_flip1 = np.reshape(label_flip1, [-1, 2]) 141 | 142 | _, loss, acc, S11, S12, S21, S22 = sess.run([train_op, model.Loss_Mean, model.accuracy, model.S11, model.S12, model.S21, model.S22], 143 | feed_dict={model.input_holder: img_flip1, 144 | model.label_holder: label_flip1, 145 | model.S11_holder: S11, 146 | model.S12_holder: S12, 147 | model.S21_holder: S21, 148 | model.S22_holder: S22 149 | }) 150 | 151 | whole_loss += loss 152 | whole_acc += acc 153 | count = count + 1 154 | 155 | ## add horizon flip image for training 156 | #img_flip1 = cv2.resize(img_flip1, (img_size, img_size)) - vgg16.VGG_MEAN 157 | #label_flip1 = cv2.resize(label_flip1, (label_size, label_size)) 158 | #label_flip1 = label_flip1.astype(np.float32) / 255. 159 | 160 | 161 | #img_flip1 = img_flip1.reshape((1, img_size, img_size, 3)) 162 | #label_flip1 = np.stack((label_flip1, 1 - label_flip1), axis=2) 163 | #label_flip1 = np.reshape(label_flip1, [-1, 2]) 164 | 165 | #_, loss, acc = sess.run([train_op, model.Loss_Mean, model.accuracy], 166 | #feed_dict={model.input_holder: img_flip1, 167 | #model.label_holder: label_flip1}) 168 | 169 | #whole_loss += loss 170 | #whole_acc += acc 171 | #count = count + 1 172 | 173 | ## add horizon and vertical flip image for training 174 | #img_flip2 = cv2.resize(img_flip2, (img_size, img_size)) - vgg16.VGG_MEAN 175 | #label_flip2 = cv2.resize(label_flip2, (label_size, label_size)) 176 | #label_flip2 = label_flip2.astype(np.float32) / 255. 177 | 178 | 179 | #img_flip2 = img_flip2.reshape((1, img_size, img_size, 3)) 180 | #label_flip2 = np.stack((label_flip2, 1 - label_flip2), axis=2) 181 | #label_flip2 = np.reshape(label_flip2, [-1, 2]) 182 | 183 | #_, loss, acc = sess.run([train_op, model.Loss_Mean, model.accuracy], 184 | #feed_dict={model.input_holder: img_flip2, 185 | #model.label_holder: label_flip2}) 186 | 187 | #whole_loss += loss 188 | #whole_acc += acc 189 | #count = count + 1 190 | 191 | if count % 1 == 0: 192 | print ("Loss of %d images: %f, Accuracy: %f" % (count, (whole_loss/count), (whole_acc/count))) 193 | 194 | #if whole_loss != whole_loss: 195 | # lr = lr*0.1 196 | # opt = tf.train.AdamOptimizer(lr) 197 | # train_op = opt.apply_gradients(zip(grads, tvars)) 198 | # ckpt = tf.train.get_checkpoint_state("./Model") 199 | # if ckpt and ckpt.model_checkpoint_path: 200 | # print("Continue training from the model {}".format(ckpt.model_checkpoint_path)) 201 | # saver.restore(sess, ckpt.model_checkpoint_path) 202 | #else: 203 | print ("Epoch %d: Loss: %f, Accuracy: %f " % (i+1, (whole_loss/count), (whole_acc/count))) 204 | saver.save(sess, 'Model/model.ckpt', global_step=i+1) 205 | 206 | f = open("./Model/loss.txt", "a+") 207 | #f.write("Lr: %f, Epoch %d: Loss: %f, Accuracy: %f \n" % (lr, i+1, (whole_loss/count), (whole_acc/count))) 208 | f.write("Epoch %d: Loss: %f, Accuracy: %f \n" % (i+1, (whole_loss/count), (whole_acc/count))) 209 | f.close() 210 | 211 | 212 | 213 | #os.mkdir('Model') 214 | #saver.save(sess, 'Model/model.ckpt', global_step=n_epochs) 215 | -------------------------------------------------------------------------------- /__pycache__: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | flags = tf.app.flags 4 | 5 | ############################ 6 | # hyper parameters # 7 | ############################ 8 | flags.DEFINE_float('ac_lambda0', 0.01, '\lambda in the activation function a_c, iteration 0') 9 | flags.DEFINE_float('ac_lambda_step', 0.01, 10 | 'It is described that \lambda increases at each iteration with a fixed schedule, however specific super parameters is absent.') 11 | 12 | flags.DEFINE_integer('batch_size', 1, 'batch size') 13 | flags.DEFINE_integer('epoch', 20, 'epoch') 14 | flags.DEFINE_integer('iter_routing', 2, 'number of iterations') 15 | flags.DEFINE_float('m_schedule', 0.2, 'the m will get to 0.9 at current epoch') 16 | flags.DEFINE_float('epsilon', 1e-9, 'epsilon') 17 | flags.DEFINE_float('m_plus', 0.9, 'the parameter of m plus') 18 | flags.DEFINE_float('m_minus', 0.1, 'the parameter of m minus') 19 | flags.DEFINE_float('lambda_val', 0.5, 'down weight of the loss for absent digit classes') 20 | flags.DEFINE_boolean('weight_reg', False, 'train with regularization of weights') 21 | flags.DEFINE_string('norm', 'norm2', 'norm type') 22 | ################################ 23 | # structure parameters # 24 | ################################ 25 | flags.DEFINE_integer('A', 32, 'number of channels in output from ReLU Conv1') 26 | flags.DEFINE_integer('B', 16, 'number of capsules in output from PrimaryCaps') 27 | flags.DEFINE_integer('C', 16, 'number of channels in output from ConvCaps1') 28 | flags.DEFINE_integer('D', 8, 'number of channels in output from ConvCaps2') 29 | 30 | ############################ 31 | # environment setting # 32 | ############################ 33 | flags.DEFINE_string('dataset', 'data/MSRA-B', 'the path for dataset') 34 | flags.DEFINE_string('dataset_fashion_MSRA-B', 'data/fashion_MSRA-B', 'the path for dataset') 35 | flags.DEFINE_boolean('is_train', True, 'train or predict phase') 36 | flags.DEFINE_integer('num_threads', 8, 'number of threads of enqueueing exampls') 37 | flags.DEFINE_string('logdir', 'logdir', 'logs directory') 38 | flags.DEFINE_string('test_logdir', 'test_logdir', 'test logs directory') 39 | 40 | cfg = tf.app.flags.FLAGS 41 | 42 | 43 | def get_coord_add(dataset_name: str): 44 | import numpy as np 45 | # TODO: get coord add for cifar10/100 datasets (32x32x3) 46 | options = {'MSRA-B': ([[[8., 8.], [12., 8.], [16., 8.]], 47 | [[8., 12.], [12., 12.], [16., 12.]], 48 | [[8., 16.], [12., 16.], [16., 16.]]], 28.), 49 | 'smallNORB': ([[[8., 8.], [12., 8.], [16., 8.], [24., 8.]], 50 | [[8., 12.], [12., 12.], [16., 12.], [24., 12.]], 51 | [[8., 16.], [12., 16.], [16., 16.], [24., 16.]], 52 | [[8., 24.], [12., 24.], [16., 24.], [24., 24.]]], 32.) 53 | } 54 | coord_add, scale = options[dataset_name] 55 | 56 | coord_add = np.array(coord_add, dtype=np.float32) / scale 57 | 58 | return coord_add 59 | 60 | 61 | def get_dataset_size_train(dataset_name: str): 62 | options = {'MSRA-B': 3000, 'smallNORB': 23400 * 2, 63 | 'fashion_MSRA-B': 55000, 'cifar10': 50000, 'cifar100': 50000} 64 | return options[dataset_name] 65 | 66 | 67 | def get_dataset_size_test(dataset_name: str): 68 | options = {'MSRA-B': 10000, 'smallNORB': 23400 * 2, 69 | 'fashion_MSRA-B': 10000, 'cifar10': 10000, 'cifar10': 10000} 70 | return options[dataset_name] 71 | 72 | 73 | def get_num_classes(dataset_name: str): 74 | options = {'MSRA-B': 2, 'smallNORB': 5, 'fashion_MSRA-B': 10, 'cifar10': 10, 'cifar100': 100} 75 | return options[dataset_name] 76 | 77 | 78 | from utils import create_inputs_MSRA_B, create_inputs_norb, create_inputs_cifar10, create_inputs_cifar100 79 | 80 | 81 | #def get_create_inputs(dataset_name: str, f_img, f_label, is_train: bool, epochs: int): 82 | # options = {'MSRA-B': lambda: create_inputs_MSRA_B(f_img, f_label, is_train), 83 | # 'fashion_MSRA-B': lambda: create_inputs_MSRA_B(is_train), 84 | # 'smallNORB': lambda: create_inputs_norb(is_train, epochs), 85 | # 'cifar10': lambda: create_inputs_cifar10(is_train), 86 | # 'cifa100': lambda: create_inputs_cifa100(is_train)} 87 | # return options[dataset_name] 88 | 89 | def get_create_inputs(dataset_name: str, f_img, f_label, is_train: bool, epochs: int): 90 | return create_inputs_MSRA_B(f_img, f_label, is_train) 91 | -------------------------------------------------------------------------------- /data/data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/data/data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/data/t10k-images.idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/data/data/t10k-images.idx3-ubyte -------------------------------------------------------------------------------- /data/data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/data/data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /data/data/t10k-labels.idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /data/data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/data/data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/mnist/t10k-images.idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/data/mnist/t10k-images.idx3-ubyte -------------------------------------------------------------------------------- /data/mnist/t10k-labels.idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /data/smallNORB.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import argh 3 | import sys 4 | import os 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | import logging 9 | import daiquiri 10 | 11 | daiquiri.setup(level=logging.DEBUG) 12 | logger = daiquiri.getLogger(__name__) 13 | 14 | from numpy.random import RandomState 15 | 16 | prng = RandomState(1234567890) 17 | 18 | from matplotlib import pyplot as plt 19 | import cv2 20 | 21 | 22 | def plot_imgs(inputs, num, label): 23 | """Plot smallNORB images helper""" 24 | # fig = plt.figure() 25 | # plt.title('Show images') 26 | # r = np.floor(np.sqrt(len(inputs))).astype(int) 27 | # for i in range(r**2): 28 | # size = inputs[i].shape[1] 29 | # sample = inputs[i].flatten().reshape(size, size) 30 | # a = fig.add_subplot(r, r, i + 1) 31 | # a.imshow(sample, cmap='gray') 32 | # plt.show() 33 | inputs = (inputs).astype(np.uint8) 34 | for i in range(len(inputs)): 35 | size = inputs[i].shape[1] 36 | cv2.imwrite('%d' % num+'_%d' % i+label+'.jpg', inputs[i].flatten().reshape(size, size)) 37 | return 38 | 39 | 40 | def write_data_to_tfrecord(kind: str, chunkify=False): 41 | """Credit: https://github.com/shashanktyagi/DC-GAN-on-smallNORB-dataset/blob/master/src/model.py 42 | Original Version: shashanktyagi 43 | """ 44 | 45 | """Plan A: write dataset into one big tfrecord""" 46 | 47 | """Plan B: write dataset into manageable chuncks""" 48 | CHUNK = 24300 * 2 / 10 # create 10 chunks 49 | 50 | from time import time 51 | start = time() 52 | """Read data""" 53 | if kind == "train": 54 | fid_images = open('./smallNORB/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat', 'rb') 55 | fid_labels = open('./smallNORB/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat', 'rb') 56 | elif kind == "test": 57 | fid_images = open('./smallNORB/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat', 'rb') 58 | fid_labels = open('./smallNORB/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat', 'rb') 59 | else: 60 | logger.warning('Please choose either training or testing data to preprocess.') 61 | 62 | logger.debug('Read data ' + kind + ' finish.') 63 | 64 | """Preprocessing""" 65 | for i in range(6): 66 | a = fid_images.read(4) # header 67 | 68 | total_num_images = 24300 * 2 69 | 70 | for j in range(total_num_images // CHUNK if chunkify else 1): 71 | 72 | num_images = CHUNK if chunkify else total_num_images # 24300 * 2 73 | images = np.zeros((num_images, 96 * 96)) 74 | for idx in range(num_images): 75 | temp = fid_images.read(96 * 96) 76 | images[idx, :] = np.fromstring(temp, 'uint8') 77 | for i in range(5): 78 | a = fid_labels.read(4) # header 79 | labels = np.fromstring(fid_labels.read(num_images * np.dtype('int32').itemsize), 'int32') 80 | labels = np.repeat(labels, 2) 81 | 82 | logger.debug('Load data %d finish. Start filling chunk %d.' % (j, j)) 83 | 84 | # make dataset permuatation reproduceable 85 | perm = prng.permutation(num_images) 86 | images = images[perm] 87 | labels = labels[perm] 88 | 89 | """display image""" 90 | ''' 91 | if j == 0: 92 | plot_imgs(images[:10]) 93 | ''' 94 | 95 | """Write to tfrecord""" 96 | writer = tf.python_io.TFRecordWriter("./" + kind + "%d.tfrecords" % j) 97 | for i in range(num_images): 98 | if i % 100 == 0: 99 | logger.debug('Write ' + kind + ' images %d' % ((j + 1) * i)) 100 | img = images[i, :].tobytes() 101 | lab = labels[i].astype(np.int64) 102 | example = tf.train.Example(features=tf.train.Features(feature={ 103 | "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[lab])), 104 | 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])) 105 | })) 106 | writer.write(example.SerializeToString()) # 序列化为字符串 107 | writer.close() 108 | 109 | # Should take less than a minute 110 | logger.info('Done writing ' + kind + '. Total time: %f' % (time() - start)) 111 | 112 | 113 | def tfrecord(): 114 | """Wrapper""" 115 | write_data_to_tfrecord(kind='train', chunkify=False) 116 | write_data_to_tfrecord(kind='test', chunkify=False) 117 | logger.info('Writing train & test to TFRecord done.') 118 | 119 | 120 | def read_norb_tfrecord(filenames, epochs: int): 121 | """Credit: http: // ycszen.github.io / 2016 / 08 / 17 / TensorFlow高效读取数据/ 122 | Original Version: Ycszen-物语 123 | """ 124 | 125 | assert isinstance(filenames, list) 126 | 127 | # 根据文件名生成一个队列 128 | filename_queue = tf.train.string_input_producer(filenames, num_epochs=epochs) 129 | reader = tf.TFRecordReader() 130 | _, serialized_example = reader.read(filename_queue) # 返回文件名和文件 131 | features = tf.parse_single_example(serialized_example, 132 | features={ 133 | 'label': tf.FixedLenFeature([], tf.int64), 134 | 'img_raw': tf.FixedLenFeature([], tf.string), 135 | }) 136 | img = tf.decode_raw(features['img_raw'], tf.float64) 137 | #logger.debug('Raw->img shape: {}'.format(img.get_shape())) 138 | img = tf.reshape(img, [96, 96, 1]) 139 | img = tf.cast(img, tf.float32) # * (1. / 255) # left unnormalized 140 | label = tf.cast(features['label'], tf.int32) 141 | # label = tf.one_hot(label, 5, dtype=tf.int32) # left dense label 142 | #logger.debug('Raw->img shape: {}, label shape: {}'.format(img.get_shape(), label.get_shape())) 143 | return img, label 144 | 145 | 146 | def test(is_train=True): 147 | """Instruction on how to read data from tfrecord""" 148 | 149 | # 1. use regular expression to find all files we want 150 | import re 151 | if is_train: 152 | CHUNK_RE = re.compile(r"train\d+\.tfrecords") 153 | else: 154 | CHUNK_RE = re.compile(r"test\d+\.tfrecords") 155 | 156 | processed_dir = './data' 157 | # 2. parse them into a list of file name 158 | chunk_files = [os.path.join(processed_dir, fname) 159 | for fname in os.listdir(processed_dir) 160 | if CHUNK_RE.match(fname)] 161 | # 3. pass argument into read method 162 | image, label = read_norb_tfrecord(chunk_files, 2) 163 | 164 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 165 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 166 | 167 | image = tf.image.resize_images(image, [48, 48]) 168 | 169 | """Batch Norm""" 170 | params_shape = [image.get_shape()[-1]] 171 | beta = tf.get_variable( 172 | 'beta', params_shape, tf.float32, 173 | initializer=tf.constant_initializer(0.0, tf.float32)) 174 | gamma = tf.get_variable( 175 | 'gamma', params_shape, tf.float32, 176 | initializer=tf.constant_initializer(1.0, tf.float32)) 177 | mean, variance = tf.nn.moments(image, [0, 1, 2]) 178 | image = tf.nn.batch_normalization(image, mean, variance, beta, gamma, 0.001) 179 | 180 | image = tf.random_crop(image, [32, 32, 1]) 181 | 182 | batch_size = 8 183 | x, y = tf.train.shuffle_batch([image, label], batch_size=batch_size, capacity=batch_size * 64, 184 | min_after_dequeue=batch_size * 32, allow_smaller_final_batch=False) 185 | logger.debug('x shape: {}, y shape: {}'.format(x.get_shape(), y.get_shape())) 186 | 187 | # 初始化所有的op 188 | init = tf.global_variables_initializer() 189 | 190 | with tf.Session() as sess: 191 | sess.run(tf.local_variables_initializer()) 192 | sess.run(init) 193 | # 启动队列 194 | coord = tf.train.Coordinator() 195 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 196 | 197 | for i in range(2): 198 | val, l = sess.run([x, y]) 199 | # l = to_categorical(l, 12) 200 | print(val, l) 201 | coord.join() 202 | 203 | logger.debug('Test read tf record Succeed') 204 | 205 | 206 | parser = argparse.ArgumentParser() 207 | argh.add_commands(parser, [tfrecord, test]) 208 | 209 | if __name__ == "__main__": 210 | argh.dispatch(parser) 211 | -------------------------------------------------------------------------------- /eval_accuracy.py: -------------------------------------------------------------------------------- 1 | """ 2 | License: Apache-2.0 3 | Author: Suofei Zhang | Hang Yu 4 | E-mail: zhangsuofei at njupt.edu.cn | hangyu5 at illinois.edu 5 | """ 6 | 7 | import tensorflow as tf 8 | from config import cfg, get_coord_add, get_dataset_size_train, get_dataset_size_test, get_num_classes, get_create_inputs 9 | import time 10 | import os 11 | import capsnet_em as net 12 | import tensorflow.contrib.slim as slim 13 | 14 | import logging 15 | import daiquiri 16 | 17 | daiquiri.setup(level=logging.DEBUG) 18 | logger = daiquiri.getLogger(__name__) 19 | 20 | 21 | def main(args): 22 | """Get dataset hyperparameters.""" 23 | assert len(args) == 3 and isinstance(args[1], str) and isinstance(args[2], str) 24 | dataset_name = args[1] 25 | model_name = args[2] 26 | coord_add = get_coord_add(dataset_name) 27 | dataset_size_train = get_dataset_size_train(dataset_name) 28 | dataset_size_test = get_dataset_size_test(dataset_name) 29 | num_classes = get_num_classes(dataset_name) 30 | create_inputs = get_create_inputs( 31 | dataset_name, is_train=False, epochs=cfg.epoch) 32 | 33 | """Set reproduciable random seed""" 34 | tf.set_random_seed(1234) 35 | 36 | with tf.Graph().as_default(): 37 | num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size) 38 | num_batches_test = int(dataset_size_test / cfg.batch_size * 0.1) 39 | 40 | batch_x, batch_labels = create_inputs() 41 | batch_x = slim.batch_norm(batch_x, center=False, is_training=False, trainable=False) 42 | if model_name == "caps": 43 | output, _ = net.build_arch(batch_x, coord_add, 44 | is_train=False, num_classes=num_classes) 45 | elif model_name == "cnn_baseline": 46 | output = net.build_arch_baseline(batch_x, 47 | is_train=False, num_classes=num_classes) 48 | else: 49 | raise "Please select model from 'caps' or 'cnn_baseline' as the secondary argument of eval.py!" 50 | batch_acc = net.test_accuracy(output, batch_labels) 51 | saver = tf.train.Saver() 52 | 53 | step = 0 54 | 55 | summaries = [] 56 | summaries.append(tf.summary.scalar('accuracy', batch_acc)) 57 | summary_op = tf.summary.merge(summaries) 58 | 59 | with tf.Session(config=tf.ConfigProto( 60 | allow_soft_placement=True, log_device_placement=False)) as sess: 61 | sess.run(tf.local_variables_initializer()) 62 | sess.run(tf.global_variables_initializer()) 63 | 64 | coord = tf.train.Coordinator() 65 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 66 | if not os.path.exists(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)): 67 | os.makedirs(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)) 68 | summary_writer = tf.summary.FileWriter( 69 | cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name), graph=sess.graph) # graph=sess.graph, huge! 70 | 71 | files = os.listdir(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name)) 72 | for epoch in range(1, cfg.epoch): 73 | # requires a regex to adapt the loss value in the file name here 74 | ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch) 75 | for __file in files: 76 | if __file.endswith(ckpt_re + ".index"): 77 | ckpt = os.path.join(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name), __file[:-6]) 78 | # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch)) 79 | saver.restore(sess, ckpt) 80 | 81 | accuracy_sum = 0 82 | for i in range(num_batches_test): 83 | batch_acc_v, summary_str = sess.run([batch_acc, summary_op]) 84 | print('%d batches are tested.' % step) 85 | summary_writer.add_summary(summary_str, step) 86 | 87 | accuracy_sum += batch_acc_v 88 | 89 | step += 1 90 | 91 | ave_acc = accuracy_sum / num_batches_test 92 | print('the average accuracy is %f' % ave_acc) 93 | 94 | coord.join(threads) 95 | 96 | 97 | if __name__ == "__main__": 98 | tf.app.run() 99 | -------------------------------------------------------------------------------- /eval_recon_histogram.py: -------------------------------------------------------------------------------- 1 | """ 2 | License: Apache-2.0 3 | Author: Suofei Zhang | Hang Yu 4 | E-mail: zhangsuofei at njupt.edu.cn | hangyu5 at illinois.edu 5 | """ 6 | 7 | import tensorflow as tf 8 | from config import cfg, get_coord_add, get_dataset_size_train, get_dataset_size_test, get_num_classes, get_create_inputs 9 | import time 10 | import os 11 | import capsnet_em as net 12 | import tensorflow.contrib.slim as slim 13 | from data.smallNORB import plot_imgs 14 | 15 | import logging 16 | import daiquiri 17 | 18 | daiquiri.setup(level=logging.DEBUG) 19 | logger = daiquiri.getLogger(__name__) 20 | 21 | 22 | def main(args): 23 | """Get dataset hyperparameters.""" 24 | assert len(args) == 3 and isinstance(args[1], str) and isinstance(args[2], str) 25 | dataset_name = args[1] 26 | model_name = args[2] 27 | 28 | """Set reproduciable random seed""" 29 | tf.set_random_seed(1234) 30 | 31 | coord_add = get_coord_add(dataset_name) 32 | dataset_size_train = get_dataset_size_train(dataset_name) 33 | dataset_size_test = get_dataset_size_test(dataset_name) 34 | num_classes = get_num_classes(dataset_name) 35 | create_inputs = get_create_inputs( 36 | dataset_name, is_train=False, epochs=cfg.epoch) 37 | 38 | with tf.Graph().as_default(): 39 | num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size) 40 | num_batches_test = 2 # int(dataset_size_test / cfg.batch_size * 0.1) 41 | 42 | batch_x, batch_labels = create_inputs() 43 | batch_squash = tf.divide(batch_x, 255.) 44 | batch_x_norm = slim.batch_norm(batch_x, center=False, is_training=False, trainable=False) 45 | output, pose_out = net.build_arch(batch_x_norm, coord_add, 46 | is_train=False, num_classes=num_classes) 47 | tf.logging.debug(pose_out.get_shape()) 48 | 49 | batch_acc = net.test_accuracy(output, batch_labels) 50 | m_op = tf.constant(0.9) 51 | loss, spread_loss, mse, recon_img_squash = net.spread_loss( 52 | output, pose_out, batch_squash, batch_labels, m_op) 53 | tf.summary.scalar('spread_loss', spread_loss) 54 | tf.summary.scalar('reconstruction_loss', mse) 55 | tf.summary.scalar('all_loss', loss) 56 | data_size = int(batch_x.get_shape()[1]) 57 | recon_img = tf.multiply(tf.reshape(recon_img_squash, shape=[ 58 | cfg.batch_size, data_size, data_size, 1]), 255.) 59 | orig_img = tf.reshape(batch_x, shape=[ 60 | cfg.batch_size, data_size, data_size, 1]) 61 | tf.summary.image('orig_image', orig_img) 62 | tf.summary.image('recon_image', recon_img) 63 | saver = tf.train.Saver() 64 | 65 | step = 0 66 | 67 | tf.summary.scalar('accuracy', batch_acc) 68 | summary_op = tf.summary.merge_all() 69 | 70 | with tf.Session(config=tf.ConfigProto( 71 | allow_soft_placement=True, log_device_placement=False)) as sess: 72 | sess.run(tf.local_variables_initializer()) 73 | sess.run(tf.global_variables_initializer()) 74 | 75 | coord = tf.train.Coordinator() 76 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 77 | if not os.path.exists(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)): 78 | os.makedirs(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)) 79 | summary_writer = tf.summary.FileWriter( 80 | cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name), graph=sess.graph) # graph=sess.graph, huge! 81 | 82 | files = os.listdir(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name)) 83 | for epoch in range(45, 46): 84 | # requires a regex to adapt the loss value in the file name here 85 | ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch) 86 | for __file in files: 87 | if __file.endswith(ckpt_re + ".index"): 88 | ckpt = os.path.join( 89 | cfg.logdir + '/{}/{}/'.format(model_name, dataset_name), __file[:-6]) 90 | # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch)) 91 | saver.restore(sess, ckpt) 92 | 93 | accuracy_sum = 0 94 | for i in range(num_batches_test): 95 | batch_acc_v, summary_str, orig_image, recon_image = sess.run( 96 | [batch_acc, summary_op, orig_img, recon_img]) 97 | print('%d batches are tested.' % step) 98 | summary_writer.add_summary(summary_str, step) 99 | 100 | accuracy_sum += batch_acc_v 101 | 102 | step += 1 103 | # display original/reconstructed images in matplotlib 104 | plot_imgs(orig_image, i, 'ori') 105 | plot_imgs(recon_image, i, 'rec') 106 | 107 | ave_acc = accuracy_sum / num_batches_test 108 | print('the average accuracy is %f' % ave_acc) 109 | 110 | 111 | if __name__ == "__main__": 112 | tf.app.run() 113 | -------------------------------------------------------------------------------- /imgs/read me.txt: -------------------------------------------------------------------------------- 1 | https://github.com/ivankreso/resnet-tensorflow -------------------------------------------------------------------------------- /imgs/spread_loss_norb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/imgs/spread_loss_norb.png -------------------------------------------------------------------------------- /imgs/test_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/imgs/test_accuracy.png -------------------------------------------------------------------------------- /imgs/test_accuracy_norb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/imgs/test_accuracy_norb.png -------------------------------------------------------------------------------- /imgs/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/imgs/training_loss.png -------------------------------------------------------------------------------- /index.md: -------------------------------------------------------------------------------- 1 | ## Welcome to GitHub Pages 2 | 3 | You can use the [editor on GitHub](https://github.com/liuyi1989/liuyi/edit/master/index.md) to maintain and preview the content for your website in Markdown files. 4 | 5 | Whenever you commit to this repository, GitHub Pages will run [Jekyll](https://jekyllrb.com/) to rebuild the pages in your site, from the content in your Markdown files. 6 | 7 | ### Markdown 8 | 9 | Markdown is a lightweight and easy-to-use syntax for styling your writing. It includes conventions for 10 | 11 | ```markdown 12 | Syntax highlighted code block 13 | 14 | # Personal Information 15 | ## Header 2 16 | ### Header 3 17 | 18 | - Bulleted 19 | - List 20 | 21 | 1. Numbered 22 | 2. List 23 | 24 | **Bold** and _Italic_ and `Code` text 25 | 26 | [Link](url) and ![Image](src) 27 | ``` 28 | 29 | For more details see [GitHub Flavored Markdown](https://guides.github.com/features/mastering-markdown/). 30 | 31 | ### Jekyll Themes 32 | 33 | Your Pages site will use the layout and styles from the Jekyll theme you have selected in your [repository settings](https://github.com/liuyi1989/liuyi/settings). The name of this theme is saved in the Jekyll `_config.yml` configuration file. 34 | 35 | ### Support or Contact 36 | 37 | Having trouble with Pages? Check out our [documentation](https://help.github.com/categories/github-pages-basics/) or [contact support](https://github.com/contact) and we’ll help you sort it out. 38 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 256, 256 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 224, 224 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 70 | expected_names = ['alexnet_v2/conv1', 71 | 'alexnet_v2/pool1', 72 | 'alexnet_v2/conv2', 73 | 'alexnet_v2/pool2', 74 | 'alexnet_v2/conv3', 75 | 'alexnet_v2/conv4', 76 | 'alexnet_v2/conv5', 77 | 'alexnet_v2/pool5', 78 | 'alexnet_v2/fc6', 79 | 'alexnet_v2/fc7', 80 | 'alexnet_v2/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 224, 224 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = alexnet.alexnet_v2(inputs, num_classes) 91 | expected_names = ['alexnet_v2/conv1', 92 | 'alexnet_v2/pool1', 93 | 'alexnet_v2/conv2', 94 | 'alexnet_v2/pool2', 95 | 'alexnet_v2/conv3', 96 | 'alexnet_v2/conv4', 97 | 'alexnet_v2/conv5', 98 | 'alexnet_v2/pool5', 99 | 'alexnet_v2/fc6', 100 | 'alexnet_v2/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('alexnet_v2/fc7')) 104 | self.assertListEqual(net.get_shape().as_list(), 105 | [batch_size, 1, 1, 4096]) 106 | 107 | def testModelVariables(self): 108 | batch_size = 5 109 | height, width = 224, 224 110 | num_classes = 1000 111 | with self.test_session(): 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | alexnet.alexnet_v2(inputs, num_classes) 114 | expected_names = ['alexnet_v2/conv1/weights', 115 | 'alexnet_v2/conv1/biases', 116 | 'alexnet_v2/conv2/weights', 117 | 'alexnet_v2/conv2/biases', 118 | 'alexnet_v2/conv3/weights', 119 | 'alexnet_v2/conv3/biases', 120 | 'alexnet_v2/conv4/weights', 121 | 'alexnet_v2/conv4/biases', 122 | 'alexnet_v2/conv5/weights', 123 | 'alexnet_v2/conv5/biases', 124 | 'alexnet_v2/fc6/weights', 125 | 'alexnet_v2/fc6/biases', 126 | 'alexnet_v2/fc7/weights', 127 | 'alexnet_v2/fc7/biases', 128 | 'alexnet_v2/fc8/weights', 129 | 'alexnet_v2/fc8/biases', 130 | ] 131 | model_variables = [v.op.name for v in slim.get_model_variables()] 132 | self.assertSetEqual(set(model_variables), set(expected_names)) 133 | 134 | def testEvaluation(self): 135 | batch_size = 2 136 | height, width = 224, 224 137 | num_classes = 1000 138 | with self.test_session(): 139 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 140 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 141 | self.assertListEqual(logits.get_shape().as_list(), 142 | [batch_size, num_classes]) 143 | predictions = tf.argmax(logits, 1) 144 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 145 | 146 | def testTrainEvalWithReuse(self): 147 | train_batch_size = 2 148 | eval_batch_size = 1 149 | train_height, train_width = 224, 224 150 | eval_height, eval_width = 300, 400 151 | num_classes = 1000 152 | with self.test_session(): 153 | train_inputs = tf.random_uniform( 154 | (train_batch_size, train_height, train_width, 3)) 155 | logits, _ = alexnet.alexnet_v2(train_inputs) 156 | self.assertListEqual(logits.get_shape().as_list(), 157 | [train_batch_size, num_classes]) 158 | tf.get_variable_scope().reuse_variables() 159 | eval_inputs = tf.random_uniform( 160 | (eval_batch_size, eval_height, eval_width, 3)) 161 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 162 | spatial_squeeze=False) 163 | self.assertListEqual(logits.get_shape().as_list(), 164 | [eval_batch_size, 4, 7, num_classes]) 165 | logits = tf.reduce_mean(logits, [1, 2]) 166 | predictions = tf.argmax(logits, 1) 167 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 168 | 169 | def testForward(self): 170 | batch_size = 1 171 | height, width = 224, 224 172 | with self.test_session() as sess: 173 | inputs = tf.random_uniform((batch_size, height, width, 3)) 174 | logits, _ = alexnet.alexnet_v2(inputs) 175 | sess.run(tf.global_variables_initializer()) 176 | output = sess.run(logits) 177 | self.assertTrue(output.any()) 178 | 179 | if __name__ == '__main__': 180 | tf.test.main() 181 | -------------------------------------------------------------------------------- /nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /nets/cyclegan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Defines the CycleGAN generator and discriminator networks.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import numpy as np 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | 24 | layers = tf.contrib.layers 25 | 26 | 27 | def cyclegan_arg_scope(instance_norm_center=True, 28 | instance_norm_scale=True, 29 | instance_norm_epsilon=0.001, 30 | weights_init_stddev=0.02, 31 | weight_decay=0.0): 32 | """Returns a default argument scope for all generators and discriminators. 33 | 34 | Args: 35 | instance_norm_center: Whether instance normalization applies centering. 36 | instance_norm_scale: Whether instance normalization applies scaling. 37 | instance_norm_epsilon: Small float added to the variance in the instance 38 | normalization to avoid dividing by zero. 39 | weights_init_stddev: Standard deviation of the random values to initialize 40 | the convolution kernels with. 41 | weight_decay: Magnitude of weight decay applied to all convolution kernel 42 | variables of the generator. 43 | 44 | Returns: 45 | An arg-scope. 46 | """ 47 | instance_norm_params = { 48 | 'center': instance_norm_center, 49 | 'scale': instance_norm_scale, 50 | 'epsilon': instance_norm_epsilon, 51 | } 52 | 53 | weights_regularizer = None 54 | if weight_decay and weight_decay > 0.0: 55 | weights_regularizer = layers.l2_regularizer(weight_decay) 56 | 57 | with tf.contrib.framework.arg_scope( 58 | [layers.conv2d], 59 | normalizer_fn=layers.instance_norm, 60 | normalizer_params=instance_norm_params, 61 | weights_initializer=tf.random_normal_initializer(0, weights_init_stddev), 62 | weights_regularizer=weights_regularizer) as sc: 63 | return sc 64 | 65 | 66 | def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose', 67 | pad_mode='REFLECT', align_corners=False): 68 | """Upsamples the given inputs. 69 | 70 | Args: 71 | net: A Tensor of size [batch_size, height, width, filters]. 72 | num_outputs: The number of output filters. 73 | stride: A list of 2 scalars or a 1x2 Tensor indicating the scale, 74 | relative to the inputs, of the output dimensions. For example, if kernel 75 | size is [2, 3], then the output height and width will be twice and three 76 | times the input size. 77 | method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv', 78 | or 'conv2d_transpose'. 79 | pad_mode: mode for tf.pad, one of "CONSTANT", "REFLECT", or "SYMMETRIC". 80 | align_corners: option for method, 'bilinear_upsample_conv'. If true, the 81 | centers of the 4 corner pixels of the input and output tensors are 82 | aligned, preserving the values at the corner pixels. 83 | 84 | Returns: 85 | A Tensor which was upsampled using the specified method. 86 | 87 | Raises: 88 | ValueError: if `method` is not recognized. 89 | """ 90 | with tf.variable_scope('upconv'): 91 | net_shape = tf.shape(net) 92 | height = net_shape[1] 93 | width = net_shape[2] 94 | 95 | # Reflection pad by 1 in spatial dimensions (axes 1, 2 = h, w) to make a 3x3 96 | # 'valid' convolution produce an output with the same dimension as the 97 | # input. 98 | spatial_pad_1 = np.array([[0, 0], [1, 1], [1, 1], [0, 0]]) 99 | 100 | if method == 'nn_upsample_conv': 101 | net = tf.image.resize_nearest_neighbor( 102 | net, [stride[0] * height, stride[1] * width]) 103 | net = tf.pad(net, spatial_pad_1, pad_mode) 104 | net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid') 105 | elif method == 'bilinear_upsample_conv': 106 | net = tf.image.resize_bilinear( 107 | net, [stride[0] * height, stride[1] * width], 108 | align_corners=align_corners) 109 | net = tf.pad(net, spatial_pad_1, pad_mode) 110 | net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid') 111 | elif method == 'conv2d_transpose': 112 | # This corrects 1 pixel offset for images with even width and height. 113 | # conv2d is left aligned and conv2d_transpose is right aligned for even 114 | # sized images (while doing 'SAME' padding). 115 | # Note: This doesn't reflect actual model in paper. 116 | net = layers.conv2d_transpose( 117 | net, num_outputs, kernel_size=[3, 3], stride=stride, padding='valid') 118 | net = net[:, 1:, 1:, :] 119 | else: 120 | raise ValueError('Unknown method: [%s]' % method) 121 | 122 | return net 123 | 124 | 125 | def _dynamic_or_static_shape(tensor): 126 | shape = tf.shape(tensor) 127 | static_shape = tf.contrib.util.constant_value(shape) 128 | return static_shape if static_shape is not None else shape 129 | 130 | 131 | def cyclegan_generator_resnet(images, 132 | arg_scope_fn=cyclegan_arg_scope, 133 | num_resnet_blocks=6, 134 | num_filters=64, 135 | upsample_fn=cyclegan_upsample, 136 | kernel_size=3, 137 | num_outputs=3, 138 | tanh_linear_slope=0.0, 139 | is_training=False): 140 | """Defines the cyclegan resnet network architecture. 141 | 142 | As closely as possible following 143 | https://github.com/junyanz/CycleGAN/blob/master/models/architectures.lua#L232 144 | 145 | FYI: This network requires input height and width to be divisible by 4 in 146 | order to generate an output with shape equal to input shape. Assertions will 147 | catch this if input dimensions are known at graph construction time, but 148 | there's no protection if unknown at graph construction time (you'll see an 149 | error). 150 | 151 | Args: 152 | images: Input image tensor of shape [batch_size, h, w, 3]. 153 | arg_scope_fn: Function to create the global arg_scope for the network. 154 | num_resnet_blocks: Number of ResNet blocks in the middle of the generator. 155 | num_filters: Number of filters of the first hidden layer. 156 | upsample_fn: Upsampling function for the decoder part of the generator. 157 | kernel_size: Size w or list/tuple [h, w] of the filter kernels for all inner 158 | layers. 159 | num_outputs: Number of output layers. Defaults to 3 for RGB. 160 | tanh_linear_slope: Slope of the linear function to add to the tanh over the 161 | logits. 162 | is_training: Whether the network is created in training mode or inference 163 | only mode. Not actually needed, just for compliance with other generator 164 | network functions. 165 | 166 | Returns: 167 | A `Tensor` representing the model output and a dictionary of model end 168 | points. 169 | 170 | Raises: 171 | ValueError: If the input height or width is known at graph construction time 172 | and not a multiple of 4. 173 | """ 174 | # Neither dropout nor batch norm -> dont need is_training 175 | del is_training 176 | 177 | end_points = {} 178 | 179 | input_size = images.shape.as_list() 180 | height, width = input_size[1], input_size[2] 181 | if height and height % 4 != 0: 182 | raise ValueError('The input height must be a multiple of 4.') 183 | if width and width % 4 != 0: 184 | raise ValueError('The input width must be a multiple of 4.') 185 | 186 | if not isinstance(kernel_size, (list, tuple)): 187 | kernel_size = [kernel_size, kernel_size] 188 | 189 | kernel_height = kernel_size[0] 190 | kernel_width = kernel_size[1] 191 | pad_top = (kernel_height - 1) // 2 192 | pad_bottom = kernel_height // 2 193 | pad_left = (kernel_width - 1) // 2 194 | pad_right = kernel_width // 2 195 | paddings = np.array( 196 | [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], 197 | dtype=np.int32) 198 | spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]]) 199 | 200 | with tf.contrib.framework.arg_scope(arg_scope_fn()): 201 | 202 | ########### 203 | # Encoder # 204 | ########### 205 | with tf.variable_scope('input'): 206 | # 7x7 input stage 207 | net = tf.pad(images, spatial_pad_3, 'REFLECT') 208 | net = layers.conv2d(net, num_filters, kernel_size=[7, 7], padding='VALID') 209 | end_points['encoder_0'] = net 210 | 211 | with tf.variable_scope('encoder'): 212 | with tf.contrib.framework.arg_scope( 213 | [layers.conv2d], 214 | kernel_size=kernel_size, 215 | stride=2, 216 | activation_fn=tf.nn.relu, 217 | padding='VALID'): 218 | 219 | net = tf.pad(net, paddings, 'REFLECT') 220 | net = layers.conv2d(net, num_filters * 2) 221 | end_points['encoder_1'] = net 222 | net = tf.pad(net, paddings, 'REFLECT') 223 | net = layers.conv2d(net, num_filters * 4) 224 | end_points['encoder_2'] = net 225 | 226 | ################### 227 | # Residual Blocks # 228 | ################### 229 | with tf.variable_scope('residual_blocks'): 230 | with tf.contrib.framework.arg_scope( 231 | [layers.conv2d], 232 | kernel_size=kernel_size, 233 | stride=1, 234 | activation_fn=tf.nn.relu, 235 | padding='VALID'): 236 | for block_id in xrange(num_resnet_blocks): 237 | with tf.variable_scope('block_{}'.format(block_id)): 238 | res_net = tf.pad(net, paddings, 'REFLECT') 239 | res_net = layers.conv2d(res_net, num_filters * 4) 240 | res_net = tf.pad(res_net, paddings, 'REFLECT') 241 | res_net = layers.conv2d(res_net, num_filters * 4, 242 | activation_fn=None) 243 | net += res_net 244 | 245 | end_points['resnet_block_%d' % block_id] = net 246 | 247 | ########### 248 | # Decoder # 249 | ########### 250 | with tf.variable_scope('decoder'): 251 | 252 | with tf.contrib.framework.arg_scope( 253 | [layers.conv2d], 254 | kernel_size=kernel_size, 255 | stride=1, 256 | activation_fn=tf.nn.relu): 257 | 258 | with tf.variable_scope('decoder1'): 259 | net = upsample_fn(net, num_outputs=num_filters * 2, stride=[2, 2]) 260 | end_points['decoder1'] = net 261 | 262 | with tf.variable_scope('decoder2'): 263 | net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2]) 264 | end_points['decoder2'] = net 265 | 266 | with tf.variable_scope('output'): 267 | net = tf.pad(net, spatial_pad_3, 'REFLECT') 268 | logits = layers.conv2d( 269 | net, 270 | num_outputs, [7, 7], 271 | activation_fn=None, 272 | normalizer_fn=None, 273 | padding='valid') 274 | logits = tf.reshape(logits, _dynamic_or_static_shape(images)) 275 | 276 | end_points['logits'] = logits 277 | end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope 278 | 279 | return end_points['predictions'], end_points 280 | -------------------------------------------------------------------------------- /nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /nets/dcgan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DCGAN generator and discriminator from https://arxiv.org/abs/1511.06434.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from math import log 21 | 22 | from six.moves import xrange # pylint: disable=redefined-builtin 23 | import tensorflow as tf 24 | 25 | slim = tf.contrib.slim 26 | 27 | 28 | def _validate_image_inputs(inputs): 29 | inputs.get_shape().assert_has_rank(4) 30 | inputs.get_shape()[1:3].assert_is_fully_defined() 31 | if inputs.get_shape()[1] != inputs.get_shape()[2]: 32 | raise ValueError('Input tensor does not have equal width and height: ', 33 | inputs.get_shape()[1:3]) 34 | width = inputs.get_shape().as_list()[1] 35 | if log(width, 2) != int(log(width, 2)): 36 | raise ValueError('Input tensor `width` is not a power of 2: ', width) 37 | 38 | 39 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 40 | # setups need the gradient of gradient FusedBatchNormGrad. 41 | def discriminator(inputs, 42 | depth=64, 43 | is_training=True, 44 | reuse=None, 45 | scope='Discriminator', 46 | fused_batch_norm=False): 47 | """Discriminator network for DCGAN. 48 | 49 | Construct discriminator network from inputs to the final endpoint. 50 | 51 | Args: 52 | inputs: A tensor of size [batch_size, height, width, channels]. Must be 53 | floating point. 54 | depth: Number of channels in first convolution layer. 55 | is_training: Whether the network is for training or not. 56 | reuse: Whether or not the network variables should be reused. `scope` 57 | must be given to be reused. 58 | scope: Optional variable_scope. 59 | fused_batch_norm: If `True`, use a faster, fused implementation of 60 | batch norm. 61 | 62 | Returns: 63 | logits: The pre-softmax activations, a tensor of size [batch_size, 1] 64 | end_points: a dictionary from components of the network to their activation. 65 | 66 | Raises: 67 | ValueError: If the input image shape is not 4-dimensional, if the spatial 68 | dimensions aren't defined at graph construction time, if the spatial 69 | dimensions aren't square, or if the spatial dimensions aren't a power of 70 | two. 71 | """ 72 | 73 | normalizer_fn = slim.batch_norm 74 | normalizer_fn_args = { 75 | 'is_training': is_training, 76 | 'zero_debias_moving_mean': True, 77 | 'fused': fused_batch_norm, 78 | } 79 | 80 | _validate_image_inputs(inputs) 81 | inp_shape = inputs.get_shape().as_list()[1] 82 | 83 | end_points = {} 84 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 85 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 86 | with slim.arg_scope([slim.conv2d], 87 | stride=2, 88 | kernel_size=4, 89 | activation_fn=tf.nn.leaky_relu): 90 | net = inputs 91 | for i in xrange(int(log(inp_shape, 2))): 92 | scope = 'conv%i' % (i + 1) 93 | current_depth = depth * 2**i 94 | normalizer_fn_ = None if i == 0 else normalizer_fn 95 | net = slim.conv2d( 96 | net, current_depth, normalizer_fn=normalizer_fn_, scope=scope) 97 | end_points[scope] = net 98 | 99 | logits = slim.conv2d(net, 1, kernel_size=1, stride=1, padding='VALID', 100 | normalizer_fn=None, activation_fn=None) 101 | logits = tf.reshape(logits, [-1, 1]) 102 | end_points['logits'] = logits 103 | 104 | return logits, end_points 105 | 106 | 107 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 108 | # setups need the gradient of gradient FusedBatchNormGrad. 109 | def generator(inputs, 110 | depth=64, 111 | final_size=32, 112 | num_outputs=3, 113 | is_training=True, 114 | reuse=None, 115 | scope='Generator', 116 | fused_batch_norm=False): 117 | """Generator network for DCGAN. 118 | 119 | Construct generator network from inputs to the final endpoint. 120 | 121 | Args: 122 | inputs: A tensor with any size N. [batch_size, N] 123 | depth: Number of channels in last deconvolution layer. 124 | final_size: The shape of the final output. 125 | num_outputs: Number of output features. For images, this is the number of 126 | channels. 127 | is_training: whether is training or not. 128 | reuse: Whether or not the network has its variables should be reused. scope 129 | must be given to be reused. 130 | scope: Optional variable_scope. 131 | fused_batch_norm: If `True`, use a faster, fused implementation of 132 | batch norm. 133 | 134 | Returns: 135 | logits: the pre-softmax activations, a tensor of size 136 | [batch_size, 32, 32, channels] 137 | end_points: a dictionary from components of the network to their activation. 138 | 139 | Raises: 140 | ValueError: If `inputs` is not 2-dimensional. 141 | ValueError: If `final_size` isn't a power of 2 or is less than 8. 142 | """ 143 | normalizer_fn = slim.batch_norm 144 | normalizer_fn_args = { 145 | 'is_training': is_training, 146 | 'zero_debias_moving_mean': True, 147 | 'fused': fused_batch_norm, 148 | } 149 | 150 | inputs.get_shape().assert_has_rank(2) 151 | if log(final_size, 2) != int(log(final_size, 2)): 152 | raise ValueError('`final_size` (%i) must be a power of 2.' % final_size) 153 | if final_size < 8: 154 | raise ValueError('`final_size` (%i) must be greater than 8.' % final_size) 155 | 156 | end_points = {} 157 | num_layers = int(log(final_size, 2)) - 1 158 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 159 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 160 | with slim.arg_scope([slim.conv2d_transpose], 161 | normalizer_fn=normalizer_fn, 162 | stride=2, 163 | kernel_size=4): 164 | net = tf.expand_dims(tf.expand_dims(inputs, 1), 1) 165 | 166 | # First upscaling is different because it takes the input vector. 167 | current_depth = depth * 2 ** (num_layers - 1) 168 | scope = 'deconv1' 169 | net = slim.conv2d_transpose( 170 | net, current_depth, stride=1, padding='VALID', scope=scope) 171 | end_points[scope] = net 172 | 173 | for i in xrange(2, num_layers): 174 | scope = 'deconv%i' % (i) 175 | current_depth = depth * 2 ** (num_layers - i) 176 | net = slim.conv2d_transpose(net, current_depth, scope=scope) 177 | end_points[scope] = net 178 | 179 | # Last layer has different normalizer and activation. 180 | scope = 'deconv%i' % (num_layers) 181 | net = slim.conv2d_transpose( 182 | net, depth, normalizer_fn=None, activation_fn=None, scope=scope) 183 | end_points[scope] = net 184 | 185 | # Convert to proper channels. 186 | scope = 'logits' 187 | logits = slim.conv2d( 188 | net, 189 | num_outputs, 190 | normalizer_fn=None, 191 | activation_fn=None, 192 | kernel_size=1, 193 | stride=1, 194 | padding='VALID', 195 | scope=scope) 196 | end_points[scope] = logits 197 | 198 | logits.get_shape().assert_has_rank(4) 199 | logits.get_shape().assert_is_compatible_with( 200 | [None, final_size, final_size, num_outputs]) 201 | 202 | return logits, end_points 203 | -------------------------------------------------------------------------------- /nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | 24 | from nets import dcgan 25 | 26 | 27 | class DCGANTest(tf.test.TestCase): 28 | 29 | def test_generator_run(self): 30 | tf.set_random_seed(1234) 31 | noise = tf.random_normal([100, 64]) 32 | image, _ = dcgan.generator(noise) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | image.eval() 36 | 37 | def test_generator_graph(self): 38 | tf.set_random_seed(1234) 39 | # Check graph construction for a number of image size/depths and batch 40 | # sizes. 41 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 42 | tf.reset_default_graph() 43 | final_size = 2 ** i 44 | noise = tf.random_normal([batch_size, 64]) 45 | image, end_points = dcgan.generator( 46 | noise, 47 | depth=32, 48 | final_size=final_size) 49 | 50 | self.assertAllEqual([batch_size, final_size, final_size, 3], 51 | image.shape.as_list()) 52 | 53 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 54 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 55 | 56 | # Check layer depths. 57 | for j in range(1, i): 58 | layer = end_points['deconv%i' % j] 59 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 60 | 61 | def test_generator_invalid_input(self): 62 | wrong_dim_input = tf.zeros([5, 32, 32]) 63 | with self.assertRaises(ValueError): 64 | dcgan.generator(wrong_dim_input) 65 | 66 | correct_input = tf.zeros([3, 2]) 67 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 68 | dcgan.generator(correct_input, final_size=30) 69 | 70 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 71 | dcgan.generator(correct_input, final_size=4) 72 | 73 | def test_discriminator_run(self): 74 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 75 | output, _ = dcgan.discriminator(image) 76 | with self.test_session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | output.eval() 79 | 80 | def test_discriminator_graph(self): 81 | # Check graph construction for a number of image size/depths and batch 82 | # sizes. 83 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 84 | tf.reset_default_graph() 85 | img_w = 2 ** i 86 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 87 | output, end_points = dcgan.discriminator( 88 | image, 89 | depth=32) 90 | 91 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 92 | 93 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 94 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 95 | 96 | # Check layer depths. 97 | for j in range(1, i+1): 98 | layer = end_points['conv%i' % j] 99 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 100 | 101 | def test_discriminator_invalid_input(self): 102 | wrong_dim_img = tf.zeros([5, 32, 32]) 103 | with self.assertRaises(ValueError): 104 | dcgan.discriminator(wrong_dim_img) 105 | 106 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 107 | with self.assertRaises(ValueError): 108 | dcgan.discriminator(spatially_undefined_shape) 109 | 110 | not_square = tf.zeros([5, 32, 16, 3]) 111 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 112 | dcgan.discriminator(not_square) 113 | 114 | not_power_2 = tf.zeros([5, 30, 30, 3]) 115 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 116 | dcgan.discriminator(not_power_2) 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.test.main() 121 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu, 37 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS, 38 | batch_norm_scale=False): 39 | """Defines the default arg scope for inception models. 40 | 41 | Args: 42 | weight_decay: The weight decay to use for regularizing the model. 43 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 44 | batch_norm_decay: Decay for batch norm moving average. 45 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 46 | in batch norm. 47 | activation_fn: Activation function for conv2d. 48 | batch_norm_updates_collections: Collection for the update ops for 49 | batch norm. 50 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 51 | activations in the batch normalization layer. 52 | 53 | Returns: 54 | An `arg_scope` to use for the inception models. 55 | """ 56 | batch_norm_params = { 57 | # Decay for the moving averages. 58 | 'decay': batch_norm_decay, 59 | # epsilon to prevent 0s in variance. 60 | 'epsilon': batch_norm_epsilon, 61 | # collection containing update_ops. 62 | 'updates_collections': batch_norm_updates_collections, 63 | # use fused batch norm if possible. 64 | 'fused': None, 65 | 'scale': batch_norm_scale, 66 | } 67 | if use_batch_norm: 68 | normalizer_fn = slim.batch_norm 69 | normalizer_params = batch_norm_params 70 | else: 71 | normalizer_fn = None 72 | normalizer_params = {} 73 | # Set weight_decay for weights in Conv and FC layers. 74 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 75 | weights_regularizer=slim.l2_regularizer(weight_decay)): 76 | with slim.arg_scope( 77 | [slim.conv2d], 78 | weights_initializer=slim.variance_scaling_initializer(), 79 | activation_fn=activation_fn, 80 | normalizer_fn=normalizer_fn, 81 | normalizer_params=normalizer_params) as sc: 82 | return sc 83 | -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /nets/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 2 | This folder contains building code for MobileNetV2, based on 3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 4 | 5 | # Performance 6 | ## Latency 7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using 8 | TF-Lite on the large core of Pixel 1 phone. 9 | 10 | ![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png) 11 | 12 | ## MACs 13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed 14 | to compute an inference on a single image is a common metric to measure the efficiency of the model. 15 | 16 | Below is the graph comparing V2 vs a few selected networks. The size 17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there 18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers. 19 | 20 | ![madds_top1_accuracy](madds_top1_accuracy.png) 21 | 22 | # Pretrained models 23 | ## Imagenet Checkpoints 24 | 25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1 26 | ---------------------------|---------|---------------|---------|----|------------- 27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0 28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0 29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8 30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1 31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2 32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6 33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6 34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8 35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6 36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4 37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9 38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2 39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7 40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1 41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9 42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9 43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4 44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7 45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6 46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5 47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9 48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5 49 | 50 | # Training 51 | The numbers above can be reproduced using slim's `train_image_classifier`. 52 | Below is the set of parameters that achieves 72.0% for full size MobileNetV2, after about 700K when trained on 8 GPU. 53 | If trained on a single GPU the full convergence is after 5.5M steps. Also note that learning rate and 54 | num_epochs_per_decay both need to be adjusted depending on how many GPUs are being 55 | used due to slim's internal averaging. 56 | 57 | ```bash 58 | --model_name="mobilenet_v2" 59 | --learning_rate=0.045 * NUM_GPUS #slim internally averages clones so we compensate 60 | --preprocessing_name="inception_v2" 61 | --label_smoothing=0.1 62 | --moving_average_decay=0.9999 63 | --batch_size= 96 64 | --num_clones = NUM_GPUS # you can use any number here between 1 and 8 depending on your hardware setup. 65 | --learning_rate_decay_factor=0.98 66 | --num_epochs_per_decay = 2.5 / NUM_GPUS # train_image_classifier does per clone epochs 67 | ``` 68 | 69 | # Example 70 | 71 | 72 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb). 73 | 74 | -------------------------------------------------------------------------------- /nets/mobilenet/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/nets/mobilenet/madds_top1_accuracy.png -------------------------------------------------------------------------------- /nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /nets/mobilenet/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of Mobilenet V2. 16 | 17 | Architecture: https://arxiv.org/abs/1801.04381 18 | 19 | The base model gives 72.2% accuracy on ImageNet, with 300MMadds, 20 | 3.4 M parameters. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import copy 28 | import functools 29 | 30 | import tensorflow as tf 31 | 32 | from nets.mobilenet import conv_blocks as ops 33 | from nets.mobilenet import mobilenet as lib 34 | 35 | slim = tf.contrib.slim 36 | op = lib.op 37 | 38 | expand_input = ops.expand_input_by_factor 39 | 40 | # pyformat: disable 41 | # Architecture: https://arxiv.org/abs/1801.04381 42 | V2_DEF = dict( 43 | defaults={ 44 | # Note: these parameters of batch norm affect the architecture 45 | # that's why they are here and not in training_scope. 46 | (slim.batch_norm,): {'center': True, 'scale': True}, 47 | (slim.conv2d, slim.fully_connected, slim.separable_conv2d): { 48 | 'normalizer_fn': slim.batch_norm, 'activation_fn': tf.nn.relu6 49 | }, 50 | (ops.expanded_conv,): { 51 | 'expansion_size': expand_input(6), 52 | 'split_expansion': 1, 53 | 'normalizer_fn': slim.batch_norm, 54 | 'residual': True 55 | }, 56 | (slim.conv2d, slim.separable_conv2d): {'padding': 'SAME'} 57 | }, 58 | spec=[ 59 | op(slim.conv2d, stride=2, num_outputs=32, kernel_size=[3, 3]), 60 | op(ops.expanded_conv, 61 | expansion_size=expand_input(1, divisible_by=1), 62 | num_outputs=16), 63 | op(ops.expanded_conv, stride=2, num_outputs=24), 64 | op(ops.expanded_conv, stride=1, num_outputs=24), 65 | op(ops.expanded_conv, stride=2, num_outputs=32), 66 | op(ops.expanded_conv, stride=1, num_outputs=32), 67 | op(ops.expanded_conv, stride=1, num_outputs=32), 68 | op(ops.expanded_conv, stride=2, num_outputs=64), 69 | op(ops.expanded_conv, stride=1, num_outputs=64), 70 | op(ops.expanded_conv, stride=1, num_outputs=64), 71 | op(ops.expanded_conv, stride=1, num_outputs=64), 72 | op(ops.expanded_conv, stride=1, num_outputs=96), 73 | op(ops.expanded_conv, stride=1, num_outputs=96), 74 | op(ops.expanded_conv, stride=1, num_outputs=96), 75 | op(ops.expanded_conv, stride=2, num_outputs=160), 76 | op(ops.expanded_conv, stride=1, num_outputs=160), 77 | op(ops.expanded_conv, stride=1, num_outputs=160), 78 | op(ops.expanded_conv, stride=1, num_outputs=320), 79 | op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=1280) 80 | ], 81 | ) 82 | # pyformat: enable 83 | 84 | 85 | @slim.add_arg_scope 86 | def mobilenet(input_tensor, 87 | num_classes=1001, 88 | depth_multiplier=1.0, 89 | scope='MobilenetV2', 90 | conv_defs=None, 91 | finegrain_classification_mode=False, 92 | min_depth=None, 93 | divisible_by=None, 94 | activation_fn=None, 95 | **kwargs): 96 | """Creates mobilenet V2 network. 97 | 98 | Inference mode is created by default. To create training use training_scope 99 | below. 100 | 101 | with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope()): 102 | logits, endpoints = mobilenet_v2.mobilenet(input_tensor) 103 | 104 | Args: 105 | input_tensor: The input tensor 106 | num_classes: number of classes 107 | depth_multiplier: The multiplier applied to scale number of 108 | channels in each layer. Note: this is called depth multiplier in the 109 | paper but the name is kept for consistency with slim's model builder. 110 | scope: Scope of the operator 111 | conv_defs: Allows to override default conv def. 112 | finegrain_classification_mode: When set to True, the model 113 | will keep the last layer large even for small multipliers. Following 114 | https://arxiv.org/abs/1801.04381 115 | suggests that it improves performance for ImageNet-type of problems. 116 | *Note* ignored if final_endpoint makes the builder exit earlier. 117 | min_depth: If provided, will ensure that all layers will have that 118 | many channels after application of depth multiplier. 119 | divisible_by: If provided will ensure that all layers # channels 120 | will be divisible by this number. 121 | activation_fn: Activation function to use, defaults to tf.nn.relu6 if not 122 | specified. 123 | **kwargs: passed directly to mobilenet.mobilenet: 124 | prediction_fn- what prediction function to use. 125 | reuse-: whether to reuse variables (if reuse set to true, scope 126 | must be given). 127 | Returns: 128 | logits/endpoints pair 129 | 130 | Raises: 131 | ValueError: On invalid arguments 132 | """ 133 | if conv_defs is None: 134 | conv_defs = V2_DEF 135 | if 'multiplier' in kwargs: 136 | raise ValueError('mobilenetv2 doesn\'t support generic ' 137 | 'multiplier parameter use "depth_multiplier" instead.') 138 | if finegrain_classification_mode: 139 | conv_defs = copy.deepcopy(conv_defs) 140 | if depth_multiplier < 1: 141 | conv_defs['spec'][-1].params['num_outputs'] /= depth_multiplier 142 | if activation_fn: 143 | conv_defs = copy.deepcopy(conv_defs) 144 | defaults = conv_defs['defaults'] 145 | conv_defaults = ( 146 | defaults[(slim.conv2d, slim.fully_connected, slim.separable_conv2d)]) 147 | conv_defaults['activation_fn'] = activation_fn 148 | 149 | depth_args = {} 150 | # NB: do not set depth_args unless they are provided to avoid overriding 151 | # whatever default depth_multiplier might have thanks to arg_scope. 152 | if min_depth is not None: 153 | depth_args['min_depth'] = min_depth 154 | if divisible_by is not None: 155 | depth_args['divisible_by'] = divisible_by 156 | 157 | with slim.arg_scope((lib.depth_multiplier,), **depth_args): 158 | return lib.mobilenet( 159 | input_tensor, 160 | num_classes=num_classes, 161 | conv_defs=conv_defs, 162 | scope=scope, 163 | multiplier=depth_multiplier, 164 | **kwargs) 165 | 166 | mobilenet.default_image_size = 224 167 | 168 | 169 | def wrapped_partial(func, *args, **kwargs): 170 | partial_func = functools.partial(func, *args, **kwargs) 171 | functools.update_wrapper(partial_func, func) 172 | return partial_func 173 | 174 | 175 | # Wrappers for mobilenet v2 with depth-multipliers. Be noticed that 176 | # 'finegrain_classification_mode' is set to True, which means the embedding 177 | # layer will not be shrinked when given a depth-multiplier < 1.0. 178 | mobilenet_v2_140 = wrapped_partial(mobilenet, depth_multiplier=1.4) 179 | mobilenet_v2_050 = wrapped_partial(mobilenet, depth_multiplier=0.50, 180 | finegrain_classification_mode=True) 181 | mobilenet_v2_035 = wrapped_partial(mobilenet, depth_multiplier=0.35, 182 | finegrain_classification_mode=True) 183 | 184 | 185 | @slim.add_arg_scope 186 | def mobilenet_base(input_tensor, depth_multiplier=1.0, **kwargs): 187 | """Creates base of the mobilenet (no pooling and no logits) .""" 188 | return mobilenet(input_tensor, 189 | depth_multiplier=depth_multiplier, 190 | base_only=True, **kwargs) 191 | 192 | 193 | def training_scope(**kwargs): 194 | """Defines MobilenetV2 training scope. 195 | 196 | Usage: 197 | with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope()): 198 | logits, endpoints = mobilenet_v2.mobilenet(input_tensor) 199 | 200 | with slim. 201 | 202 | Args: 203 | **kwargs: Passed to mobilenet.training_scope. The following parameters 204 | are supported: 205 | weight_decay- The weight decay to use for regularizing the model. 206 | stddev- Standard deviation for initialization, if negative uses xavier. 207 | dropout_keep_prob- dropout keep probability 208 | bn_decay- decay for the batch norm moving averages. 209 | 210 | Returns: 211 | An `arg_scope` to use for the mobilenet v2 model. 212 | """ 213 | return lib.training_scope(**kwargs) 214 | 215 | 216 | __all__ = ['training_scope', 'mobilenet_base', 'mobilenet', 'V2_DEF'] 217 | -------------------------------------------------------------------------------- /nets/mobilenet/mobilenet_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for mobilenet_v2.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import copy 21 | import tensorflow as tf 22 | from nets.mobilenet import conv_blocks as ops 23 | from nets.mobilenet import mobilenet 24 | from nets.mobilenet import mobilenet_v2 25 | 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def find_ops(optype): 31 | """Find ops of a given type in graphdef or a graph. 32 | 33 | Args: 34 | optype: operation type (e.g. Conv2D) 35 | Returns: 36 | List of operations. 37 | """ 38 | gd = tf.get_default_graph() 39 | return [var for var in gd.get_operations() if var.type == optype] 40 | 41 | 42 | class MobilenetV2Test(tf.test.TestCase): 43 | 44 | def setUp(self): 45 | tf.reset_default_graph() 46 | 47 | def testCreation(self): 48 | spec = dict(mobilenet_v2.V2_DEF) 49 | _, ep = mobilenet.mobilenet( 50 | tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec) 51 | num_convs = len(find_ops('Conv2D')) 52 | 53 | # This is mostly a sanity test. No deep reason for these particular 54 | # constants. 55 | # 56 | # All but first 2 and last one have two convolutions, and there is one 57 | # extra conv that is not in the spec. (logits) 58 | self.assertEqual(num_convs, len(spec['spec']) * 2 - 2) 59 | # Check that depthwise are exposed. 60 | for i in range(2, 17): 61 | self.assertIn('layer_%d/depthwise_output' % i, ep) 62 | 63 | def testCreationNoClasses(self): 64 | spec = copy.deepcopy(mobilenet_v2.V2_DEF) 65 | net, ep = mobilenet.mobilenet( 66 | tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec, 67 | num_classes=None) 68 | self.assertIs(net, ep['global_pool']) 69 | 70 | def testImageSizes(self): 71 | for input_size, output_size in [(224, 7), (192, 6), (160, 5), 72 | (128, 4), (96, 3)]: 73 | tf.reset_default_graph() 74 | _, ep = mobilenet_v2.mobilenet( 75 | tf.placeholder(tf.float32, (10, input_size, input_size, 3))) 76 | 77 | self.assertEqual(ep['layer_18/output'].get_shape().as_list()[1:3], 78 | [output_size] * 2) 79 | 80 | def testWithSplits(self): 81 | spec = copy.deepcopy(mobilenet_v2.V2_DEF) 82 | spec['overrides'] = { 83 | (ops.expanded_conv,): dict(split_expansion=2), 84 | } 85 | _, _ = mobilenet.mobilenet( 86 | tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec) 87 | num_convs = len(find_ops('Conv2D')) 88 | # All but 3 op has 3 conv operatore, the remainign 3 have one 89 | # and there is one unaccounted. 90 | self.assertEqual(num_convs, len(spec['spec']) * 3 - 5) 91 | 92 | def testWithOutputStride8(self): 93 | out, _ = mobilenet.mobilenet_base( 94 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 95 | conv_defs=mobilenet_v2.V2_DEF, 96 | output_stride=8, 97 | scope='MobilenetV2') 98 | self.assertEqual(out.get_shape().as_list()[1:3], [28, 28]) 99 | 100 | def testDivisibleBy(self): 101 | tf.reset_default_graph() 102 | mobilenet_v2.mobilenet( 103 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 104 | conv_defs=mobilenet_v2.V2_DEF, 105 | divisible_by=16, 106 | min_depth=32) 107 | s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')] 108 | s = set(s) 109 | self.assertSameElements([32, 64, 96, 160, 192, 320, 384, 576, 960, 1280, 110 | 1001], s) 111 | 112 | def testDivisibleByWithArgScope(self): 113 | tf.reset_default_graph() 114 | # Verifies that depth_multiplier arg scope actually works 115 | # if no default min_depth is provided. 116 | with slim.arg_scope((mobilenet.depth_multiplier,), min_depth=32): 117 | mobilenet_v2.mobilenet( 118 | tf.placeholder(tf.float32, (10, 224, 224, 2)), 119 | conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.1) 120 | s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')] 121 | s = set(s) 122 | self.assertSameElements(s, [32, 192, 128, 1001]) 123 | 124 | def testFineGrained(self): 125 | tf.reset_default_graph() 126 | # Verifies that depth_multiplier arg scope actually works 127 | # if no default min_depth is provided. 128 | 129 | mobilenet_v2.mobilenet( 130 | tf.placeholder(tf.float32, (10, 224, 224, 2)), 131 | conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.01, 132 | finegrain_classification_mode=True) 133 | s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')] 134 | s = set(s) 135 | # All convolutions will be 8->48, except for the last one. 136 | self.assertSameElements(s, [8, 48, 1001, 1280]) 137 | 138 | def testMobilenetBase(self): 139 | tf.reset_default_graph() 140 | # Verifies that mobilenet_base returns pre-pooling layer. 141 | with slim.arg_scope((mobilenet.depth_multiplier,), min_depth=32): 142 | net, _ = mobilenet_v2.mobilenet_base( 143 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 144 | conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.1) 145 | self.assertEqual(net.get_shape().as_list(), [10, 7, 7, 128]) 146 | 147 | def testWithOutputStride16(self): 148 | tf.reset_default_graph() 149 | out, _ = mobilenet.mobilenet_base( 150 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 151 | conv_defs=mobilenet_v2.V2_DEF, 152 | output_stride=16) 153 | self.assertEqual(out.get_shape().as_list()[1:3], [14, 14]) 154 | 155 | def testWithOutputStride8AndExplicitPadding(self): 156 | tf.reset_default_graph() 157 | out, _ = mobilenet.mobilenet_base( 158 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 159 | conv_defs=mobilenet_v2.V2_DEF, 160 | output_stride=8, 161 | use_explicit_padding=True, 162 | scope='MobilenetV2') 163 | self.assertEqual(out.get_shape().as_list()[1:3], [28, 28]) 164 | 165 | def testWithOutputStride16AndExplicitPadding(self): 166 | tf.reset_default_graph() 167 | out, _ = mobilenet.mobilenet_base( 168 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 169 | conv_defs=mobilenet_v2.V2_DEF, 170 | output_stride=16, 171 | use_explicit_padding=True) 172 | self.assertEqual(out.get_shape().as_list()[1:3], [14, 14]) 173 | 174 | def testBatchNormScopeDoesNotHaveIsTrainingWhenItsSetToNone(self): 175 | sc = mobilenet.training_scope(is_training=None) 176 | self.assertNotIn('is_training', sc[slim.arg_scope_func_key( 177 | slim.batch_norm)]) 178 | 179 | def testBatchNormScopeDoesHasIsTrainingWhenItsNotNone(self): 180 | sc = mobilenet.training_scope(is_training=False) 181 | self.assertIn('is_training', sc[slim.arg_scope_func_key(slim.batch_norm)]) 182 | sc = mobilenet.training_scope(is_training=True) 183 | self.assertIn('is_training', sc[slim.arg_scope_func_key(slim.batch_norm)]) 184 | sc = mobilenet.training_scope() 185 | self.assertIn('is_training', sc[slim.arg_scope_func_key(slim.batch_norm)]) 186 | 187 | 188 | if __name__ == '__main__': 189 | tf.test.main() 190 | -------------------------------------------------------------------------------- /nets/mobilenet_v1.md: -------------------------------------------------------------------------------- 1 | # Mobilenet_v2 2 | For Mobilenet V2 see this file [mobilenet/README.md] 3 | 4 | # MobileNet_v1 5 | 6 | [MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 7 | 8 | MobileNets trade off between latency, size and accuracy while comparing favorably with popular models from the literature. 9 | 10 | ![alt text](mobilenet_v1.png "MobileNet Graph") 11 | 12 | # Pre-trained Models 13 | 14 | Choose the right MobileNet model to fit your latency and size budget. The size of the network in memory and on disk is proportional to the number of parameters. The latency and power usage of the network scales with the number of Multiply-Accumulates (MACs) which measures the number of fused Multiplication and Addition operations. These MobileNet models have been trained on the 15 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 16 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 17 | 18 | Model | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 19 | :----:|:------------:|:----------:|:-------:|:-------:| 20 | [MobileNet_v1_1.0_224](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz)|569|4.24|70.9|89.9| 21 | [MobileNet_v1_1.0_192](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192.tgz)|418|4.24|70.0|89.2| 22 | [MobileNet_v1_1.0_160](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160.tgz)|291|4.24|68.0|87.7| 23 | [MobileNet_v1_1.0_128](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128.tgz)|186|4.24|65.2|85.8| 24 | [MobileNet_v1_0.75_224](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224.tgz)|317|2.59|68.4|88.2| 25 | [MobileNet_v1_0.75_192](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192.tgz)|233|2.59|67.2|87.3| 26 | [MobileNet_v1_0.75_160](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160.tgz)|162|2.59|65.3|86.0| 27 | [MobileNet_v1_0.75_128](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128.tgz)|104|2.59|62.1|83.9| 28 | [MobileNet_v1_0.50_224](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224.tgz)|150|1.34|63.3|84.9| 29 | [MobileNet_v1_0.50_192](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192.tgz)|110|1.34|61.7|83.6| 30 | [MobileNet_v1_0.50_160](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160.tgz)|77|1.34|59.1|81.9| 31 | [MobileNet_v1_0.50_128](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128.tgz)|49|1.34|56.3|79.4| 32 | [MobileNet_v1_0.25_224](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224.tgz)|41|0.47|49.8|74.2| 33 | [MobileNet_v1_0.25_192](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192.tgz)|34|0.47|47.7|72.3| 34 | [MobileNet_v1_0.25_160](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160.tgz)|21|0.47|45.5|70.3| 35 | [MobileNet_v1_0.25_128](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128.tgz)|14|0.47|41.5|66.3| 36 | [MobileNet_v1_1.0_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)|569|4.24|70.1|88.9| 37 | [MobileNet_v1_1.0_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz)|418|4.24|69.2|88.3| 38 | [MobileNet_v1_1.0_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz)|291|4.24|67.2|86.7| 39 | [MobileNet_v1_1.0_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz)|186|4.24|63.4|84.2| 40 | [MobileNet_v1_0.75_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz)|317|2.59|66.8|87.0| 41 | [MobileNet_v1_0.75_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz)|233|2.59|66.1|86.4| 42 | [MobileNet_v1_0.75_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz)|162|2.59|62.3|83.8| 43 | [MobileNet_v1_0.75_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz)|104|2.59|55.8|78.8| 44 | [MobileNet_v1_0.50_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz)|150|1.34|60.7|83.2| 45 | [MobileNet_v1_0.50_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz)|110|1.34|60.0|82.2| 46 | [MobileNet_v1_0.50_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz)|77|1.34|57.7|80.4| 47 | [MobileNet_v1_0.50_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz)|49|1.34|54.5|77.7| 48 | [MobileNet_v1_0.25_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz)|41|0.47|48.0|72.8| 49 | [MobileNet_v1_0.25_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz)|34|0.47|46.0|71.2| 50 | [MobileNet_v1_0.25_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz)|21|0.47|43.4|68.5| 51 | [MobileNet_v1_0.25_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz)|14|0.47|39.5|64.4| 52 | 53 | Revisions to models: 54 | * July 12, 2018: Update to TFLite models that fixes an accuracy issue resolved by making conversion support weights with narrow_range. We now report validation on the actual TensorFlow Lite model rather than the emulated quantization number of TensorFlow. 55 | * August 2, 2018: Update to TFLite models that fixes an accuracy issue resolved by making sure the numerics of quantization match TF quantized training accurately. 56 | 57 | The linked model tar files contain the following: 58 | * Trained model checkpoints 59 | * Eval graph text protos (to be easily viewed) 60 | * Frozen trained models 61 | * Info file containing input and output information 62 | * Converted [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) flatbuffer model 63 | 64 | Note that quantized model GraphDefs are still float models, they just have FakeQuantization 65 | operation embedded to simulate quantization. These are converted by [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) 66 | to be fully quantized. The final effect of quantization can be seen by comparing the frozen fake 67 | quantized graph to the size of the TFLite flatbuffer, i.e. The TFLite flatbuffer is about 1/4 68 | the size. 69 | For more information on the quantization techniques used here, see 70 | [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize). 71 | 72 | Here is an example of how to download the MobileNet_v1_1.0_224 checkpoint: 73 | 74 | ```shell 75 | $ CHECKPOINT_DIR=/tmp/checkpoints 76 | $ mkdir ${CHECKPOINT_DIR} 77 | $ wget http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz 78 | $ tar -xvf mobilenet_v1_1.0_224.tgz 79 | $ mv mobilenet_v1_1.0_224.ckpt.* ${CHECKPOINT_DIR} 80 | ``` 81 | 82 | # MobileNet V1 scripts 83 | 84 | This package contains scripts for training floating point and eight-bit fixed 85 | point TensorFlow models. 86 | 87 | Quantization tools used are described in [contrib/quantize](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize). 88 | 89 | Conversion to fully quantized models for mobile can be done through [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/). 90 | 91 | ## Usage 92 | 93 | ### Build for GPU 94 | 95 | ``` 96 | $ bazel build -c opt --config=cuda mobilenet_v1_{eval,train} 97 | ``` 98 | 99 | ### Running 100 | 101 | #### Float Training and Eval 102 | 103 | Train: 104 | 105 | ``` 106 | $ ./bazel-bin/mobilenet_v1_train --dataset_dir "path/to/dataset" --checkpoint_dir "path/to/checkpoints" 107 | ``` 108 | 109 | Eval: 110 | 111 | ``` 112 | $ ./bazel-bin/mobilenet_v1_eval --dataset_dir "path/to/dataset" --checkpoint_dir "path/to/checkpoints" 113 | ``` 114 | 115 | #### Quantized Training and Eval 116 | 117 | Train from preexisting float checkpoint: 118 | 119 | ``` 120 | $ ./bazel-bin/mobilenet_v1_train --dataset_dir "path/to/dataset" --checkpoint_dir "path/to/checkpoints" \ 121 | --quantize=True --fine_tune_checkpoint=float/checkpoint/path 122 | ``` 123 | 124 | Train from scratch: 125 | 126 | ``` 127 | $ ./bazel-bin/mobilenet_v1_train --dataset_dir "path/to/dataset" --checkpoint_dir "path/to/checkpoints" --quantize=True 128 | ``` 129 | 130 | Eval: 131 | 132 | ``` 133 | $ ./bazel-bin/mobilenet_v1_eval --dataset_dir "path/to/dataset" --checkpoint_dir "path/to/checkpoints" --quantize=True 134 | ``` 135 | 136 | The resulting float and quantized models can be run on-device via [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/). 137 | -------------------------------------------------------------------------------- /nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyi1989/TSPORTNet/58e5c54b9b613a225cb4b3892bd8316b0b328897/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /nets/mobilenet_v1_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Validate mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import mobilenet_v1 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | flags = tf.app.flags 31 | 32 | flags.DEFINE_string('master', '', 'Session master') 33 | flags.DEFINE_integer('batch_size', 250, 'Batch size') 34 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 35 | flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate') 36 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 37 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 38 | flags.DEFINE_bool('quantize', False, 'Quantize training') 39 | flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints') 40 | flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs') 41 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def imagenet_input(is_training): 47 | """Data reader for imagenet. 48 | 49 | Reads in imagenet data and performs pre-processing on the images. 50 | 51 | Args: 52 | is_training: bool specifying if train or validation dataset is needed. 53 | Returns: 54 | A batch of images and labels. 55 | """ 56 | if is_training: 57 | dataset = dataset_factory.get_dataset('imagenet', 'train', 58 | FLAGS.dataset_dir) 59 | else: 60 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 61 | FLAGS.dataset_dir) 62 | 63 | provider = slim.dataset_data_provider.DatasetDataProvider( 64 | dataset, 65 | shuffle=is_training, 66 | common_queue_capacity=2 * FLAGS.batch_size, 67 | common_queue_min=FLAGS.batch_size) 68 | [image, label] = provider.get(['image', 'label']) 69 | 70 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 71 | 'mobilenet_v1', is_training=is_training) 72 | 73 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 74 | 75 | images, labels = tf.train.batch( 76 | tensors=[image, label], 77 | batch_size=FLAGS.batch_size, 78 | num_threads=4, 79 | capacity=5 * FLAGS.batch_size) 80 | return images, labels 81 | 82 | 83 | def metrics(logits, labels): 84 | """Specify the metrics for eval. 85 | 86 | Args: 87 | logits: Logits output from the graph. 88 | labels: Ground truth labels for inputs. 89 | 90 | Returns: 91 | Eval Op for the graph. 92 | """ 93 | labels = tf.squeeze(labels) 94 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 95 | 'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels), 96 | 'Recall_5': tf.metrics.recall_at_k(labels, logits, 5), 97 | }) 98 | for name, value in names_to_values.iteritems(): 99 | slim.summaries.add_scalar_summary( 100 | value, name, prefix='eval', print_summary=True) 101 | return names_to_updates.values() 102 | 103 | 104 | def build_model(): 105 | """Build the mobilenet_v1 model for evaluation. 106 | 107 | Returns: 108 | g: graph with rewrites after insertion of quantization ops and batch norm 109 | folding. 110 | eval_ops: eval ops for inference. 111 | variables_to_restore: List of variables to restore from checkpoint. 112 | """ 113 | g = tf.Graph() 114 | with g.as_default(): 115 | inputs, labels = imagenet_input(is_training=False) 116 | 117 | scope = mobilenet_v1.mobilenet_v1_arg_scope( 118 | is_training=False, weight_decay=0.0) 119 | with slim.arg_scope(scope): 120 | logits, _ = mobilenet_v1.mobilenet_v1( 121 | inputs, 122 | is_training=False, 123 | depth_multiplier=FLAGS.depth_multiplier, 124 | num_classes=FLAGS.num_classes) 125 | 126 | if FLAGS.quantize: 127 | tf.contrib.quantize.create_eval_graph() 128 | 129 | eval_ops = metrics(logits, labels) 130 | 131 | return g, eval_ops 132 | 133 | 134 | def eval_model(): 135 | """Evaluates mobilenet_v1.""" 136 | g, eval_ops = build_model() 137 | with g.as_default(): 138 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) 139 | slim.evaluation.evaluate_once( 140 | FLAGS.master, 141 | FLAGS.checkpoint_dir, 142 | logdir=FLAGS.eval_dir, 143 | num_evals=num_batches, 144 | eval_op=eval_ops) 145 | 146 | 147 | def main(unused_arg): 148 | eval_model() 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run(main) 153 | -------------------------------------------------------------------------------- /nets/mobilenet_v1_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Build and train mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from datasets import dataset_factory 24 | from nets import mobilenet_v1 25 | from preprocessing import preprocessing_factory 26 | 27 | slim = tf.contrib.slim 28 | 29 | flags = tf.app.flags 30 | 31 | flags.DEFINE_string('master', '', 'Session master') 32 | flags.DEFINE_integer('task', 0, 'Task') 33 | flags.DEFINE_integer('ps_tasks', 0, 'Number of ps') 34 | flags.DEFINE_integer('batch_size', 64, 'Batch size') 35 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 36 | flags.DEFINE_integer('number_of_steps', None, 37 | 'Number of training steps to perform before stopping') 38 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 39 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 40 | flags.DEFINE_bool('quantize', False, 'Quantize training') 41 | flags.DEFINE_string('fine_tune_checkpoint', '', 42 | 'Checkpoint from which to start finetuning.') 43 | flags.DEFINE_string('checkpoint_dir', '', 44 | 'Directory for writing training checkpoints and logs') 45 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 46 | flags.DEFINE_integer('log_every_n_steps', 100, 'Number of steps per log') 47 | flags.DEFINE_integer('save_summaries_secs', 100, 48 | 'How often to save summaries, secs') 49 | flags.DEFINE_integer('save_interval_secs', 100, 50 | 'How often to save checkpoints, secs') 51 | 52 | FLAGS = flags.FLAGS 53 | 54 | _LEARNING_RATE_DECAY_FACTOR = 0.94 55 | 56 | 57 | def get_learning_rate(): 58 | if FLAGS.fine_tune_checkpoint: 59 | # If we are fine tuning a checkpoint we need to start at a lower learning 60 | # rate since we are farther along on training. 61 | return 1e-4 62 | else: 63 | return 0.045 64 | 65 | 66 | def get_quant_delay(): 67 | if FLAGS.fine_tune_checkpoint: 68 | # We can start quantizing immediately if we are finetuning. 69 | return 0 70 | else: 71 | # We need to wait for the model to train a bit before we quantize if we are 72 | # training from scratch. 73 | return 250000 74 | 75 | 76 | def imagenet_input(is_training): 77 | """Data reader for imagenet. 78 | 79 | Reads in imagenet data and performs pre-processing on the images. 80 | 81 | Args: 82 | is_training: bool specifying if train or validation dataset is needed. 83 | Returns: 84 | A batch of images and labels. 85 | """ 86 | if is_training: 87 | dataset = dataset_factory.get_dataset('imagenet', 'train', 88 | FLAGS.dataset_dir) 89 | else: 90 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 91 | FLAGS.dataset_dir) 92 | 93 | provider = slim.dataset_data_provider.DatasetDataProvider( 94 | dataset, 95 | shuffle=is_training, 96 | common_queue_capacity=2 * FLAGS.batch_size, 97 | common_queue_min=FLAGS.batch_size) 98 | [image, label] = provider.get(['image', 'label']) 99 | 100 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 101 | 'mobilenet_v1', is_training=is_training) 102 | 103 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 104 | 105 | images, labels = tf.train.batch( 106 | [image, label], 107 | batch_size=FLAGS.batch_size, 108 | num_threads=4, 109 | capacity=5 * FLAGS.batch_size) 110 | labels = slim.one_hot_encoding(labels, FLAGS.num_classes) 111 | return images, labels 112 | 113 | 114 | def build_model(): 115 | """Builds graph for model to train with rewrites for quantization. 116 | 117 | Returns: 118 | g: Graph with fake quantization ops and batch norm folding suitable for 119 | training quantized weights. 120 | train_tensor: Train op for execution during training. 121 | """ 122 | g = tf.Graph() 123 | with g.as_default(), tf.device( 124 | tf.train.replica_device_setter(FLAGS.ps_tasks)): 125 | inputs, labels = imagenet_input(is_training=True) 126 | with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)): 127 | logits, _ = mobilenet_v1.mobilenet_v1( 128 | inputs, 129 | is_training=True, 130 | depth_multiplier=FLAGS.depth_multiplier, 131 | num_classes=FLAGS.num_classes) 132 | 133 | tf.losses.softmax_cross_entropy(labels, logits) 134 | 135 | # Call rewriter to produce graph with fake quant ops and folded batch norms 136 | # quant_delay delays start of quantization till quant_delay steps, allowing 137 | # for better model accuracy. 138 | if FLAGS.quantize: 139 | tf.contrib.quantize.create_training_graph(quant_delay=get_quant_delay()) 140 | 141 | total_loss = tf.losses.get_total_loss(name='total_loss') 142 | # Configure the learning rate using an exponential decay. 143 | num_epochs_per_decay = 2.5 144 | imagenet_size = 1271167 145 | decay_steps = int(imagenet_size / FLAGS.batch_size * num_epochs_per_decay) 146 | 147 | learning_rate = tf.train.exponential_decay( 148 | get_learning_rate(), 149 | tf.train.get_or_create_global_step(), 150 | decay_steps, 151 | _LEARNING_RATE_DECAY_FACTOR, 152 | staircase=True) 153 | opt = tf.train.GradientDescentOptimizer(learning_rate) 154 | 155 | train_tensor = slim.learning.create_train_op( 156 | total_loss, 157 | optimizer=opt) 158 | 159 | slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses') 160 | slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training') 161 | return g, train_tensor 162 | 163 | 164 | def get_checkpoint_init_fn(): 165 | """Returns the checkpoint init_fn if the checkpoint is provided.""" 166 | if FLAGS.fine_tune_checkpoint: 167 | variables_to_restore = slim.get_variables_to_restore() 168 | global_step_reset = tf.assign(tf.train.get_or_create_global_step(), 0) 169 | # When restoring from a floating point model, the min/max values for 170 | # quantized weights and activations are not present. 171 | # We instruct slim to ignore variables that are missing during restoration 172 | # by setting ignore_missing_vars=True 173 | slim_init_fn = slim.assign_from_checkpoint_fn( 174 | FLAGS.fine_tune_checkpoint, 175 | variables_to_restore, 176 | ignore_missing_vars=True) 177 | 178 | def init_fn(sess): 179 | slim_init_fn(sess) 180 | # If we are restoring from a floating point model, we need to initialize 181 | # the global step to zero for the exponential decay to result in 182 | # reasonable learning rates. 183 | sess.run(global_step_reset) 184 | return init_fn 185 | else: 186 | return None 187 | 188 | 189 | def train_model(): 190 | """Trains mobilenet_v1.""" 191 | g, train_tensor = build_model() 192 | with g.as_default(): 193 | slim.learning.train( 194 | train_tensor, 195 | FLAGS.checkpoint_dir, 196 | is_chief=(FLAGS.task == 0), 197 | master=FLAGS.master, 198 | log_every_n_steps=FLAGS.log_every_n_steps, 199 | graph=g, 200 | number_of_steps=FLAGS.number_of_steps, 201 | save_summaries_secs=FLAGS.save_summaries_secs, 202 | save_interval_secs=FLAGS.save_interval_secs, 203 | init_fn=get_checkpoint_init_fn(), 204 | global_step=tf.train.get_global_step()) 205 | 206 | 207 | def main(unused_arg): 208 | train_model() 209 | 210 | 211 | if __name__ == '__main__': 212 | tf.app.run(main) 213 | -------------------------------------------------------------------------------- /nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /nets/nasnet/pnasnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition for the PNASNet classification networks. 16 | 17 | Paper: https://arxiv.org/abs/1712.00559 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import copy 25 | import tensorflow as tf 26 | 27 | from nets.nasnet import nasnet 28 | from nets.nasnet import nasnet_utils 29 | 30 | arg_scope = tf.contrib.framework.arg_scope 31 | slim = tf.contrib.slim 32 | 33 | 34 | def large_imagenet_config(): 35 | """Large ImageNet configuration based on PNASNet-5.""" 36 | return tf.contrib.training.HParams( 37 | stem_multiplier=3.0, 38 | dense_dropout_keep_prob=0.5, 39 | num_cells=12, 40 | filter_scaling_rate=2.0, 41 | num_conv_filters=216, 42 | drop_path_keep_prob=0.6, 43 | use_aux_head=1, 44 | num_reduction_layers=2, 45 | data_format='NHWC', 46 | skip_reduction_layer_input=1, 47 | total_training_steps=250000, 48 | use_bounded_activation=False, 49 | ) 50 | 51 | 52 | def mobile_imagenet_config(): 53 | """Mobile ImageNet configuration based on PNASNet-5.""" 54 | return tf.contrib.training.HParams( 55 | stem_multiplier=1.0, 56 | dense_dropout_keep_prob=0.5, 57 | num_cells=9, 58 | filter_scaling_rate=2.0, 59 | num_conv_filters=54, 60 | drop_path_keep_prob=1.0, 61 | use_aux_head=1, 62 | num_reduction_layers=2, 63 | data_format='NHWC', 64 | skip_reduction_layer_input=1, 65 | total_training_steps=250000, 66 | use_bounded_activation=False, 67 | ) 68 | 69 | 70 | def pnasnet_large_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997, 71 | batch_norm_epsilon=0.001): 72 | """Default arg scope for the PNASNet Large ImageNet model.""" 73 | return nasnet.nasnet_large_arg_scope( 74 | weight_decay, batch_norm_decay, batch_norm_epsilon) 75 | 76 | 77 | def pnasnet_mobile_arg_scope(weight_decay=4e-5, 78 | batch_norm_decay=0.9997, 79 | batch_norm_epsilon=0.001): 80 | """Default arg scope for the PNASNet Mobile ImageNet model.""" 81 | return nasnet.nasnet_mobile_arg_scope(weight_decay, batch_norm_decay, 82 | batch_norm_epsilon) 83 | 84 | 85 | def _build_pnasnet_base(images, 86 | normal_cell, 87 | num_classes, 88 | hparams, 89 | is_training, 90 | final_endpoint=None): 91 | """Constructs a PNASNet image model.""" 92 | 93 | end_points = {} 94 | 95 | def add_and_check_endpoint(endpoint_name, net): 96 | end_points[endpoint_name] = net 97 | return final_endpoint and (endpoint_name == final_endpoint) 98 | 99 | # Find where to place the reduction cells or stride normal cells 100 | reduction_indices = nasnet_utils.calc_reduction_layers( 101 | hparams.num_cells, hparams.num_reduction_layers) 102 | 103 | # pylint: disable=protected-access 104 | stem = lambda: nasnet._imagenet_stem(images, hparams, normal_cell) 105 | # pylint: enable=protected-access 106 | net, cell_outputs = stem() 107 | if add_and_check_endpoint('Stem', net): 108 | return net, end_points 109 | 110 | # Setup for building in the auxiliary head. 111 | aux_head_cell_idxes = [] 112 | if len(reduction_indices) >= 2: 113 | aux_head_cell_idxes.append(reduction_indices[1] - 1) 114 | 115 | # Run the cells 116 | filter_scaling = 1.0 117 | # true_cell_num accounts for the stem cells 118 | true_cell_num = 2 119 | activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu 120 | for cell_num in range(hparams.num_cells): 121 | is_reduction = cell_num in reduction_indices 122 | stride = 2 if is_reduction else 1 123 | if is_reduction: filter_scaling *= hparams.filter_scaling_rate 124 | if hparams.skip_reduction_layer_input or not is_reduction: 125 | prev_layer = cell_outputs[-2] 126 | net = normal_cell( 127 | net, 128 | scope='cell_{}'.format(cell_num), 129 | filter_scaling=filter_scaling, 130 | stride=stride, 131 | prev_layer=prev_layer, 132 | cell_num=true_cell_num) 133 | if add_and_check_endpoint('Cell_{}'.format(cell_num), net): 134 | return net, end_points 135 | true_cell_num += 1 136 | cell_outputs.append(net) 137 | 138 | if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and 139 | num_classes and is_training): 140 | aux_net = activation_fn(net) 141 | # pylint: disable=protected-access 142 | nasnet._build_aux_head(aux_net, end_points, num_classes, hparams, 143 | scope='aux_{}'.format(cell_num)) 144 | # pylint: enable=protected-access 145 | 146 | # Final softmax layer 147 | with tf.variable_scope('final_layer'): 148 | net = activation_fn(net) 149 | net = nasnet_utils.global_avg_pool(net) 150 | if add_and_check_endpoint('global_pool', net) or not num_classes: 151 | return net, end_points 152 | net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout') 153 | logits = slim.fully_connected(net, num_classes) 154 | 155 | if add_and_check_endpoint('Logits', logits): 156 | return net, end_points 157 | 158 | predictions = tf.nn.softmax(logits, name='predictions') 159 | if add_and_check_endpoint('Predictions', predictions): 160 | return net, end_points 161 | return logits, end_points 162 | 163 | 164 | def build_pnasnet_large(images, 165 | num_classes, 166 | is_training=True, 167 | final_endpoint=None, 168 | config=None): 169 | """Build PNASNet Large model for the ImageNet Dataset.""" 170 | hparams = copy.deepcopy(config) if config else large_imagenet_config() 171 | # pylint: disable=protected-access 172 | nasnet._update_hparams(hparams, is_training) 173 | # pylint: enable=protected-access 174 | 175 | if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': 176 | tf.logging.info('A GPU is available on the machine, consider using NCHW ' 177 | 'data format for increased speed on GPU.') 178 | 179 | if hparams.data_format == 'NCHW': 180 | images = tf.transpose(images, [0, 3, 1, 2]) 181 | 182 | # Calculate the total number of cells in the network. 183 | # There is no distinction between reduction and normal cells in PNAS so the 184 | # total number of cells is equal to the number normal cells plus the number 185 | # of stem cells (two by default). 186 | total_num_cells = hparams.num_cells + 2 187 | 188 | normal_cell = PNasNetNormalCell(hparams.num_conv_filters, 189 | hparams.drop_path_keep_prob, total_num_cells, 190 | hparams.total_training_steps, 191 | hparams.use_bounded_activation) 192 | with arg_scope( 193 | [slim.dropout, nasnet_utils.drop_path, slim.batch_norm], 194 | is_training=is_training): 195 | with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, 196 | slim.batch_norm, slim.separable_conv2d, 197 | nasnet_utils.factorized_reduction, 198 | nasnet_utils.global_avg_pool, 199 | nasnet_utils.get_channel_index, 200 | nasnet_utils.get_channel_dim], 201 | data_format=hparams.data_format): 202 | return _build_pnasnet_base( 203 | images, 204 | normal_cell=normal_cell, 205 | num_classes=num_classes, 206 | hparams=hparams, 207 | is_training=is_training, 208 | final_endpoint=final_endpoint) 209 | build_pnasnet_large.default_image_size = 331 210 | 211 | 212 | def build_pnasnet_mobile(images, 213 | num_classes, 214 | is_training=True, 215 | final_endpoint=None, 216 | config=None): 217 | """Build PNASNet Mobile model for the ImageNet Dataset.""" 218 | hparams = copy.deepcopy(config) if config else mobile_imagenet_config() 219 | # pylint: disable=protected-access 220 | nasnet._update_hparams(hparams, is_training) 221 | # pylint: enable=protected-access 222 | 223 | if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': 224 | tf.logging.info('A GPU is available on the machine, consider using NCHW ' 225 | 'data format for increased speed on GPU.') 226 | 227 | if hparams.data_format == 'NCHW': 228 | images = tf.transpose(images, [0, 3, 1, 2]) 229 | 230 | # Calculate the total number of cells in the network. 231 | # There is no distinction between reduction and normal cells in PNAS so the 232 | # total number of cells is equal to the number normal cells plus the number 233 | # of stem cells (two by default). 234 | total_num_cells = hparams.num_cells + 2 235 | 236 | normal_cell = PNasNetNormalCell(hparams.num_conv_filters, 237 | hparams.drop_path_keep_prob, total_num_cells, 238 | hparams.total_training_steps, 239 | hparams.use_bounded_activation) 240 | with arg_scope( 241 | [slim.dropout, nasnet_utils.drop_path, slim.batch_norm], 242 | is_training=is_training): 243 | with arg_scope( 244 | [ 245 | slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, 246 | slim.separable_conv2d, nasnet_utils.factorized_reduction, 247 | nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, 248 | nasnet_utils.get_channel_dim 249 | ], 250 | data_format=hparams.data_format): 251 | return _build_pnasnet_base( 252 | images, 253 | normal_cell=normal_cell, 254 | num_classes=num_classes, 255 | hparams=hparams, 256 | is_training=is_training, 257 | final_endpoint=final_endpoint) 258 | 259 | 260 | build_pnasnet_mobile.default_image_size = 224 261 | 262 | 263 | class PNasNetNormalCell(nasnet_utils.NasNetABaseCell): 264 | """PNASNet Normal Cell.""" 265 | 266 | def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells, 267 | total_training_steps, use_bounded_activation=False): 268 | # Configuration for the PNASNet-5 model. 269 | operations = [ 270 | 'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3', 271 | 'separable_5x5_2', 'separable_3x3_2', 'separable_3x3_2', 'max_pool_3x3', 272 | 'separable_3x3_2', 'none' 273 | ] 274 | used_hiddenstates = [1, 1, 0, 0, 0, 0, 0] 275 | hiddenstate_indices = [1, 1, 0, 0, 0, 0, 4, 0, 1, 0] 276 | 277 | super(PNasNetNormalCell, self).__init__( 278 | num_conv_filters, operations, used_hiddenstates, hiddenstate_indices, 279 | drop_path_keep_prob, total_num_cells, total_training_steps, 280 | use_bounded_activation) 281 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import mobilenet_v1 29 | from nets import overfeat 30 | from nets import resnet_v1 31 | from nets import resnet_v2 32 | from nets import vgg 33 | from nets.mobilenet import mobilenet_v2 34 | from nets.nasnet import nasnet 35 | from nets.nasnet import pnasnet 36 | 37 | slim = tf.contrib.slim 38 | 39 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 40 | 'cifarnet': cifarnet.cifarnet, 41 | 'overfeat': overfeat.overfeat, 42 | 'vgg_a': vgg.vgg_a, 43 | 'vgg_16': vgg.vgg_16, 44 | 'vgg_19': vgg.vgg_19, 45 | 'inception_v1': inception.inception_v1, 46 | 'inception_v2': inception.inception_v2, 47 | 'inception_v3': inception.inception_v3, 48 | 'inception_v4': inception.inception_v4, 49 | 'inception_resnet_v2': inception.inception_resnet_v2, 50 | 'lenet': lenet.lenet, 51 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 52 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 53 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 54 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 55 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 56 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 57 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 58 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 59 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 60 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 61 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 62 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 63 | 'mobilenet_v2': mobilenet_v2.mobilenet, 64 | 'mobilenet_v2_140': mobilenet_v2.mobilenet_v2_140, 65 | 'mobilenet_v2_035': mobilenet_v2.mobilenet_v2_035, 66 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 67 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 68 | 'nasnet_large': nasnet.build_nasnet_large, 69 | 'pnasnet_large': pnasnet.build_pnasnet_large, 70 | 'pnasnet_mobile': pnasnet.build_pnasnet_mobile, 71 | } 72 | 73 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 74 | 'cifarnet': cifarnet.cifarnet_arg_scope, 75 | 'overfeat': overfeat.overfeat_arg_scope, 76 | 'vgg_a': vgg.vgg_arg_scope, 77 | 'vgg_16': vgg.vgg_arg_scope, 78 | 'vgg_19': vgg.vgg_arg_scope, 79 | 'inception_v1': inception.inception_v3_arg_scope, 80 | 'inception_v2': inception.inception_v3_arg_scope, 81 | 'inception_v3': inception.inception_v3_arg_scope, 82 | 'inception_v4': inception.inception_v4_arg_scope, 83 | 'inception_resnet_v2': 84 | inception.inception_resnet_v2_arg_scope, 85 | 'lenet': lenet.lenet_arg_scope, 86 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 87 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 88 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 89 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 90 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 91 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 92 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 93 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 94 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 95 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 96 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 97 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 98 | 'mobilenet_v2': mobilenet_v2.training_scope, 99 | 'mobilenet_v2_035': mobilenet_v2.training_scope, 100 | 'mobilenet_v2_140': mobilenet_v2.training_scope, 101 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 102 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 103 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 104 | 'pnasnet_large': pnasnet.pnasnet_large_arg_scope, 105 | 'pnasnet_mobile': pnasnet.pnasnet_mobile_arg_scope, 106 | } 107 | 108 | 109 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 110 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 111 | 112 | Args: 113 | name: The name of the network. 114 | num_classes: The number of classes to use for classification. If 0 or None, 115 | the logits layer is omitted and its input features are returned instead. 116 | weight_decay: The l2 coefficient for the model weights. 117 | is_training: `True` if the model is being used for training and `False` 118 | otherwise. 119 | 120 | Returns: 121 | network_fn: A function that applies the model to a batch of images. It has 122 | the following signature: 123 | net, end_points = network_fn(images) 124 | The `images` input is a tensor of shape [batch_size, height, width, 3] 125 | with height = width = network_fn.default_image_size. (The permissibility 126 | and treatment of other sizes depends on the network_fn.) 127 | The returned `end_points` are a dictionary of intermediate activations. 128 | The returned `net` is the topmost layer, depending on `num_classes`: 129 | If `num_classes` was a non-zero integer, `net` is a logits tensor 130 | of shape [batch_size, num_classes]. 131 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 132 | to the logits layer of shape [batch_size, 1, 1, num_features] or 133 | [batch_size, num_features]. Dropout has not been applied to this 134 | (even if the network's original classification does); it remains for 135 | the caller to do this or not. 136 | 137 | Raises: 138 | ValueError: If network `name` is not recognized. 139 | """ 140 | if name not in networks_map: 141 | raise ValueError('Name of network unknown %s' % name) 142 | func = networks_map[name] 143 | @functools.wraps(func) 144 | def network_fn(images, **kwargs): 145 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 146 | with slim.arg_scope(arg_scope): 147 | return func(images, num_classes, is_training=is_training, **kwargs) 148 | if hasattr(func, 'default_image_size'): 149 | network_fn.default_image_size = func.default_image_size 150 | 151 | return network_fn 152 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in list(nets_factory.networks_map.keys())[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in list(nets_factory.networks_map.keys())[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 281, 281 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 231, 231 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = overfeat.overfeat(inputs, num_classes) 70 | expected_names = ['overfeat/conv1', 71 | 'overfeat/pool1', 72 | 'overfeat/conv2', 73 | 'overfeat/pool2', 74 | 'overfeat/conv3', 75 | 'overfeat/conv4', 76 | 'overfeat/conv5', 77 | 'overfeat/pool5', 78 | 'overfeat/fc6', 79 | 'overfeat/fc7', 80 | 'overfeat/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 231, 231 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = overfeat.overfeat(inputs, num_classes) 91 | expected_names = ['overfeat/conv1', 92 | 'overfeat/pool1', 93 | 'overfeat/conv2', 94 | 'overfeat/pool2', 95 | 'overfeat/conv3', 96 | 'overfeat/conv4', 97 | 'overfeat/conv5', 98 | 'overfeat/pool5', 99 | 'overfeat/fc6', 100 | 'overfeat/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('overfeat/fc7')) 104 | 105 | def testModelVariables(self): 106 | batch_size = 5 107 | height, width = 231, 231 108 | num_classes = 1000 109 | with self.test_session(): 110 | inputs = tf.random_uniform((batch_size, height, width, 3)) 111 | overfeat.overfeat(inputs, num_classes) 112 | expected_names = ['overfeat/conv1/weights', 113 | 'overfeat/conv1/biases', 114 | 'overfeat/conv2/weights', 115 | 'overfeat/conv2/biases', 116 | 'overfeat/conv3/weights', 117 | 'overfeat/conv3/biases', 118 | 'overfeat/conv4/weights', 119 | 'overfeat/conv4/biases', 120 | 'overfeat/conv5/weights', 121 | 'overfeat/conv5/biases', 122 | 'overfeat/fc6/weights', 123 | 'overfeat/fc6/biases', 124 | 'overfeat/fc7/weights', 125 | 'overfeat/fc7/biases', 126 | 'overfeat/fc8/weights', 127 | 'overfeat/fc8/biases', 128 | ] 129 | model_variables = [v.op.name for v in slim.get_model_variables()] 130 | self.assertSetEqual(set(model_variables), set(expected_names)) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | height, width = 231, 231 135 | num_classes = 1000 136 | with self.test_session(): 137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | predictions = tf.argmax(logits, 1) 142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 143 | 144 | def testTrainEvalWithReuse(self): 145 | train_batch_size = 2 146 | eval_batch_size = 1 147 | train_height, train_width = 231, 231 148 | eval_height, eval_width = 281, 281 149 | num_classes = 1000 150 | with self.test_session(): 151 | train_inputs = tf.random_uniform( 152 | (train_batch_size, train_height, train_width, 3)) 153 | logits, _ = overfeat.overfeat(train_inputs) 154 | self.assertListEqual(logits.get_shape().as_list(), 155 | [train_batch_size, num_classes]) 156 | tf.get_variable_scope().reuse_variables() 157 | eval_inputs = tf.random_uniform( 158 | (eval_batch_size, eval_height, eval_width, 3)) 159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 160 | spatial_squeeze=False) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [eval_batch_size, 2, 2, num_classes]) 163 | logits = tf.reduce_mean(logits, [1, 2]) 164 | predictions = tf.argmax(logits, 1) 165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 166 | 167 | def testForward(self): 168 | batch_size = 1 169 | height, width = 231, 231 170 | with self.test_session() as sess: 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | logits, _ = overfeat.overfeat(inputs) 173 | sess.run(tf.global_variables_initializer()) 174 | output = sess.run(logits) 175 | self.assertTrue(output.any()) 176 | 177 | if __name__ == '__main__': 178 | tf.test.main() 179 | -------------------------------------------------------------------------------- /nets/pix2pix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Implementation of the Image-to-Image Translation model. 16 | 17 | This network represents a port of the following work: 18 | 19 | Image-to-Image Translation with Conditional Adversarial Networks 20 | Phillip Isola, Jun-Yan Zhu, Tinghui Zhou and Alexei A. Efros 21 | Arxiv, 2017 22 | https://phillipi.github.io/pix2pix/ 23 | 24 | A reference implementation written in Lua can be found at: 25 | https://github.com/phillipi/pix2pix/blob/master/models.lua 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | import collections 33 | import functools 34 | 35 | import tensorflow as tf 36 | 37 | layers = tf.contrib.layers 38 | 39 | 40 | def pix2pix_arg_scope(): 41 | """Returns a default argument scope for isola_net. 42 | 43 | Returns: 44 | An arg scope. 45 | """ 46 | # These parameters come from the online port, which don't necessarily match 47 | # those in the paper. 48 | # TODO(nsilberman): confirm these values with Philip. 49 | instance_norm_params = { 50 | 'center': True, 51 | 'scale': True, 52 | 'epsilon': 0.00001, 53 | } 54 | 55 | with tf.contrib.framework.arg_scope( 56 | [layers.conv2d, layers.conv2d_transpose], 57 | normalizer_fn=layers.instance_norm, 58 | normalizer_params=instance_norm_params, 59 | weights_initializer=tf.random_normal_initializer(0, 0.02)) as sc: 60 | return sc 61 | 62 | 63 | def upsample(net, num_outputs, kernel_size, method='nn_upsample_conv'): 64 | """Upsamples the given inputs. 65 | 66 | Args: 67 | net: A `Tensor` of size [batch_size, height, width, filters]. 68 | num_outputs: The number of output filters. 69 | kernel_size: A list of 2 scalars or a 1x2 `Tensor` indicating the scale, 70 | relative to the inputs, of the output dimensions. For example, if kernel 71 | size is [2, 3], then the output height and width will be twice and three 72 | times the input size. 73 | method: The upsampling method. 74 | 75 | Returns: 76 | An `Tensor` which was upsampled using the specified method. 77 | 78 | Raises: 79 | ValueError: if `method` is not recognized. 80 | """ 81 | net_shape = tf.shape(net) 82 | height = net_shape[1] 83 | width = net_shape[2] 84 | 85 | if method == 'nn_upsample_conv': 86 | net = tf.image.resize_nearest_neighbor( 87 | net, [kernel_size[0] * height, kernel_size[1] * width]) 88 | net = layers.conv2d(net, num_outputs, [4, 4], activation_fn=None) 89 | elif method == 'conv2d_transpose': 90 | net = layers.conv2d_transpose( 91 | net, num_outputs, [4, 4], stride=kernel_size, activation_fn=None) 92 | else: 93 | raise ValueError('Unknown method: [%s]' % method) 94 | 95 | return net 96 | 97 | 98 | class Block( 99 | collections.namedtuple('Block', ['num_filters', 'decoder_keep_prob'])): 100 | """Represents a single block of encoder and decoder processing. 101 | 102 | The Image-to-Image translation paper works a bit differently than the original 103 | U-Net model. In particular, each block represents a single operation in the 104 | encoder which is concatenated with the corresponding decoder representation. 105 | A dropout layer follows the concatenation and convolution of the concatenated 106 | features. 107 | """ 108 | pass 109 | 110 | 111 | def _default_generator_blocks(): 112 | """Returns the default generator block definitions. 113 | 114 | Returns: 115 | A list of generator blocks. 116 | """ 117 | return [ 118 | Block(64, 0.5), 119 | Block(128, 0.5), 120 | Block(256, 0.5), 121 | Block(512, 0), 122 | Block(512, 0), 123 | Block(512, 0), 124 | Block(512, 0), 125 | ] 126 | 127 | 128 | def pix2pix_generator(net, 129 | num_outputs, 130 | blocks=None, 131 | upsample_method='nn_upsample_conv', 132 | is_training=False): # pylint: disable=unused-argument 133 | """Defines the network architecture. 134 | 135 | Args: 136 | net: A `Tensor` of size [batch, height, width, channels]. Note that the 137 | generator currently requires square inputs (e.g. height=width). 138 | num_outputs: The number of (per-pixel) outputs. 139 | blocks: A list of generator blocks or `None` to use the default generator 140 | definition. 141 | upsample_method: The method of upsampling images, one of 'nn_upsample_conv' 142 | or 'conv2d_transpose' 143 | is_training: Whether or not we're in training or testing mode. 144 | 145 | Returns: 146 | A `Tensor` representing the model output and a dictionary of model end 147 | points. 148 | 149 | Raises: 150 | ValueError: if the input heights do not match their widths. 151 | """ 152 | end_points = {} 153 | 154 | blocks = blocks or _default_generator_blocks() 155 | 156 | input_size = net.get_shape().as_list() 157 | 158 | input_size[3] = num_outputs 159 | 160 | upsample_fn = functools.partial(upsample, method=upsample_method) 161 | 162 | encoder_activations = [] 163 | 164 | ########### 165 | # Encoder # 166 | ########### 167 | with tf.variable_scope('encoder'): 168 | with tf.contrib.framework.arg_scope( 169 | [layers.conv2d], 170 | kernel_size=[4, 4], 171 | stride=2, 172 | activation_fn=tf.nn.leaky_relu): 173 | 174 | for block_id, block in enumerate(blocks): 175 | # No normalizer for the first encoder layers as per 'Image-to-Image', 176 | # Section 5.1.1 177 | if block_id == 0: 178 | # First layer doesn't use normalizer_fn 179 | net = layers.conv2d(net, block.num_filters, normalizer_fn=None) 180 | elif block_id < len(blocks) - 1: 181 | net = layers.conv2d(net, block.num_filters) 182 | else: 183 | # Last layer doesn't use activation_fn nor normalizer_fn 184 | net = layers.conv2d( 185 | net, block.num_filters, activation_fn=None, normalizer_fn=None) 186 | 187 | encoder_activations.append(net) 188 | end_points['encoder%d' % block_id] = net 189 | 190 | ########### 191 | # Decoder # 192 | ########### 193 | reversed_blocks = list(blocks) 194 | reversed_blocks.reverse() 195 | 196 | with tf.variable_scope('decoder'): 197 | # Dropout is used at both train and test time as per 'Image-to-Image', 198 | # Section 2.1 (last paragraph). 199 | with tf.contrib.framework.arg_scope([layers.dropout], is_training=True): 200 | 201 | for block_id, block in enumerate(reversed_blocks): 202 | if block_id > 0: 203 | net = tf.concat([net, encoder_activations[-block_id - 1]], axis=3) 204 | 205 | # The Relu comes BEFORE the upsample op: 206 | net = tf.nn.relu(net) 207 | net = upsample_fn(net, block.num_filters, [2, 2]) 208 | if block.decoder_keep_prob > 0: 209 | net = layers.dropout(net, keep_prob=block.decoder_keep_prob) 210 | end_points['decoder%d' % block_id] = net 211 | 212 | with tf.variable_scope('output'): 213 | # Explicitly set the normalizer_fn to None to override any default value 214 | # that may come from an arg_scope, such as pix2pix_arg_scope. 215 | logits = layers.conv2d( 216 | net, num_outputs, [4, 4], activation_fn=None, normalizer_fn=None) 217 | logits = tf.reshape(logits, input_size) 218 | 219 | end_points['logits'] = logits 220 | end_points['predictions'] = tf.tanh(logits) 221 | 222 | return logits, end_points 223 | 224 | 225 | def pix2pix_discriminator(net, num_filters, padding=2, pad_mode='REFLECT', 226 | activation_fn=tf.nn.leaky_relu, is_training=False): 227 | """Creates the Image2Image Translation Discriminator. 228 | 229 | Args: 230 | net: A `Tensor` of size [batch_size, height, width, channels] representing 231 | the input. 232 | num_filters: A list of the filters in the discriminator. The length of the 233 | list determines the number of layers in the discriminator. 234 | padding: Amount of reflection padding applied before each convolution. 235 | pad_mode: mode for tf.pad, one of "CONSTANT", "REFLECT", or "SYMMETRIC". 236 | activation_fn: activation fn for layers.conv2d. 237 | is_training: Whether or not the model is training or testing. 238 | 239 | Returns: 240 | A logits `Tensor` of size [batch_size, N, N, 1] where N is the number of 241 | 'patches' we're attempting to discriminate and a dictionary of model end 242 | points. 243 | """ 244 | del is_training 245 | end_points = {} 246 | 247 | num_layers = len(num_filters) 248 | 249 | def padded(net, scope): 250 | if padding: 251 | with tf.variable_scope(scope): 252 | spatial_pad = tf.constant( 253 | [[0, 0], [padding, padding], [padding, padding], [0, 0]], 254 | dtype=tf.int32) 255 | return tf.pad(net, spatial_pad, pad_mode) 256 | else: 257 | return net 258 | 259 | with tf.contrib.framework.arg_scope( 260 | [layers.conv2d], 261 | kernel_size=[4, 4], 262 | stride=2, 263 | padding='valid', 264 | activation_fn=activation_fn): 265 | 266 | # No normalization on the input layer. 267 | net = layers.conv2d( 268 | padded(net, 'conv0'), num_filters[0], normalizer_fn=None, scope='conv0') 269 | 270 | end_points['conv0'] = net 271 | 272 | for i in range(1, num_layers - 1): 273 | net = layers.conv2d( 274 | padded(net, 'conv%d' % i), num_filters[i], scope='conv%d' % i) 275 | end_points['conv%d' % i] = net 276 | 277 | # Stride 1 on the last layer. 278 | net = layers.conv2d( 279 | padded(net, 'conv%d' % (num_layers - 1)), 280 | num_filters[-1], 281 | stride=1, 282 | scope='conv%d' % (num_layers - 1)) 283 | end_points['conv%d' % (num_layers - 1)] = net 284 | 285 | # 1-dim logits, stride 1, no activation, no normalization. 286 | logits = layers.conv2d( 287 | padded(net, 'conv%d' % num_layers), 288 | 1, 289 | stride=1, 290 | activation_fn=None, 291 | normalizer_fn=None, 292 | scope='conv%d' % num_layers) 293 | end_points['logits'] = logits 294 | end_points['predictions'] = tf.sigmoid(logits) 295 | return logits, end_points 296 | -------------------------------------------------------------------------------- /nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def _reduced_default_blocks(self): 28 | """Returns the default blocks, scaled down to make test run faster.""" 29 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 30 | for b in pix2pix._default_generator_blocks()] 31 | 32 | def test_output_size_nn_upsample_conv(self): 33 | batch_size = 2 34 | height, width = 256, 256 35 | num_outputs = 4 36 | 37 | images = tf.ones((batch_size, height, width, 3)) 38 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 39 | logits, _ = pix2pix.pix2pix_generator( 40 | images, num_outputs, blocks=self._reduced_default_blocks(), 41 | upsample_method='nn_upsample_conv') 42 | 43 | with self.test_session() as session: 44 | session.run(tf.global_variables_initializer()) 45 | np_outputs = session.run(logits) 46 | self.assertListEqual([batch_size, height, width, num_outputs], 47 | list(np_outputs.shape)) 48 | 49 | def test_output_size_conv2d_transpose(self): 50 | batch_size = 2 51 | height, width = 256, 256 52 | num_outputs = 4 53 | 54 | images = tf.ones((batch_size, height, width, 3)) 55 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 56 | logits, _ = pix2pix.pix2pix_generator( 57 | images, num_outputs, blocks=self._reduced_default_blocks(), 58 | upsample_method='conv2d_transpose') 59 | 60 | with self.test_session() as session: 61 | session.run(tf.global_variables_initializer()) 62 | np_outputs = session.run(logits) 63 | self.assertListEqual([batch_size, height, width, num_outputs], 64 | list(np_outputs.shape)) 65 | 66 | def test_block_number_dictates_number_of_layers(self): 67 | batch_size = 2 68 | height, width = 256, 256 69 | num_outputs = 4 70 | 71 | images = tf.ones((batch_size, height, width, 3)) 72 | blocks = [ 73 | pix2pix.Block(64, 0.5), 74 | pix2pix.Block(128, 0), 75 | ] 76 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 77 | _, end_points = pix2pix.pix2pix_generator( 78 | images, num_outputs, blocks) 79 | 80 | num_encoder_layers = 0 81 | num_decoder_layers = 0 82 | for end_point in end_points: 83 | if end_point.startswith('encoder'): 84 | num_encoder_layers += 1 85 | elif end_point.startswith('decoder'): 86 | num_decoder_layers += 1 87 | 88 | self.assertEqual(num_encoder_layers, len(blocks)) 89 | self.assertEqual(num_decoder_layers, len(blocks)) 90 | 91 | 92 | class DiscriminatorTest(tf.test.TestCase): 93 | 94 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 95 | return (input_size + pad * 2 - kernel_size) // stride + 1 96 | 97 | def test_four_layers(self): 98 | batch_size = 2 99 | input_size = 256 100 | 101 | output_size = self._layer_output_size(input_size) 102 | output_size = self._layer_output_size(output_size) 103 | output_size = self._layer_output_size(output_size) 104 | output_size = self._layer_output_size(output_size, stride=1) 105 | output_size = self._layer_output_size(output_size, stride=1) 106 | 107 | images = tf.ones((batch_size, input_size, input_size, 3)) 108 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 109 | logits, end_points = pix2pix.pix2pix_discriminator( 110 | images, num_filters=[64, 128, 256, 512]) 111 | self.assertListEqual([batch_size, output_size, output_size, 1], 112 | logits.shape.as_list()) 113 | self.assertListEqual([batch_size, output_size, output_size, 1], 114 | end_points['predictions'].shape.as_list()) 115 | 116 | def test_four_layers_no_padding(self): 117 | batch_size = 2 118 | input_size = 256 119 | 120 | output_size = self._layer_output_size(input_size, pad=0) 121 | output_size = self._layer_output_size(output_size, pad=0) 122 | output_size = self._layer_output_size(output_size, pad=0) 123 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 124 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 125 | 126 | images = tf.ones((batch_size, input_size, input_size, 3)) 127 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 128 | logits, end_points = pix2pix.pix2pix_discriminator( 129 | images, num_filters=[64, 128, 256, 512], padding=0) 130 | self.assertListEqual([batch_size, output_size, output_size, 1], 131 | logits.shape.as_list()) 132 | self.assertListEqual([batch_size, output_size, output_size, 1], 133 | end_points['predictions'].shape.as_list()) 134 | 135 | def test_four_layers_wrog_paddig(self): 136 | batch_size = 2 137 | input_size = 256 138 | 139 | images = tf.ones((batch_size, input_size, input_size, 3)) 140 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 141 | with self.assertRaises(TypeError): 142 | pix2pix.pix2pix_discriminator( 143 | images, num_filters=[64, 128, 256, 512], padding=1.5) 144 | 145 | def test_four_layers_negative_padding(self): 146 | batch_size = 2 147 | input_size = 256 148 | 149 | images = tf.ones((batch_size, input_size, input_size, 3)) 150 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 151 | with self.assertRaises(ValueError): 152 | pix2pix.pix2pix_discriminator( 153 | images, num_filters=[64, 128, 256, 512], padding=-1) 154 | 155 | if __name__ == '__main__': 156 | tf.test.main() 157 | --------------------------------------------------------------------------------