├── adanet.png ├── resnet ├── flip_gradient.py ├── augmentors.py ├── resnet_model.py ├── adanet-resnet.py ├── ilsvrcsemi.py └── imagenet_utils.py ├── convlarge ├── flip_gradient.py ├── cnn.py ├── dataset_utils.py ├── dataset_utils_cifar.py ├── test_svhn.py ├── test_cifar.py ├── layers.py ├── svhn.py ├── cifar10.py ├── train_svhn.py └── train_cifar.py ├── LICENSE ├── .gitignore └── README.md /adanet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinenergy/adanet/HEAD/adanet.png -------------------------------------------------------------------------------- /resnet/flip_gradient.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python.framework import ops 7 | 8 | 9 | class FlipGradientBuilder(object): 10 | def __init__(self): 11 | self.num_calls = 0 12 | 13 | def __call__(self, x, l=1.0): 14 | grad_name = "FlipGradient%d" % self.num_calls 15 | @ops.RegisterGradient(grad_name) 16 | def _flip_gradients(op, grad): 17 | return [tf.negative(grad) * l] 18 | 19 | g = tf.get_default_graph() 20 | with g.gradient_override_map({"Identity": grad_name}): 21 | y = tf.identity(x) 22 | 23 | self.num_calls += 1 24 | return y 25 | 26 | flip_gradient = FlipGradientBuilder() 27 | -------------------------------------------------------------------------------- /convlarge/flip_gradient.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python.framework import ops 7 | 8 | 9 | class FlipGradientBuilder(object): 10 | def __init__(self): 11 | self.num_calls = 0 12 | 13 | def __call__(self, x, l=1.0): 14 | grad_name = "FlipGradient%d" % self.num_calls 15 | @ops.RegisterGradient(grad_name) 16 | def _flip_gradients(op, grad): 17 | return [tf.negative(grad) * l] 18 | 19 | g = tf.get_default_graph() 20 | with g.gradient_override_map({"Identity": grad_name}): 21 | y = tf.identity(x) 22 | 23 | self.num_calls += 1 24 | return y 25 | 26 | flip_gradient = FlipGradientBuilder() 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Qin Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADA-Net 2 | Tensorflow implementation 3 | 4 | [Semi-Supervised Learning by Augmented Distribution Alignment](https://arxiv.org/abs/1905.08171) Qin Wang, Wen Li, Luc Van Gool (ICCV 2019 Oral) 5 | 6 | [Thesis: Distribution Aligned Semi-Supervised Learning](https://github.com/qinenergy/adanet/releases/download/0.1/QinThesis.pdf) 2018 August at ETH Zurich 7 | ![](adanet.png) 8 | 9 | 10 | ### Requirements 11 | ``` 12 | pip3 install tensorflow-gpu==1.13.1 13 | pip3 install tensorpack==0.9.1 14 | pip3 install scipy==1.2.1 15 | ``` 16 | ### Train and Eval ADA-Net on ConvLarge 17 | #### Prepare dataset 18 | ``` 19 | cd convlarge 20 | python3 cifar10.py --data_dir=./dataset/cifar10/ --dataset_seed=1 21 | ``` 22 | 23 | #### Train and Eval ADA-Net on Cifar10 ConvLarge 24 | 25 | ``` 26 | CUDA_VISIBLE_DEVICES=0 python3 train_cifar.py --dataset=cifar10 --data_dir=./dataset/cifar10/ --log_dir=./log/cifar10aug/ --num_epochs=2000 --epoch_decay_start=1500 --aug_flip=True --aug_trans=True --dataset_seed=1 27 | CUDA_VISIBLE_DEVICES=0 python3 test_cifar.py --dataset=cifar10 --data_dir=./dataset/cifar10/ --log_dir= --dataset_seed=1 28 | ``` 29 | 30 | Here are the error rates we get using the above scripts : 31 | 32 | | Data Split Seed 1 | Seed 2 | Seed 3 | Reported 33 | | -------- | -------- | -------- |-------- | 34 | | 8.61% | 8.89% | 8.65% | 8.72+-0.12% 35 | 36 | The dataset split seed controls the split between labeled and unlabeled samples. It does not affect the test set. 37 | 38 | 39 | ### Train and Eval ADA-Net on ImageNet ResNet 40 | Download our imagenet labeled/unlabeled split from [this link](https://github.com/qinenergy/adanet/releases/download/0.1/imagenet_split.zip), put them in ./resnet 41 | 42 | ``` 43 | cd resnet 44 | python3 ./adanet-resnet.py --data -d 18 --mode resnet --batch 256 --gpu 0,1,2,3 45 | ``` 46 | 47 | 48 | ### Acknowledgement 49 | + ConvLarge code is based on Takeru Miyato's [tf implementation](https://github.com/takerum/vat_tf). 50 | + ResNet code is based on [Tensorpack](https://github.com/tensorpack/tensorpack/tree/master/examples/ResNet)'s supervised imagenet training scripts. 51 | 52 | ### License 53 | MIT 54 | 55 | ### Citing this work 56 | ``` 57 | @article{wang2019semi, 58 | title={Semi-Supervised Learning by Augmented Distribution Alignment}, 59 | author={Wang, Qin and Li, Wen and Van Gool, Luc}, 60 | journal={arXiv preprint arXiv:1905.08171}, 61 | year={2019} 62 | } 63 | ``` 64 | 65 | ### Reproduce Figure 4 66 | To reproduce Figure 4 in the paper, we provide the plot script and extracted features [here](https://github.com/qinenergy/adanet/releases/download/0.1/Figure4-reproduce.zip). Notice that we use sklearn==0.20.1 for TSNE calculation. 67 | -------------------------------------------------------------------------------- /convlarge/cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy 3 | import sys, os 4 | import layers as L 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | tf.app.flags.DEFINE_float('keep_prob_hidden', 0.5, "dropout rate") 8 | tf.app.flags.DEFINE_float('lrelu_a', 0.1, "lrelu slope") 9 | tf.app.flags.DEFINE_boolean('top_bn', False, "") 10 | 11 | 12 | def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): 13 | h = x 14 | 15 | rng = numpy.random.RandomState(seed) 16 | 17 | h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=128, seed=rng.randint(123456), name='c1') 18 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b1'), FLAGS.lrelu_a) 19 | h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c2') 20 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b2'), FLAGS.lrelu_a) 21 | h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c3') 22 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b3'), FLAGS.lrelu_a) 23 | 24 | h = L.max_pool(h, ksize=2, stride=2) 25 | h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h 26 | 27 | h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=256, seed=rng.randint(123456), name='c4') 28 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b4'), FLAGS.lrelu_a) 29 | h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c5') 30 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b5'), FLAGS.lrelu_a) 31 | h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c6') 32 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b6'), FLAGS.lrelu_a) 33 | 34 | h = L.max_pool(h, ksize=2, stride=2) 35 | h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h 36 | 37 | h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=512, seed=rng.randint(123456), padding="VALID", name='c7') 38 | h = L.lrelu(L.bn(h, 512, is_training=is_training, update_batch_stats=update_batch_stats, name='b7'), FLAGS.lrelu_a) 39 | h = L.conv(h, ksize=1, stride=1, f_in=512, f_out=256, seed=rng.randint(123456), name='c8') 40 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b8'), FLAGS.lrelu_a) 41 | h = L.conv(h, ksize=1, stride=1, f_in=256, f_out=128, seed=rng.randint(123456), name='c9') 42 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b9'), FLAGS.lrelu_a) 43 | 44 | h1 = tf.reduce_mean(h, reduction_indices=[1, 2]) # Features to be aligned 45 | h = L.fc(h1, 128, 10, seed=rng.randint(123456), name='fc') 46 | 47 | if FLAGS.top_bn: 48 | h = L.bn(h, 10, is_training=is_training, 49 | update_batch_stats=update_batch_stats, name='bfc') 50 | 51 | return h, h1 52 | -------------------------------------------------------------------------------- /resnet/augmentors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: augmentors.py 4 | 5 | import numpy as np 6 | import cv2 7 | from tensorpack.dataflow import imgaug 8 | 9 | 10 | __all__ = ['fbresnet_augmentor', 'inference_augmentor', 11 | 'resizeAndLighting_augmentor'] 12 | 13 | 14 | class GoogleNetResize(imgaug.ImageAugmentor): 15 | """ 16 | crop 8%~100% of the original image 17 | See `Going Deeper with Convolutions` by Google. 18 | """ 19 | def __init__(self, crop_area_fraction=0.08, 20 | aspect_ratio_low=0.75, aspect_ratio_high=1.333, 21 | target_shape=224): 22 | self._init(locals()) 23 | 24 | def _augment(self, img, _): 25 | h, w = img.shape[:2] 26 | area = h * w 27 | for _ in range(10): 28 | targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area 29 | aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high) 30 | ww = int(np.sqrt(targetArea * aspectR) + 0.5) 31 | hh = int(np.sqrt(targetArea / aspectR) + 0.5) 32 | if self.rng.uniform() < 0.5: 33 | ww, hh = hh, ww 34 | if hh <= h and ww <= w: 35 | x1 = 0 if w == ww else self.rng.randint(0, w - ww) 36 | y1 = 0 if h == hh else self.rng.randint(0, h - hh) 37 | out = img[y1:y1 + hh, x1:x1 + ww] 38 | out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC) 39 | return out 40 | out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img) 41 | out = imgaug.CenterCrop(self.target_shape).augment(out) 42 | return out 43 | 44 | 45 | def inference_augmentor(): 46 | return [ 47 | imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), 48 | imgaug.CenterCrop((224, 224)) 49 | ] 50 | 51 | 52 | def fbresnet_augmentor(): 53 | # assme BGR input 54 | augmentors = [ 55 | GoogleNetResize(), 56 | imgaug.RandomOrderAug( 57 | [imgaug.BrightnessScale((0.6, 1.4), clip=False), 58 | imgaug.Contrast((0.6, 1.4), clip=False), 59 | imgaug.Saturation(0.4, rgb=False), 60 | # rgb->bgr conversion for the constants copied from fb.resnet.torch 61 | imgaug.Lighting(0.1, 62 | eigval=np.asarray( 63 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0, 64 | eigvec=np.array( 65 | [[-0.5675, 0.7192, 0.4009], 66 | [-0.5808, -0.0045, -0.8140], 67 | [-0.5836, -0.6948, 0.4203]], 68 | dtype='float32')[::-1, ::-1] 69 | )]), 70 | imgaug.Flip(horiz=True), 71 | ] 72 | return augmentors 73 | 74 | 75 | def resizeAndLighting_augmentor(): 76 | # assme BGR input 77 | augmentors = [ 78 | GoogleNetResize(), 79 | imgaug.Lighting(0.1, 80 | eigval=np.asarray( 81 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0, 82 | eigvec=np.array( 83 | [[-0.5675, 0.7192, 0.4009], 84 | [-0.5808, -0.0045, -0.8140], 85 | [-0.5836, -0.6948, 0.4203]], 86 | dtype='float32')[::-1, ::-1]), 87 | imgaug.Flip(horiz=True), 88 | ] 89 | return augmentors 90 | 91 | 92 | def resizeOnly_augmentor(): 93 | # assme BGR input 94 | augmentors = [ 95 | GoogleNetResize(), 96 | imgaug.Lighting(0.1, 97 | eigval=np.asarray( 98 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0, 99 | eigvec=np.array( 100 | [[-0.5675, 0.7192, 0.4009], 101 | [-0.5808, -0.0045, -0.8140], 102 | [-0.5836, -0.6948, 0.4203]], 103 | dtype='float32')[::-1, ::-1]), 104 | imgaug.Flip(horiz=True), 105 | ] 106 | return augmentors 107 | -------------------------------------------------------------------------------- /convlarge/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os, sys, pickle 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | tf.app.flags.DEFINE_bool('aug_trans', False, "") 8 | tf.app.flags.DEFINE_bool('aug_flip', False, "") 9 | 10 | def unpickle(file): 11 | fp = open(file, 'rb') 12 | if sys.version_info.major == 2: 13 | data = pickle.load(fp) 14 | elif sys.version_info.major == 3: 15 | data = pickle.load(fp, encoding='latin-1') 16 | fp.close() 17 | return data 18 | 19 | 20 | def ZCA(data, reg=1e-6): 21 | mean = np.mean(data, axis=0) 22 | mdata = data - mean 23 | sigma = np.dot(mdata.T, mdata) / mdata.shape[0] 24 | U, S, V = linalg.svd(sigma) 25 | components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T) 26 | whiten = np.dot(data - mean, components.T) 27 | return components, mean, whiten 28 | 29 | 30 | def _int64_feature(value): 31 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 32 | 33 | 34 | def _bytes_feature(value): 35 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 36 | 37 | 38 | def convert_images_and_labels(images, labels, filepath): 39 | num_examples = labels.shape[0] 40 | if images.shape[0] != num_examples: 41 | raise ValueError("Images size %d does not match label size %d." % 42 | (images.shape[0], num_examples)) 43 | print('Writing', filepath) 44 | writer = tf.python_io.TFRecordWriter(filepath) 45 | for index in range(num_examples): 46 | image = images[index].tolist() 47 | image_feature = tf.train.Feature(float_list=tf.train.FloatList(value=image)) 48 | example = tf.train.Example(features=tf.train.Features(feature={ 49 | 'height': _int64_feature(32), 50 | 'width': _int64_feature(32), 51 | 'depth': _int64_feature(3), 52 | 'label': _int64_feature(int(labels[index])), 53 | 'image': image_feature})) 54 | writer.write(example.SerializeToString()) 55 | writer.close() 56 | 57 | 58 | def read(filename_queue): 59 | reader = tf.TFRecordReader() 60 | _, serialized_example = reader.read(filename_queue) 61 | features = tf.parse_single_example( 62 | serialized_example, 63 | # Defaults are not specified since both keys are required. 64 | features={ 65 | 'image': tf.FixedLenFeature([3072], tf.float32), 66 | 'label': tf.FixedLenFeature([], tf.int64), 67 | }) 68 | 69 | # Convert label from a scalar uint8 tensor to an int32 scalar. 70 | image = features['image'] 71 | image = tf.reshape(image, [32, 32, 3]) 72 | label = tf.one_hot(tf.cast(features['label'], tf.int32), 10) 73 | return image, label 74 | 75 | 76 | def generate_batch( 77 | example, 78 | min_queue_examples, 79 | batch_size, shuffle): 80 | """ 81 | Arg: 82 | list of tensors. 83 | """ 84 | num_preprocess_threads = 1 85 | 86 | if shuffle: 87 | ret = tf.train.shuffle_batch( 88 | example, 89 | batch_size=batch_size, 90 | num_threads=num_preprocess_threads, 91 | capacity=min_queue_examples + 3 * batch_size, 92 | min_after_dequeue=min_queue_examples) 93 | else: 94 | ret = tf.train.batch( 95 | example, 96 | batch_size=batch_size, 97 | num_threads=num_preprocess_threads, 98 | allow_smaller_final_batch=True, 99 | capacity=min_queue_examples + 3 * batch_size) 100 | 101 | return ret 102 | 103 | 104 | def transform(image): 105 | image = tf.reshape(image, [32, 32, 3]) 106 | if FLAGS.aug_trans or FLAGS.aug_flip: 107 | print("augmentation") 108 | if FLAGS.aug_trans: 109 | image = tf.pad(image, [[2, 2], [2, 2], [0, 0]]) 110 | image = tf.random_crop(image, [32, 32, 3]) 111 | if FLAGS.aug_flip: 112 | image = tf.image.random_flip_left_right(image) 113 | return image 114 | 115 | 116 | def generate_filename_queue(filenames, data_dir, num_epochs=None): 117 | print("filenames in queue:", filenames) 118 | for i in range(len(filenames)): 119 | filenames[i] = os.path.join(data_dir, filenames[i]) 120 | return tf.train.string_input_producer(filenames, num_epochs=num_epochs) 121 | 122 | 123 | -------------------------------------------------------------------------------- /convlarge/dataset_utils_cifar.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os, sys, pickle 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | tf.app.flags.DEFINE_bool('aug_trans', False, "") 8 | tf.app.flags.DEFINE_bool('aug_flip', False, "") 9 | 10 | def unpickle(file): 11 | fp = open(file, 'rb') 12 | if sys.version_info.major == 2: 13 | data = pickle.load(fp) 14 | elif sys.version_info.major == 3: 15 | data = pickle.load(fp, encoding='latin-1') 16 | fp.close() 17 | return data 18 | 19 | 20 | def ZCA(data, reg=1e-6): 21 | mean = np.mean(data, axis=0) 22 | mdata = data - mean 23 | sigma = np.dot(mdata.T, mdata) / mdata.shape[0] 24 | U, S, V = linalg.svd(sigma) 25 | components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T) 26 | whiten = np.dot(data - mean, components.T) 27 | return components, mean, whiten 28 | 29 | 30 | def _int64_feature(value): 31 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 32 | 33 | 34 | def _bytes_feature(value): 35 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 36 | 37 | 38 | def convert_images_and_labels(images, labels, filepath): 39 | num_examples = labels.shape[0] 40 | if images.shape[0] != num_examples: 41 | raise ValueError("Images size %d does not match label size %d." % 42 | (images.shape[0], num_examples)) 43 | print('Writing', filepath) 44 | writer = tf.python_io.TFRecordWriter(filepath) 45 | for index in range(num_examples): 46 | image = images[index].tolist() 47 | image_feature = tf.train.Feature(float_list=tf.train.FloatList(value=image)) 48 | example = tf.train.Example(features=tf.train.Features(feature={ 49 | 'height': _int64_feature(32), 50 | 'width': _int64_feature(32), 51 | 'depth': _int64_feature(3), 52 | 'label': _int64_feature(int(labels[index])), 53 | 'image': image_feature})) 54 | writer.write(example.SerializeToString()) 55 | writer.close() 56 | 57 | 58 | def read(filename_queue): 59 | reader = tf.TFRecordReader() 60 | _, serialized_example = reader.read(filename_queue) 61 | features = tf.parse_single_example( 62 | serialized_example, 63 | # Defaults are not specified since both keys are required. 64 | features={ 65 | 'image': tf.FixedLenFeature([3072], tf.float32), 66 | 'label': tf.FixedLenFeature([], tf.int64), 67 | }) 68 | 69 | # Convert label from a scalar uint8 tensor to an int32 scalar. 70 | image = features['image'] 71 | image = tf.reshape(image, [32, 32, 3]) 72 | label = tf.one_hot(tf.cast(features['label'], tf.int32), 10) 73 | return image, label 74 | 75 | 76 | def generate_batch( 77 | example, 78 | min_queue_examples, 79 | batch_size, shuffle): 80 | """ 81 | Arg: 82 | list of tensors. 83 | """ 84 | num_preprocess_threads = 1 85 | 86 | if shuffle: 87 | ret = tf.train.shuffle_batch( 88 | example, 89 | batch_size=batch_size, 90 | num_threads=num_preprocess_threads, 91 | capacity=min_queue_examples + 3 * batch_size, 92 | min_after_dequeue=min_queue_examples) 93 | else: 94 | ret = tf.train.batch( 95 | example, 96 | batch_size=batch_size, 97 | num_threads=num_preprocess_threads, 98 | allow_smaller_final_batch=True, 99 | capacity=min_queue_examples + 3 * batch_size) 100 | 101 | return ret 102 | 103 | 104 | def transform(image): 105 | image = tf.reshape(image, [32, 32, 3]) 106 | if FLAGS.aug_trans or FLAGS.aug_flip: 107 | print("augmentation") 108 | if FLAGS.aug_trans: 109 | image = tf.pad(image, [[4, 4], [4, 4], [0, 0]]) 110 | image = tf.random_crop(image, [32, 32, 3]) 111 | if FLAGS.aug_flip: 112 | image = tf.image.random_flip_left_right(image) 113 | return image 114 | 115 | 116 | def generate_filename_queue(filenames, data_dir, num_epochs=None): 117 | print("filenames in queue:", filenames) 118 | for i in range(len(filenames)): 119 | filenames[i] = os.path.join(data_dir, filenames[i]) 120 | return tf.train.string_input_producer(filenames, num_epochs=num_epochs) 121 | 122 | 123 | -------------------------------------------------------------------------------- /convlarge/test_svhn.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy 4 | import tensorflow as tf 5 | 6 | import layers as L 7 | import cnn 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | 11 | tf.app.flags.DEFINE_string('device', '/gpu:0', "device") 12 | 13 | tf.app.flags.DEFINE_string('dataset', 'cifar10', "{cifar10, svhn}") 14 | 15 | tf.app.flags.DEFINE_string('log_dir', "", "log_dir") 16 | tf.app.flags.DEFINE_bool('validation', False, "") 17 | 18 | tf.app.flags.DEFINE_integer('finetune_batch_size', 100, "the number of examples in a batch") 19 | tf.app.flags.DEFINE_integer('finetune_iter', 100, "the number of iteration for finetuning of BN stats") 20 | tf.app.flags.DEFINE_integer('eval_batch_size', 500, "the number of examples in a batch") 21 | 22 | 23 | from svhn import inputs, unlabeled_inputs 24 | 25 | def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): 26 | return cnn.logit(x, is_training=is_training, 27 | update_batch_stats=update_batch_stats, 28 | stochastic=stochastic, 29 | seed=seed)[0] 30 | 31 | 32 | def forward(x, is_training=True, update_batch_stats=True, seed=1234): 33 | if is_training: 34 | return logit(x, is_training=True, 35 | update_batch_stats=update_batch_stats, 36 | stochastic=True, seed=seed) 37 | else: 38 | return logit(x, is_training=False, 39 | update_batch_stats=update_batch_stats, 40 | stochastic=False, seed=seed) 41 | 42 | 43 | def build_finetune_graph(x): 44 | logit = forward(x, is_training=True, update_batch_stats=True) 45 | with tf.control_dependencies([logit]): 46 | finetune_op = tf.no_op() 47 | return finetune_op 48 | 49 | 50 | def build_eval_graph(x, y): 51 | logit = forward(x, is_training=False, update_batch_stats=False) 52 | n_corrects = tf.cast(tf.equal(tf.argmax(logit, 1), tf.argmax(y,1)), tf.int32) 53 | return tf.reduce_sum(n_corrects), tf.shape(n_corrects)[0] 54 | 55 | 56 | def main(_): 57 | with tf.Graph().as_default() as g: 58 | with tf.device("/cpu:0"): 59 | images_eval_train, _ = inputs(batch_size=FLAGS.finetune_batch_size, 60 | validation=FLAGS.validation, 61 | shuffle=True) 62 | images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size, 63 | train=False, 64 | validation=FLAGS.validation, 65 | shuffle=False, num_epochs=1) 66 | 67 | with tf.device(FLAGS.device): 68 | with tf.variable_scope("CNN") as scope: 69 | # Build graph of finetuning BN stats 70 | finetune_op = build_finetune_graph(images_eval_train) 71 | scope.reuse_variables() 72 | # Build eval graph 73 | n_correct, m = build_eval_graph(images_eval_test, labels_eval_test) 74 | 75 | init_op = tf.global_variables_initializer() 76 | saver = tf.train.Saver(tf.global_variables()) 77 | sess = tf.Session() 78 | sess.run(init_op) 79 | ckpt = tf.train.get_checkpoint_state(FLAGS.log_dir) 80 | print("Checkpoints:", ckpt) 81 | if ckpt and ckpt.model_checkpoint_path: 82 | saver.restore(sess, ckpt.model_checkpoint_path) 83 | sess.run(tf.local_variables_initializer()) 84 | coord = tf.train.Coordinator() 85 | tf.train.start_queue_runners(sess=sess, coord=coord) 86 | print("Finetuning...") 87 | for _ in range(FLAGS.finetune_iter): 88 | sess.run(finetune_op) 89 | 90 | sum_correct_examples= 0 91 | sum_m = 0 92 | try: 93 | while not coord.should_stop(): 94 | _n_correct, _m = sess.run([n_correct, m]) 95 | sum_correct_examples += _n_correct 96 | sum_m += _m 97 | except tf.errors.OutOfRangeError: 98 | print('Done evaluation -- epoch limit reached') 99 | finally: 100 | # When done, ask the threads to stop. 101 | coord.request_stop() 102 | print("Test: num_test_examples:{}, num_correct_examples:{}, accuracy:{}".format( 103 | sum_m, sum_correct_examples, sum_correct_examples/float(sum_m))) 104 | 105 | 106 | if __name__ == "__main__": 107 | tf.app.run() 108 | -------------------------------------------------------------------------------- /convlarge/test_cifar.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy 4 | import tensorflow as tf 5 | 6 | import layers as L 7 | import cnn 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | 11 | tf.app.flags.DEFINE_string('device', '/gpu:0', "device") 12 | 13 | tf.app.flags.DEFINE_string('dataset', 'cifar10', "{cifar10, svhn}") 14 | 15 | tf.app.flags.DEFINE_string('log_dir', "", "log_dir") 16 | tf.app.flags.DEFINE_bool('validation', False, "") 17 | 18 | tf.app.flags.DEFINE_integer('finetune_batch_size', 100, "the number of examples in a batch") 19 | tf.app.flags.DEFINE_integer('finetune_iter', 100, "the number of iteration for finetuning of BN stats") 20 | tf.app.flags.DEFINE_integer('eval_batch_size', 500, "the number of examples in a batch") 21 | 22 | 23 | from cifar10 import inputs, unlabeled_inputs 24 | 25 | def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): 26 | return cnn.logit(x, is_training=is_training, 27 | update_batch_stats=update_batch_stats, 28 | stochastic=stochastic, 29 | seed=seed)[0] 30 | 31 | 32 | def forward(x, is_training=True, update_batch_stats=True, seed=1234): 33 | if is_training: 34 | return logit(x, is_training=True, 35 | update_batch_stats=update_batch_stats, 36 | stochastic=True, seed=seed) 37 | else: 38 | return logit(x, is_training=False, 39 | update_batch_stats=update_batch_stats, 40 | stochastic=False, seed=seed) 41 | 42 | 43 | def build_finetune_graph(x): 44 | logit = forward(x, is_training=True, update_batch_stats=True) 45 | with tf.control_dependencies([logit]): 46 | finetune_op = tf.no_op() 47 | return finetune_op 48 | 49 | 50 | def build_eval_graph(x, y): 51 | logit = forward(x, is_training=False, update_batch_stats=False) 52 | n_corrects = tf.cast(tf.equal(tf.argmax(logit, 1), tf.argmax(y,1)), tf.int32) 53 | return tf.reduce_sum(n_corrects), tf.shape(n_corrects)[0] 54 | 55 | 56 | def main(_): 57 | with tf.Graph().as_default() as g: 58 | with tf.device("/cpu:0"): 59 | images_eval_train, _ = inputs(batch_size=FLAGS.finetune_batch_size, 60 | validation=FLAGS.validation, 61 | shuffle=True) 62 | images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size, 63 | train=False, 64 | validation=FLAGS.validation, 65 | shuffle=False, num_epochs=1) 66 | 67 | with tf.device(FLAGS.device): 68 | with tf.variable_scope("CNN") as scope: 69 | # Build graph of finetuning BN stats 70 | finetune_op = build_finetune_graph(images_eval_train) 71 | scope.reuse_variables() 72 | # Build eval graph 73 | n_correct, m = build_eval_graph(images_eval_test, labels_eval_test) 74 | 75 | init_op = tf.global_variables_initializer() 76 | saver = tf.train.Saver(tf.global_variables()) 77 | sess = tf.Session() 78 | sess.run(init_op) 79 | ckpt = tf.train.get_checkpoint_state(FLAGS.log_dir) 80 | print("Checkpoints:", ckpt) 81 | if ckpt and ckpt.model_checkpoint_path: 82 | saver.restore(sess, ckpt.model_checkpoint_path) 83 | sess.run(tf.local_variables_initializer()) 84 | coord = tf.train.Coordinator() 85 | tf.train.start_queue_runners(sess=sess, coord=coord) 86 | print("Finetuning...") 87 | for _ in range(FLAGS.finetune_iter): 88 | sess.run(finetune_op) 89 | 90 | sum_correct_examples= 0 91 | sum_m = 0 92 | try: 93 | while not coord.should_stop(): 94 | _n_correct, _m = sess.run([n_correct, m]) 95 | sum_correct_examples += _n_correct 96 | sum_m += _m 97 | except tf.errors.OutOfRangeError: 98 | print('Done evaluation -- epoch limit reached') 99 | finally: 100 | # When done, ask the threads to stop. 101 | coord.request_stop() 102 | print("Test: num_test_examples:{}, num_correct_examples:{}, accuracy:{}".format( 103 | sum_m, sum_correct_examples, sum_correct_examples/float(sum_m))) 104 | 105 | 106 | if __name__ == "__main__": 107 | tf.app.run() 108 | -------------------------------------------------------------------------------- /convlarge/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy 3 | import sys, os 4 | 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | tf.app.flags.DEFINE_float('bn_stats_decay_factor', 0.99, 8 | "moving average decay factor for stats on batch normalization") 9 | 10 | 11 | def lrelu(x, a=0.1): 12 | if a < 1e-16: 13 | return tf.nn.relu(x) 14 | else: 15 | return tf.maximum(x, a * x) 16 | 17 | 18 | def bn(x, dim, is_training=True, update_batch_stats=True, collections=None, name="bn"): 19 | params_shape = (dim,) 20 | n = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1])) 21 | axis = list(range(int(tf.shape(x).get_shape().as_list()[0]) - 1)) 22 | mean = tf.reduce_mean(x, axis) 23 | var = tf.reduce_mean(tf.pow(x - mean, 2.0), axis) 24 | avg_mean = tf.get_variable( 25 | name=name + "_mean", 26 | shape=params_shape, 27 | initializer=tf.constant_initializer(0.0), 28 | collections=collections, 29 | trainable=False 30 | ) 31 | 32 | avg_var = tf.get_variable( 33 | name=name + "_var", 34 | shape=params_shape, 35 | initializer=tf.constant_initializer(1.0), 36 | collections=collections, 37 | trainable=False 38 | ) 39 | 40 | gamma = tf.get_variable( 41 | name=name + "_gamma", 42 | shape=params_shape, 43 | initializer=tf.constant_initializer(1.0), 44 | collections=collections 45 | ) 46 | 47 | beta = tf.get_variable( 48 | name=name + "_beta", 49 | shape=params_shape, 50 | initializer=tf.constant_initializer(0.0), 51 | collections=collections, 52 | ) 53 | 54 | if is_training: 55 | avg_mean_assign_op = tf.no_op() 56 | avg_var_assign_op = tf.no_op() 57 | if update_batch_stats: 58 | avg_mean_assign_op = tf.assign( 59 | avg_mean, 60 | FLAGS.bn_stats_decay_factor * avg_mean + (1 - FLAGS.bn_stats_decay_factor) * mean) 61 | avg_var_assign_op = tf.assign( 62 | avg_var, 63 | FLAGS.bn_stats_decay_factor * avg_var + (n / (n - 1)) 64 | * (1 - FLAGS.bn_stats_decay_factor) * var) 65 | 66 | with tf.control_dependencies([avg_mean_assign_op, avg_var_assign_op]): 67 | z = (x - mean) / tf.sqrt(1e-6 + var) 68 | else: 69 | z = (x - avg_mean) / tf.sqrt(1e-6 + avg_var) 70 | 71 | return gamma * z + beta 72 | 73 | 74 | def fc(x, dim_in, dim_out, seed=None, name='fc'): 75 | num_units_in = dim_in 76 | num_units_out = dim_out 77 | weights_initializer = tf.contrib.layers.variance_scaling_initializer(seed=seed) 78 | 79 | weights = tf.get_variable(name + '_W', 80 | shape=[num_units_in, num_units_out], 81 | initializer=weights_initializer) 82 | biases = tf.get_variable(name + '_b', 83 | shape=[num_units_out], 84 | initializer=tf.constant_initializer(0.0)) 85 | x = tf.nn.xw_plus_b(x, weights, biases) 86 | return x 87 | 88 | 89 | def conv(x, ksize, stride, f_in, f_out, padding='SAME', use_bias=False, seed=None, name='conv'): 90 | shape = [ksize, ksize, f_in, f_out] 91 | initializer = tf.contrib.layers.variance_scaling_initializer(seed=seed) 92 | weights = tf.get_variable(name + '_W', 93 | shape=shape, 94 | dtype='float', 95 | initializer=initializer) 96 | x = tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding=padding) 97 | 98 | if use_bias: 99 | bias = tf.get_variable(name + '_b', 100 | shape=[f_out], 101 | dtype='float', 102 | initializer=tf.zeros_initializer) 103 | return tf.nn.bias_add(x, bias) 104 | else: 105 | return x 106 | 107 | 108 | def avg_pool(x, ksize=2, stride=2): 109 | return tf.nn.avg_pool(x, 110 | ksize=[1, ksize, ksize, 1], 111 | strides=[1, stride, stride, 1], 112 | padding='SAME') 113 | 114 | 115 | def max_pool(x, ksize=2, stride=2): 116 | return tf.nn.max_pool(x, 117 | ksize=[1, ksize, ksize, 1], 118 | strides=[1, stride, stride, 1], 119 | padding='SAME') 120 | 121 | 122 | def ce_loss(logit, y): 123 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y)) 124 | 125 | 126 | def accuracy(logit, y): 127 | pred = tf.argmax(logit, 1) 128 | true = tf.argmax(y, 1) 129 | return tf.reduce_mean(tf.to_float(tf.equal(pred, true))) 130 | 131 | 132 | def logsoftmax(x): 133 | xdev = x - tf.reduce_max(x, 1, keep_dims=True) 134 | lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keep_dims=True)) 135 | return lsm 136 | 137 | 138 | def kl_divergence_with_logit(q_logit, p_logit): 139 | q = tf.nn.softmax(q_logit) 140 | qlogq = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(q_logit), 1)) 141 | qlogp = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(p_logit), 1)) 142 | return qlogq - qlogp 143 | 144 | 145 | def entropy_y_x(logit): 146 | p = tf.nn.softmax(logit) 147 | return -tf.reduce_mean(tf.reduce_sum(p * logsoftmax(logit), 1)) 148 | -------------------------------------------------------------------------------- /resnet/resnet_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: resnet_model.py 3 | 4 | import tensorflow as tf 5 | 6 | from tensorpack.tfutils.argscope import argscope, get_arg_scope 7 | from tensorpack.models import ( 8 | Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, FullyConnected) 9 | 10 | 11 | def resnet_shortcut(l, n_out, stride, activation=tf.identity): 12 | data_format = get_arg_scope()['Conv2D']['data_format'] 13 | n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3] 14 | if n_in != n_out: # change dimension when channel is not the same 15 | return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) 16 | else: 17 | return l 18 | 19 | 20 | def apply_preactivation(l, preact): 21 | if preact == 'bnrelu': 22 | shortcut = l # preserve identity mapping 23 | l = BNReLU('preact', l) 24 | else: 25 | shortcut = l 26 | return l, shortcut 27 | 28 | 29 | def get_bn(zero_init=False): 30 | """ 31 | Zero init gamma is good for resnet. See https://arxiv.org/abs/1706.02677. 32 | """ 33 | if zero_init: 34 | return lambda x, name=None: BatchNorm('bn', x, gamma_initializer=tf.zeros_initializer()) 35 | else: 36 | return lambda x, name=None: BatchNorm('bn', x) 37 | 38 | 39 | def preresnet_basicblock(l, ch_out, stride, preact): 40 | l, shortcut = apply_preactivation(l, preact) 41 | l = Conv2D('conv1', l, ch_out, 3, strides=stride, activation=BNReLU) 42 | l = Conv2D('conv2', l, ch_out, 3) 43 | return l + resnet_shortcut(shortcut, ch_out, stride) 44 | 45 | 46 | def preresnet_bottleneck(l, ch_out, stride, preact): 47 | # stride is applied on the second conv, following fb.resnet.torch 48 | l, shortcut = apply_preactivation(l, preact) 49 | l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU) 50 | l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU) 51 | l = Conv2D('conv3', l, ch_out * 4, 1) 52 | return l + resnet_shortcut(shortcut, ch_out * 4, stride) 53 | 54 | 55 | def preresnet_group(name, l, block_func, features, count, stride): 56 | with tf.variable_scope(name): 57 | for i in range(0, count): 58 | with tf.variable_scope('block{}'.format(i)): 59 | # first block doesn't need activation 60 | l = block_func(l, features, 61 | stride if i == 0 else 1, 62 | 'no_preact' if i == 0 else 'bnrelu') 63 | # end of each group need an extra activation 64 | l = BNReLU('bnlast', l) 65 | return l 66 | 67 | 68 | def resnet_basicblock(l, ch_out, stride): 69 | shortcut = l 70 | l = Conv2D('conv1', l, ch_out, 3, strides=stride, activation=BNReLU) 71 | l = Conv2D('conv2', l, ch_out, 3, activation=get_bn(zero_init=True)) 72 | out = l + resnet_shortcut(shortcut, ch_out, stride, activation=get_bn(zero_init=False)) 73 | return tf.nn.relu(out) 74 | 75 | 76 | def resnet_bottleneck(l, ch_out, stride, stride_first=False): 77 | """ 78 | stride_first: original resnet put stride on first conv. fb.resnet.torch put stride on second conv. 79 | """ 80 | shortcut = l 81 | l = Conv2D('conv1', l, ch_out, 1, strides=stride if stride_first else 1, activation=BNReLU) 82 | l = Conv2D('conv2', l, ch_out, 3, strides=1 if stride_first else stride, activation=BNReLU) 83 | l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) 84 | out = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False)) 85 | return tf.nn.relu(out) 86 | 87 | 88 | def se_resnet_bottleneck(l, ch_out, stride): 89 | shortcut = l 90 | l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU) 91 | l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU) 92 | l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) 93 | 94 | squeeze = GlobalAvgPooling('gap', l) 95 | squeeze = FullyConnected('fc1', squeeze, ch_out // 4, activation=tf.nn.relu) 96 | squeeze = FullyConnected('fc2', squeeze, ch_out * 4, activation=tf.nn.sigmoid) 97 | data_format = get_arg_scope()['Conv2D']['data_format'] 98 | ch_ax = 1 if data_format in ['NCHW', 'channels_first'] else 3 99 | shape = [-1, 1, 1, 1] 100 | shape[ch_ax] = ch_out * 4 101 | l = l * tf.reshape(squeeze, shape) 102 | out = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False)) 103 | return tf.nn.relu(out) 104 | 105 | 106 | def resnet_group(name, l, block_func, features, count, stride): 107 | with tf.variable_scope(name): 108 | for i in range(0, count): 109 | with tf.variable_scope('block{}'.format(i)): 110 | l = block_func(l, features, stride if i == 0 else 1) 111 | return l 112 | 113 | 114 | def resnet_backbone(image, num_blocks, group_func, block_func): 115 | with argscope(Conv2D, use_bias=False, 116 | kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')): 117 | # Note that this pads the image by [2, 3] instead of [3, 2]. 118 | # Similar things happen in later stride=2 layers as well. 119 | l = Conv2D('conv0', image, 64, 7, strides=2, activation=BNReLU) 120 | l = MaxPooling('pool0', l, pool_size=3, strides=2, padding='SAME') 121 | l = group_func('group0', l, block_func, 64, num_blocks[0], 1) 122 | l = group_func('group1', l, block_func, 128, num_blocks[1], 2) 123 | l = group_func('group2', l, block_func, 256, num_blocks[2], 2) 124 | l = group_func('group3', l, block_func, 512, num_blocks[3], 2) 125 | l = GlobalAvgPooling('gap', l) 126 | logits = FullyConnected('linear', l, 1000, 127 | kernel_initializer=tf.random_normal_initializer(stddev=0.01)) 128 | return logits, l 129 | -------------------------------------------------------------------------------- /resnet/adanet-resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: imagenet-resnet.py 4 | 5 | import argparse 6 | import os 7 | 8 | from tensorpack import logger, QueueInput 9 | from tensorpack.models import * 10 | from tensorpack.callbacks import * 11 | from tensorpack.train import ( 12 | AutoResumeTrainConfig, SyncMultiGPUTrainerReplicated, launch_train_with_config) 13 | from tensorpack.dataflow import FakeData 14 | from tensorpack.tfutils import argscope, get_model_loader 15 | from tensorpack.utils.gpu import get_num_gpu 16 | 17 | from imagenet_utils import ( 18 | fbresnet_augmentor, get_imagenet_dataflow, ImageNetModel, 19 | eval_on_ILSVRC12) 20 | from resnet_model import ( 21 | preresnet_group, preresnet_basicblock, preresnet_bottleneck, 22 | resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck, 23 | resnet_backbone) 24 | 25 | 26 | class Model(ImageNetModel): 27 | def __init__(self, depth, mode='resnet'): 28 | if mode == 'se': 29 | assert depth >= 50 30 | 31 | self.mode = mode 32 | basicblock = preresnet_basicblock if mode == 'preact' else resnet_basicblock 33 | bottleneck = { 34 | 'resnet': resnet_bottleneck, 35 | 'preact': preresnet_bottleneck, 36 | 'se': se_resnet_bottleneck}[mode] 37 | self.num_blocks, self.block_func = { 38 | 18: ([2, 2, 2, 2], basicblock), 39 | 34: ([3, 4, 6, 3], basicblock), 40 | 50: ([3, 4, 6, 3], bottleneck), 41 | 101: ([3, 4, 23, 3], bottleneck), 42 | 152: ([3, 8, 36, 3], bottleneck) 43 | }[depth] 44 | 45 | def get_logits(self, image): 46 | with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format): 47 | return resnet_backbone( 48 | image, self.num_blocks, 49 | preresnet_group if self.mode == 'preact' else resnet_group, self.block_func) 50 | 51 | 52 | def get_data(name, batch): 53 | isTrain = name == 'train' 54 | augmentors = fbresnet_augmentor(isTrain) 55 | return get_imagenet_dataflow( 56 | args.data, name, batch, augmentors) 57 | 58 | 59 | def get_config(model, fake=False): 60 | nr_tower = max(get_num_gpu(), 1) 61 | assert args.batch % nr_tower == 0 62 | batch = args.batch // nr_tower 63 | 64 | logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch)) 65 | if batch < 32 or batch > 64: 66 | logger.warn("Batch size per tower not in [32, 64]. This probably will lead to worse accuracy than reported.") 67 | if fake: 68 | data = QueueInput(FakeData( 69 | [[batch, 224, 224, 3], [batch],[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8')) 70 | callbacks = [] 71 | else: 72 | data = QueueInput(get_data('train', batch)) 73 | 74 | START_LR = 0.1 75 | BASE_LR = START_LR * (args.batch / 256.0) 76 | callbacks = [ 77 | ModelSaver(), 78 | EstimatedTimeLeft(), 79 | ScheduledHyperParamSetter( 80 | 'learning_rate', [ 81 | (0, min(START_LR, BASE_LR)), (30, BASE_LR * 1e-1), (45, BASE_LR * 1e-2), 82 | (55, BASE_LR * 1e-3)]), 83 | ] 84 | if BASE_LR > START_LR: 85 | callbacks.append( 86 | ScheduledHyperParamSetter( 87 | 'learning_rate', [(0, START_LR), (5, BASE_LR)], interp='linear')) 88 | 89 | infs = [ClassificationError('wrong-top1', 'val-error-top1'), 90 | ClassificationError('wrong-top5', 'val-error-top5')] 91 | dataset_val = get_data('val', batch) 92 | if nr_tower == 1: 93 | # single-GPU inference with queue prefetch 94 | callbacks.append(InferenceRunner(QueueInput(dataset_val), infs)) 95 | else: 96 | # multi-GPU inference (with mandatory queue prefetch) 97 | callbacks.append(DataParallelInferenceRunner( 98 | dataset_val, infs, list(range(nr_tower)))) 99 | 100 | return AutoResumeTrainConfig( 101 | model=model, 102 | data=data, 103 | callbacks=callbacks, 104 | steps_per_epoch=100 if args.fake else 1280000 // args.batch, 105 | max_epoch=60, 106 | ) 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use. Default to use all available ones') 112 | parser.add_argument('--data', help='ILSVRC dataset dir') 113 | parser.add_argument('--load', help='load a model for training or evaluation') 114 | parser.add_argument('--fake', help='use FakeData to debug or benchmark this model', action='store_true') 115 | parser.add_argument('--data-format', help='image data format', 116 | default='NCHW', choices=['NCHW', 'NHWC']) 117 | parser.add_argument('-d', '--depth', help='ResNet depth', 118 | type=int, default=50, choices=[18, 34, 50, 101, 152]) 119 | parser.add_argument('--eval', action='store_true', help='run offline evaluation instead of training') 120 | parser.add_argument('--batch', default=256, type=int, 121 | help="total batch size. " 122 | "Note that it's best to keep per-GPU batch size in [32, 64] to obtain the best accuracy." 123 | "Pretrained models listed in README were trained with batch=32x8.") 124 | parser.add_argument('--mode', choices=['resnet', 'preact', 'se'], 125 | help='variants of resnet to use', default='resnet') 126 | args = parser.parse_args() 127 | 128 | if args.gpu: 129 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 130 | 131 | model = Model(args.depth, args.mode) 132 | model.data_format = args.data_format 133 | if args.eval: 134 | batch = 128 # something that can run on one gpu 135 | ds = get_data('val', batch) 136 | eval_on_ILSVRC12(model, get_model_loader(args.load), ds) 137 | else: 138 | if args.fake: 139 | logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd') 140 | else: 141 | logger.set_logger_dir( 142 | os.path.join('train_log', 'imagenet-{}-d{}-batch{}'.format(args.mode, args.depth, args.batch))) 143 | 144 | config = get_config(model, fake=args.fake) 145 | if args.load: 146 | config.session_init = get_model_loader(args.load) 147 | trainer = SyncMultiGPUTrainerReplicated(max(get_num_gpu(), 1)) 148 | launch_train_with_config(config, trainer) 149 | -------------------------------------------------------------------------------- /convlarge/svhn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | from scipy.io import loadmat 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | import glob 12 | import pickle 13 | 14 | from six.moves import xrange # pylint: disable=redefined-builtin 15 | from six.moves import urllib 16 | 17 | import tensorflow as tf 18 | from dataset_utils import * 19 | 20 | DATA_URL_TRAIN = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat' 21 | DATA_URL_TEST = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat' 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | tf.app.flags.DEFINE_string('data_dir', './dataset/svhn', "") 25 | tf.app.flags.DEFINE_integer('num_labeled_examples', 1000, "The number of labeled examples") 26 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 27 | tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 28 | 29 | NUM_EXAMPLES_TRAIN = 73257 30 | NUM_EXAMPLES_TEST = 26032 31 | 32 | 33 | def maybe_download_and_extract(): 34 | if not os.path.exists(FLAGS.data_dir): 35 | os.makedirs(FLAGS.data_dir) 36 | filepath_train_mat = os.path.join(FLAGS.data_dir, 'train_32x32.mat') 37 | filepath_test_mat = os.path.join(FLAGS.data_dir, 'test_32x32.mat') 38 | if not os.path.exists(filepath_train_mat) or not os.path.exists(filepath_test_mat): 39 | def _progress(count, block_size, total_size): 40 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 41 | sys.stdout.flush() 42 | 43 | urllib.request.urlretrieve(DATA_URL_TRAIN, filepath_train_mat, _progress) 44 | urllib.request.urlretrieve(DATA_URL_TEST, filepath_test_mat, _progress) 45 | 46 | # Training set 47 | print("Loading training data...") 48 | print("Preprocessing training data...") 49 | train_data = loadmat(FLAGS.data_dir + '/train_32x32.mat') 50 | train_x = (-127.5 + train_data['X']) / 255. 51 | train_x = train_x.transpose((3, 0, 1, 2)) 52 | train_x = train_x.reshape([train_x.shape[0], -1]) 53 | train_y = train_data['y'].flatten().astype(np.int32) 54 | train_y[train_y == 10] = 0 55 | 56 | # Test set 57 | print("Loading test data...") 58 | test_data = loadmat(FLAGS.data_dir + '/test_32x32.mat') 59 | test_x = (-127.5 + test_data['X']) / 255. 60 | test_x = test_x.transpose((3, 0, 1, 2)) 61 | test_x = test_x.reshape((test_x.shape[0], -1)) 62 | test_y = test_data['y'].flatten().astype(np.int32) 63 | test_y[test_y == 10] = 0 64 | 65 | np.save('{}/train_images'.format(FLAGS.data_dir), train_x) 66 | np.save('{}/train_labels'.format(FLAGS.data_dir), train_y) 67 | np.save('{}/test_images'.format(FLAGS.data_dir), test_x) 68 | np.save('{}/test_labels'.format(FLAGS.data_dir), test_y) 69 | 70 | 71 | def load_svhn(): 72 | maybe_download_and_extract() 73 | train_images = np.load('{}/train_images.npy'.format(FLAGS.data_dir)).astype(np.float32) 74 | train_labels = np.load('{}/train_labels.npy'.format(FLAGS.data_dir)).astype(np.float32) 75 | test_images = np.load('{}/test_images.npy'.format(FLAGS.data_dir)).astype(np.float32) 76 | test_labels = np.load('{}/test_labels.npy'.format(FLAGS.data_dir)).astype(np.float32) 77 | return (train_images, train_labels), (test_images, test_labels) 78 | 79 | 80 | def prepare_dataset(): 81 | (train_images, train_labels), (test_images, test_labels) = load_svhn() 82 | dirpath = os.path.join(FLAGS.data_dir, 'seed' + str(FLAGS.dataset_seed)) 83 | if not os.path.exists(dirpath): 84 | os.makedirs(dirpath) 85 | 86 | rng = np.random.RandomState(FLAGS.dataset_seed) 87 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 88 | print(rand_ix) 89 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 90 | 91 | labeled_ind = np.arange(FLAGS.num_labeled_examples) 92 | labeled_train_images, labeled_train_labels = _train_images[labeled_ind], _train_labels[labeled_ind] 93 | _train_images = np.delete(_train_images, labeled_ind, 0) 94 | _train_labels = np.delete(_train_labels, labeled_ind, 0) 95 | convert_images_and_labels(labeled_train_images, 96 | labeled_train_labels, 97 | os.path.join(dirpath, 'labeled_train.tfrecords')) 98 | convert_images_and_labels(train_images, train_labels, 99 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 100 | convert_images_and_labels(test_images, 101 | test_labels, 102 | os.path.join(dirpath, 'test.tfrecords')) 103 | 104 | # Construct dataset for validation 105 | train_images_valid, train_labels_valid = labeled_train_images, labeled_train_labels 106 | test_images_valid, test_labels_valid = \ 107 | _train_images[:FLAGS.num_valid_examples], _train_labels[:FLAGS.num_valid_examples] 108 | unlabeled_train_images_valid = np.concatenate( 109 | (train_images_valid, _train_images[FLAGS.num_valid_examples:]), axis=0) 110 | unlabeled_train_labels_valid = np.concatenate( 111 | (train_labels_valid, _train_labels[FLAGS.num_valid_examples:]), axis=0) 112 | convert_images_and_labels(train_images_valid, 113 | train_labels_valid, 114 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 115 | convert_images_and_labels(unlabeled_train_images_valid, 116 | unlabeled_train_labels_valid, 117 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 118 | convert_images_and_labels(test_images_valid, 119 | test_labels_valid, 120 | os.path.join(dirpath, 'test_val.tfrecords')) 121 | 122 | 123 | def inputs(batch_size=100, 124 | train=True, validation=False, 125 | shuffle=True, num_epochs=None): 126 | if validation: 127 | if train: 128 | filenames = ['labeled_train_val.tfrecords'] 129 | num_examples = FLAGS.num_labeled_examples 130 | else: 131 | filenames = ['test_val.tfrecords'] 132 | num_examples = FLAGS.num_valid_examples 133 | else: 134 | if train: 135 | filenames = ['labeled_train.tfrecords'] 136 | num_examples = FLAGS.num_labeled_examples 137 | else: 138 | filenames = ['test.tfrecords'] 139 | num_examples = NUM_EXAMPLES_TEST 140 | 141 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 142 | filename_queue = generate_filename_queue(filenames, FLAGS.data_dir, num_epochs) 143 | image, label = read(filename_queue) 144 | image = transform(tf.cast(image, tf.float32)) if train else image 145 | return generate_batch([image, label], num_examples, batch_size, shuffle) 146 | 147 | 148 | def unlabeled_inputs(batch_size=100, 149 | validation=False, 150 | shuffle=True): 151 | if validation: 152 | filenames = ['unlabeled_train_val.tfrecords'] 153 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 154 | else: 155 | filenames = ['unlabeled_train.tfrecords'] 156 | num_examples = NUM_EXAMPLES_TRAIN 157 | 158 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 159 | filename_queue = generate_filename_queue(filenames, data_dir=FLAGS.data_dir) 160 | image, label = read(filename_queue) 161 | image = transform(tf.cast(image, tf.float32)) 162 | return generate_batch([image], num_examples, batch_size, shuffle) 163 | 164 | 165 | def main(argv): 166 | prepare_dataset() 167 | 168 | 169 | if __name__ == "__main__": 170 | tf.app.run() 171 | -------------------------------------------------------------------------------- /convlarge/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | """Routine for decoding the CIFAR-10 binary file format.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | import numpy as np 25 | from scipy import linalg 26 | import glob 27 | import pickle 28 | 29 | from six.moves import xrange # pylint: disable=redefined-builtin 30 | from six.moves import urllib 31 | 32 | import tensorflow as tf 33 | 34 | from dataset_utils_cifar import * 35 | 36 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 37 | 38 | FLAGS = tf.app.flags.FLAGS 39 | tf.app.flags.DEFINE_string('data_dir', './dataset/cifar10', 40 | 'where to store the dataset') 41 | tf.app.flags.DEFINE_integer('num_labeled_examples', 4000, "The number of labeled examples") 42 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 43 | tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 44 | 45 | # Process images of this size. Note that this differs from the original CIFAR 46 | # image size of 32 x 32. If one alters this number, then the entire model 47 | # architecture will change and any model would need to be retrained. 48 | IMAGE_SIZE = 32 49 | 50 | # Global constants describing the CIFAR-10 data set. 51 | NUM_CLASSES = 10 52 | NUM_EXAMPLES_TRAIN = 50000 53 | NUM_EXAMPLES_TEST = 10000 54 | 55 | def load_cifar10(): 56 | """Download and extract the tarball from Alex's website.""" 57 | dest_directory = FLAGS.data_dir 58 | if not os.path.exists(dest_directory): 59 | os.makedirs(dest_directory) 60 | filename = DATA_URL.split('/')[-1] 61 | filepath = os.path.join(dest_directory, filename) 62 | if not os.path.exists(filepath): 63 | def _progress(count, block_size, total_size): 64 | sys.stdout.write('\r>> Downloading %s %.1f%%' % 65 | (filename, float(count * block_size) / 66 | float(total_size) * 100.0)) 67 | sys.stdout.flush() 68 | 69 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 70 | print() 71 | statinfo = os.stat(filepath) 72 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 73 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 74 | 75 | # Training set 76 | print("Loading training data...") 77 | train_images = np.zeros((NUM_EXAMPLES_TRAIN, 3 * 32 * 32), dtype=np.float32) 78 | train_labels = [] 79 | for i, data_fn in enumerate( 80 | sorted(glob.glob(FLAGS.data_dir + '/cifar-10-batches-py/data_batch*'))): 81 | batch = unpickle(data_fn) 82 | train_images[i * 10000:(i + 1) * 10000] = batch['data'] 83 | train_labels.extend(batch['labels']) 84 | train_images = (train_images - 127.5) / 255. 85 | train_labels = np.asarray(train_labels, dtype=np.int64) 86 | 87 | rand_ix = np.random.permutation(NUM_EXAMPLES_TRAIN) 88 | train_images = train_images[rand_ix] 89 | train_labels = train_labels[rand_ix] 90 | 91 | print("Loading test data...") 92 | test = unpickle(FLAGS.data_dir + '/cifar-10-batches-py/test_batch') 93 | test_images = test['data'].astype(np.float32) 94 | test_images = (test_images - 127.5) / 255. 95 | test_labels = np.asarray(test['labels'], dtype=np.int64) 96 | 97 | """ 98 | print("Apply ZCA whitening") 99 | components, mean, train_images = ZCA(train_images) 100 | np.save('{}/components'.format(FLAGS.data_dir), components) 101 | np.save('{}/mean'.format(FLAGS.data_dir), mean) 102 | test_images = np.dot(test_images - mean, components.T) 103 | """ 104 | 105 | train_images = train_images.reshape( 106 | (NUM_EXAMPLES_TRAIN, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TRAIN, -1)) 107 | test_images = test_images.reshape( 108 | (NUM_EXAMPLES_TEST, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TEST, -1)) 109 | return (train_images, train_labels), (test_images, test_labels) 110 | 111 | 112 | def prepare_dataset(): 113 | (train_images, train_labels), (test_images, test_labels) = load_cifar10() 114 | dirpath = os.path.join(FLAGS.data_dir, 'seed' + str(FLAGS.dataset_seed)) 115 | if not os.path.exists(dirpath): 116 | os.makedirs(dirpath) 117 | 118 | rng = np.random.RandomState(FLAGS.dataset_seed) 119 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 120 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 121 | 122 | examples_per_class = int(FLAGS.num_labeled_examples / 10) 123 | labeled_train_images = np.zeros((FLAGS.num_labeled_examples, 3072), dtype=np.float32) 124 | labeled_train_labels = np.zeros((FLAGS.num_labeled_examples), dtype=np.int64) 125 | for i in xrange(10): 126 | ind = np.where(_train_labels == i)[0] 127 | labeled_train_images[i * examples_per_class:(i + 1) * examples_per_class] \ 128 | = _train_images[ind[0:examples_per_class]] 129 | labeled_train_labels[i * examples_per_class:(i + 1) * examples_per_class] \ 130 | = _train_labels[ind[0:examples_per_class]] 131 | _train_images = np.delete(_train_images, 132 | ind[0:examples_per_class], 0) 133 | _train_labels = np.delete(_train_labels, 134 | ind[0:examples_per_class]) 135 | 136 | rand_ix_labeled = rng.permutation(FLAGS.num_labeled_examples) 137 | labeled_train_images, labeled_train_labels = \ 138 | labeled_train_images[rand_ix_labeled], labeled_train_labels[rand_ix_labeled] 139 | 140 | convert_images_and_labels(labeled_train_images, 141 | labeled_train_labels, 142 | os.path.join(dirpath, 'labeled_train.tfrecords')) 143 | convert_images_and_labels(train_images, train_labels, 144 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 145 | convert_images_and_labels(test_images, 146 | test_labels, 147 | os.path.join(dirpath, 'test.tfrecords')) 148 | 149 | # Construct dataset for validation 150 | train_images_valid, train_labels_valid = \ 151 | labeled_train_images[FLAGS.num_valid_examples:], labeled_train_labels[FLAGS.num_valid_examples:] 152 | test_images_valid, test_labels_valid = \ 153 | labeled_train_images[:FLAGS.num_valid_examples], labeled_train_labels[:FLAGS.num_valid_examples] 154 | unlabeled_train_images_valid = np.concatenate( 155 | (train_images_valid, _train_images), axis=0) 156 | unlabeled_train_labels_valid = np.concatenate( 157 | (train_labels_valid, _train_labels), axis=0) 158 | convert_images_and_labels(train_images_valid, 159 | train_labels_valid, 160 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 161 | convert_images_and_labels(unlabeled_train_images_valid, 162 | unlabeled_train_labels_valid, 163 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 164 | convert_images_and_labels(test_images_valid, 165 | test_labels_valid, 166 | os.path.join(dirpath, 'test_val.tfrecords')) 167 | 168 | 169 | def inputs(batch_size=100, 170 | train=True, validation=False, 171 | shuffle=True, num_epochs=None): 172 | if validation: 173 | if train: 174 | filenames = ['labeled_train_val.tfrecords'] 175 | num_examples = FLAGS.num_labeled_examples - FLAGS.num_valid_examples 176 | else: 177 | filenames = ['test_val.tfrecords'] 178 | num_examples = FLAGS.num_valid_examples 179 | else: 180 | if train: 181 | filenames = ['labeled_train.tfrecords'] 182 | num_examples = FLAGS.num_labeled_examples 183 | else: 184 | filenames = ['test.tfrecords'] 185 | num_examples = NUM_EXAMPLES_TEST 186 | 187 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 188 | 189 | filename_queue = generate_filename_queue(filenames, FLAGS.data_dir, num_epochs) 190 | image, label = read(filename_queue) 191 | image = transform(tf.cast(image, tf.float32)) if train else image 192 | return generate_batch([image, label], num_examples, batch_size, shuffle) 193 | 194 | 195 | def unlabeled_inputs(batch_size=100, 196 | validation=False, 197 | shuffle=True): 198 | if validation: 199 | filenames = ['unlabeled_train_val.tfrecords'] 200 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 201 | else: 202 | filenames = ['unlabeled_train.tfrecords'] 203 | num_examples = NUM_EXAMPLES_TRAIN 204 | 205 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 206 | filename_queue = generate_filename_queue(filenames, FLAGS.data_dir) 207 | image, label = read(filename_queue) 208 | image = transform(tf.cast(image, tf.float32)) 209 | return generate_batch([image], num_examples, batch_size, shuffle) 210 | 211 | 212 | def main(argv): 213 | prepare_dataset() 214 | 215 | 216 | if __name__ == "__main__": 217 | tf.app.run() 218 | -------------------------------------------------------------------------------- /convlarge/train_svhn.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy 4 | import tensorflow as tf 5 | 6 | import layers as L 7 | import cnn 8 | 9 | from flip_gradient import flip_gradient 10 | from svhn import inputs, unlabeled_inputs 11 | 12 | FLAGS = tf.app.flags.FLAGS 13 | 14 | tf.app.flags.DEFINE_string('device', '/gpu:0', "device") 15 | tf.app.flags.DEFINE_string('dataset', 'cifar10', "{cifar10, svhn}") 16 | tf.app.flags.DEFINE_string('log_dir', "", "log_dir") 17 | tf.app.flags.DEFINE_integer('seed', 1, "initial random seed") 18 | tf.app.flags.DEFINE_bool('validation', False, "") 19 | tf.app.flags.DEFINE_bool('one_hot', False, "") 20 | tf.app.flags.DEFINE_integer('batch_size', 128, "the number of examples in a batch") 21 | tf.app.flags.DEFINE_integer('ul_batch_size', 128, "the number of unlabeled examples in a batch") 22 | tf.app.flags.DEFINE_integer('eval_batch_size', 100, "the number of eval examples in a batch") 23 | tf.app.flags.DEFINE_integer('eval_freq', 5, "") 24 | tf.app.flags.DEFINE_integer('num_epochs', 120, "the number of epochs for training") 25 | tf.app.flags.DEFINE_integer('epoch_decay_start', 80, "epoch of starting learning rate decay") 26 | tf.app.flags.DEFINE_integer('num_iter_per_epoch', 400, "the number of updates per epoch") 27 | tf.app.flags.DEFINE_float('learning_rate', 0.001, "initial leanring rate") 28 | tf.app.flags.DEFINE_float('mom1', 0.9, "initial momentum rate") 29 | tf.app.flags.DEFINE_float('mom2', 0.5, "momentum rate after epoch_decay_start") 30 | 31 | 32 | 33 | 34 | NUM_EVAL_EXAMPLES = 5000 35 | 36 | def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): 37 | return cnn.logit(x, is_training=is_training, 38 | update_batch_stats=update_batch_stats, 39 | stochastic=stochastic, 40 | seed=seed)[0] 41 | 42 | 43 | def forward(x, is_training=True, update_batch_stats=True, seed=1234): 44 | if is_training: 45 | return logit(x, is_training=True, 46 | update_batch_stats=update_batch_stats, 47 | stochastic=True, seed=seed) 48 | else: 49 | return logit(x, is_training=False, 50 | update_batch_stats=update_batch_stats, 51 | stochastic=False, seed=seed) 52 | 53 | def build_training_graph(x1, y1, x2, lr, mom): 54 | global_step = tf.get_variable( 55 | name="global_step", 56 | shape=[], 57 | dtype=tf.float32, 58 | initializer=tf.constant_initializer(0.0), 59 | trainable=False, 60 | ) 61 | k = 1. * global_step / (FLAGS.num_iter_per_epoch * FLAGS.num_epochs) 62 | # lp schedule from GRL 63 | lp = (2. / (1. + tf.exp(-10. * k)) - 1) 64 | 65 | # Interpolation 66 | y2_logit, _ = cnn.logit(x2, is_training=False, update_batch_stats=False, stochastic=False) 67 | if FLAGS.one_hot: 68 | y2 = tf.stop_gradient(tf.cast(tf.one_hot(tf.argmax(y2_logit, -1), 10), tf.float32)) 69 | else: 70 | y2 = tf.stop_gradient(tf.nn.softmax(y2_logit)) 71 | 72 | dist_beta = tf.distributions.Beta(0.1, 0.1) 73 | lmb = dist_beta.sample(tf.shape(x1)[0]) 74 | lmb_x = tf.reshape(lmb, [-1, 1, 1, 1]) 75 | lmb_y = tf.reshape(lmb, [-1, 1]) 76 | x = x1 * lmb_x + x2 * (1. - lmb_x) 77 | y = y1 * lmb_y + y2 * (1. - lmb_y) 78 | 79 | label_dm = tf.concat([tf.reshape(lmb, [-1, 1]), tf.reshape(1. - lmb, [-1, 1])], axis=1) 80 | 81 | # Calculate the feats and logits on interpolated samples 82 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 83 | logit, net = cnn.logit(x, is_training=True, update_batch_stats=True) 84 | 85 | # Alignment Loss 86 | net_ = flip_gradient(net, lp) 87 | logitsdm = tf.layers.dense(net_, 1024, activation=tf.nn.relu, name='linear_dm1') 88 | logitsdm = tf.layers.dense(logitsdm, 1024, activation=tf.nn.relu, name='linear_dm2') 89 | logits_dm = tf.layers.dense(logitsdm, 2, name="logits_dm") 90 | dm_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label_dm, logits=logits_dm)) 91 | additional_loss = dm_loss 92 | 93 | nll_loss = tf.reduce_mean(lmb*tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logit)) 94 | 95 | loss = nll_loss + additional_loss 96 | 97 | opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom) 98 | tvars = tf.trainable_variables() 99 | grads_and_vars = opt.compute_gradients(loss, tvars) 100 | train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) 101 | return loss, train_op, global_step 102 | 103 | 104 | def build_eval_graph(x, y, ul_x): 105 | losses = {} 106 | logit = forward(x, is_training=False, update_batch_stats=False) 107 | nll_loss = L.ce_loss(logit, y) 108 | losses['NLL'] = nll_loss 109 | acc = L.accuracy(logit, y) 110 | losses['Acc'] = acc 111 | return losses 112 | 113 | 114 | def main(_): 115 | numpy.random.seed(seed=FLAGS.seed) 116 | tf.set_random_seed(numpy.random.randint(1234)) 117 | with tf.Graph().as_default() as g: 118 | with tf.device("/cpu:0"): 119 | images, labels = inputs(batch_size=FLAGS.batch_size, 120 | train=True, 121 | validation=FLAGS.validation, 122 | shuffle=True) 123 | ul_images = unlabeled_inputs(batch_size=FLAGS.ul_batch_size, 124 | validation=FLAGS.validation, 125 | shuffle=True) 126 | 127 | images_eval_train, labels_eval_train = inputs(batch_size=FLAGS.eval_batch_size, 128 | train=True, 129 | validation=FLAGS.validation, 130 | shuffle=True) 131 | ul_images_eval_train = unlabeled_inputs(batch_size=FLAGS.eval_batch_size, 132 | validation=FLAGS.validation, 133 | shuffle=True) 134 | 135 | images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size, 136 | train=False, 137 | validation=FLAGS.validation, 138 | shuffle=True) 139 | 140 | with tf.device(FLAGS.device): 141 | lr = tf.placeholder(tf.float32, shape=[], name="learning_rate") 142 | mom = tf.placeholder(tf.float32, shape=[], name="momentum") 143 | with tf.variable_scope("CNN") as scope: 144 | # Build training graph 145 | loss, train_op, global_step = build_training_graph(images, labels, ul_images, lr, mom) 146 | scope.reuse_variables() 147 | # Build eval graph 148 | losses_eval_train = build_eval_graph(images_eval_train, labels_eval_train, ul_images_eval_train) 149 | losses_eval_test = build_eval_graph(images_eval_test, labels_eval_test, images_eval_test) 150 | 151 | init_op = tf.global_variables_initializer() 152 | 153 | if not FLAGS.log_dir: 154 | logdir = None 155 | writer_train = None 156 | writer_test = None 157 | else: 158 | logdir = FLAGS.log_dir 159 | writer_train = tf.summary.FileWriter(FLAGS.log_dir + "/train", g) 160 | writer_test = tf.summary.FileWriter(FLAGS.log_dir + "/test", g) 161 | 162 | saver = tf.train.Saver(tf.global_variables()) 163 | sv = tf.train.Supervisor( 164 | is_chief=True, 165 | logdir=logdir, 166 | init_op=init_op, 167 | init_feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1}, 168 | saver=saver, 169 | global_step=global_step, 170 | summary_op=None, 171 | summary_writer=None, 172 | save_model_secs=150, recovery_wait_secs=0) 173 | 174 | print("Training...") 175 | with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 176 | for ep in range(FLAGS.num_epochs): 177 | if sv.should_stop(): 178 | break 179 | 180 | if ep < FLAGS.epoch_decay_start: 181 | feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1} 182 | else: 183 | decayed_lr = ((FLAGS.num_epochs - ep) / float( 184 | FLAGS.num_epochs - FLAGS.epoch_decay_start)) * FLAGS.learning_rate 185 | feed_dict = {lr: decayed_lr, mom: FLAGS.mom2} 186 | 187 | sum_loss = 0 188 | start = time.time() 189 | for i in range(FLAGS.num_iter_per_epoch): 190 | _, batch_loss, _ = sess.run([train_op, loss, global_step], 191 | feed_dict=feed_dict) 192 | sum_loss += batch_loss 193 | end = time.time() 194 | print("Epoch:", ep, "CE_loss_train:", sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:", end - start) 195 | 196 | if (ep + 1) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs: 197 | # Eval on training data 198 | act_values_dict = {} 199 | for key, _ in losses_eval_train.items(): 200 | act_values_dict[key] = 0 201 | n_iter_per_epoch = NUM_EVAL_EXAMPLES // FLAGS.eval_batch_size 202 | for i in range(n_iter_per_epoch): 203 | values = list(losses_eval_train.values()) 204 | act_values = sess.run(values) 205 | for key, value in zip(list(act_values_dict.keys()), act_values): 206 | act_values_dict[key] += value 207 | summary = tf.Summary() 208 | current_global_step = sess.run(global_step) 209 | for key, value in act_values_dict.items(): 210 | print("train-" + key, value / n_iter_per_epoch) 211 | summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) 212 | if writer_train is not None: 213 | writer_train.add_summary(summary, current_global_step) 214 | 215 | # Eval on test data 216 | act_values_dict = {} 217 | for key, _ in losses_eval_test.items(): 218 | act_values_dict[key] = 0 219 | n_iter_per_epoch = NUM_EVAL_EXAMPLES // FLAGS.eval_batch_size 220 | for i in range(n_iter_per_epoch): 221 | values = list(losses_eval_test.values()) 222 | act_values = sess.run(values) 223 | for key, value in zip(list(act_values_dict.keys()), act_values): 224 | act_values_dict[key] += value 225 | summary = tf.Summary() 226 | current_global_step = sess.run(global_step) 227 | for key, value in act_values_dict.items(): 228 | print("test-" + key, value / n_iter_per_epoch) 229 | summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) 230 | if writer_test is not None: 231 | writer_test.add_summary(summary, current_global_step) 232 | 233 | saver.save(sess, sv.save_path, global_step=global_step) 234 | sv.stop() 235 | 236 | 237 | if __name__ == "__main__": 238 | tf.app.run() 239 | -------------------------------------------------------------------------------- /convlarge/train_cifar.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy 4 | import tensorflow as tf 5 | 6 | import layers as L 7 | import cnn 8 | 9 | from flip_gradient import flip_gradient 10 | from cifar10 import inputs, unlabeled_inputs 11 | 12 | FLAGS = tf.app.flags.FLAGS 13 | 14 | tf.app.flags.DEFINE_string('device', '/gpu:0', "device") 15 | tf.app.flags.DEFINE_string('dataset', 'cifar10', "{cifar10, svhn}") 16 | tf.app.flags.DEFINE_string('log_dir', "", "log_dir") 17 | tf.app.flags.DEFINE_integer('seed', 1, "initial random seed") 18 | tf.app.flags.DEFINE_bool('validation', False, "") 19 | tf.app.flags.DEFINE_bool('one_hot', False, "") 20 | tf.app.flags.DEFINE_integer('batch_size', 100, "the number of examples in a batch") 21 | tf.app.flags.DEFINE_integer('ul_batch_size', 100, "the number of unlabeled examples in a batch") 22 | tf.app.flags.DEFINE_integer('eval_batch_size', 100, "the number of eval examples in a batch") 23 | tf.app.flags.DEFINE_integer('eval_freq', 5, "") 24 | tf.app.flags.DEFINE_integer('num_epochs', 120, "the number of epochs for training") 25 | tf.app.flags.DEFINE_integer('epoch_decay_start', 80, "epoch of starting learning rate decay") 26 | tf.app.flags.DEFINE_integer('num_iter_per_epoch', int(400*128/100), "the number of updates per epoch") 27 | tf.app.flags.DEFINE_float('learning_rate', 0.001, "initial leanring rate") 28 | tf.app.flags.DEFINE_float('mom1', 0.9, "initial momentum rate") 29 | tf.app.flags.DEFINE_float('mom2', 0.5, "momentum rate after epoch_decay_start") 30 | 31 | 32 | 33 | 34 | NUM_EVAL_EXAMPLES = 5000 35 | 36 | def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): 37 | return cnn.logit(x, is_training=is_training, 38 | update_batch_stats=update_batch_stats, 39 | stochastic=stochastic, 40 | seed=seed)[0] 41 | 42 | 43 | def forward(x, is_training=True, update_batch_stats=True, seed=1234): 44 | if is_training: 45 | return logit(x, is_training=True, 46 | update_batch_stats=update_batch_stats, 47 | stochastic=True, seed=seed) 48 | else: 49 | return logit(x, is_training=False, 50 | update_batch_stats=update_batch_stats, 51 | stochastic=False, seed=seed) 52 | 53 | def build_training_graph(x1, y1, x2, lr, mom): 54 | global_step = tf.get_variable( 55 | name="global_step", 56 | shape=[], 57 | dtype=tf.float32, 58 | initializer=tf.constant_initializer(0.0), 59 | trainable=False, 60 | ) 61 | k = 1. * global_step / (FLAGS.num_iter_per_epoch * FLAGS.num_epochs) 62 | # lp schedule from GRL 63 | lp = 1. * (2. / (1. + tf.exp(-10. * k)) - 1) 64 | 65 | # Interpolation 66 | y2_logit, _ = cnn.logit(x2, is_training=False, update_batch_stats=False, stochastic=False) 67 | if FLAGS.one_hot: 68 | y2 = tf.stop_gradient(tf.cast(tf.one_hot(tf.argmax(y2_logit, -1), 10), tf.float32)) 69 | else: 70 | y2 = tf.stop_gradient(tf.nn.softmax(y2_logit)) 71 | 72 | dist_beta = tf.distributions.Beta(1.0, 1.0) 73 | lmb = dist_beta.sample(tf.shape(x1)[0]) 74 | lmb_x = tf.reshape(lmb, [-1, 1, 1, 1]) 75 | lmb_y = tf.reshape(lmb, [-1, 1]) 76 | x = x1 * lmb_x + x2 * (1. - lmb_x) 77 | y = y1 * lmb_y + y2 * (1. - lmb_y) 78 | 79 | label_dm = tf.concat([tf.reshape(lmb, [-1, 1]), tf.reshape(1. - lmb, [-1, 1])], axis=1) 80 | 81 | # Calculate the feats and logits on interpolated samples 82 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 83 | logit, net = cnn.logit(x, is_training=True, update_batch_stats=True) 84 | 85 | # Alignment Loss 86 | net_ = flip_gradient(net, lp) 87 | logitsdm = tf.layers.dense(net_, 1024, activation=tf.nn.relu, name='linear_dm1') 88 | logitsdm = tf.layers.dense(logitsdm, 1024, activation=tf.nn.relu, name='linear_dm2') 89 | logits_dm = tf.layers.dense(logitsdm, 2, name="logits_dm") 90 | dm_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label_dm, logits=logits_dm)) 91 | additional_loss = dm_loss 92 | 93 | nll_loss = tf.reduce_mean(lmb*tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logit)) 94 | 95 | loss = nll_loss + additional_loss 96 | 97 | opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom) 98 | tvars = tf.trainable_variables() 99 | grads_and_vars = opt.compute_gradients(loss, tvars) 100 | train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) 101 | return loss, train_op, global_step 102 | 103 | 104 | def build_eval_graph(x, y, ul_x): 105 | losses = {} 106 | logit = forward(x, is_training=False, update_batch_stats=False) 107 | nll_loss = L.ce_loss(logit, y) 108 | losses['NLL'] = nll_loss 109 | acc = L.accuracy(logit, y) 110 | losses['Acc'] = acc 111 | return losses 112 | 113 | 114 | def main(_): 115 | numpy.random.seed(seed=FLAGS.seed) 116 | tf.set_random_seed(numpy.random.randint(1234)) 117 | with tf.Graph().as_default() as g: 118 | with tf.device("/cpu:0"): 119 | images, labels = inputs(batch_size=FLAGS.batch_size, 120 | train=True, 121 | validation=FLAGS.validation, 122 | shuffle=True) 123 | ul_images = unlabeled_inputs(batch_size=FLAGS.ul_batch_size, 124 | validation=FLAGS.validation, 125 | shuffle=True) 126 | 127 | images_eval_train, labels_eval_train = inputs(batch_size=FLAGS.eval_batch_size, 128 | train=True, 129 | validation=FLAGS.validation, 130 | shuffle=True) 131 | ul_images_eval_train = unlabeled_inputs(batch_size=FLAGS.eval_batch_size, 132 | validation=FLAGS.validation, 133 | shuffle=True) 134 | 135 | images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size, 136 | train=False, 137 | validation=FLAGS.validation, 138 | shuffle=True) 139 | 140 | with tf.device(FLAGS.device): 141 | lr = tf.placeholder(tf.float32, shape=[], name="learning_rate") 142 | mom = tf.placeholder(tf.float32, shape=[], name="momentum") 143 | with tf.variable_scope("CNN") as scope: 144 | # Build training graph 145 | loss, train_op, global_step = build_training_graph(images, labels, ul_images, lr, mom) 146 | scope.reuse_variables() 147 | # Build eval graph 148 | losses_eval_train = build_eval_graph(images_eval_train, labels_eval_train, ul_images_eval_train) 149 | losses_eval_test = build_eval_graph(images_eval_test, labels_eval_test, images_eval_test) 150 | 151 | init_op = tf.global_variables_initializer() 152 | 153 | if not FLAGS.log_dir: 154 | logdir = None 155 | writer_train = None 156 | writer_test = None 157 | else: 158 | logdir = FLAGS.log_dir 159 | writer_train = tf.summary.FileWriter(FLAGS.log_dir + "/train", g) 160 | writer_test = tf.summary.FileWriter(FLAGS.log_dir + "/test", g) 161 | 162 | saver = tf.train.Saver(tf.global_variables()) 163 | sv = tf.train.Supervisor( 164 | is_chief=True, 165 | logdir=logdir, 166 | init_op=init_op, 167 | init_feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1}, 168 | saver=saver, 169 | global_step=global_step, 170 | summary_op=None, 171 | summary_writer=None, 172 | save_model_secs=150, recovery_wait_secs=0) 173 | 174 | print("Training...") 175 | with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 176 | for ep in range(FLAGS.num_epochs): 177 | if sv.should_stop(): 178 | break 179 | 180 | if ep < FLAGS.epoch_decay_start: 181 | feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1} 182 | else: 183 | decayed_lr = ((FLAGS.num_epochs - ep) / float( 184 | FLAGS.num_epochs - FLAGS.epoch_decay_start)) * FLAGS.learning_rate 185 | feed_dict = {lr: decayed_lr, mom: FLAGS.mom2} 186 | 187 | sum_loss = 0 188 | start = time.time() 189 | for i in range(FLAGS.num_iter_per_epoch): 190 | _, batch_loss, _ = sess.run([train_op, loss, global_step], 191 | feed_dict=feed_dict) 192 | sum_loss += batch_loss 193 | end = time.time() 194 | print("Epoch:", ep, "CE_loss_train:", sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:", end - start) 195 | 196 | if (ep + 1) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs: 197 | # Eval on training data 198 | act_values_dict = {} 199 | for key, _ in losses_eval_train.items(): 200 | act_values_dict[key] = 0 201 | n_iter_per_epoch = NUM_EVAL_EXAMPLES // FLAGS.eval_batch_size 202 | for i in range(n_iter_per_epoch): 203 | values = list(losses_eval_train.values()) 204 | act_values = sess.run(values) 205 | for key, value in zip(list(act_values_dict.keys()), act_values): 206 | act_values_dict[key] += value 207 | summary = tf.Summary() 208 | current_global_step = sess.run(global_step) 209 | for key, value in act_values_dict.items(): 210 | print("train-" + key, value / n_iter_per_epoch) 211 | summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) 212 | if writer_train is not None: 213 | writer_train.add_summary(summary, current_global_step) 214 | 215 | # Eval on test data 216 | act_values_dict = {} 217 | for key, _ in losses_eval_test.items(): 218 | act_values_dict[key] = 0 219 | n_iter_per_epoch = NUM_EVAL_EXAMPLES // FLAGS.eval_batch_size 220 | for i in range(n_iter_per_epoch): 221 | values = list(losses_eval_test.values()) 222 | act_values = sess.run(values) 223 | for key, value in zip(list(act_values_dict.keys()), act_values): 224 | act_values_dict[key] += value 225 | summary = tf.Summary() 226 | current_global_step = sess.run(global_step) 227 | for key, value in act_values_dict.items(): 228 | print("test-" + key, value / n_iter_per_epoch) 229 | summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) 230 | if writer_test is not None: 231 | writer_test.add_summary(summary, current_global_step) 232 | 233 | saver.save(sess, sv.save_path, global_step=global_step) 234 | sv.stop() 235 | 236 | 237 | if __name__ == "__main__": 238 | tf.app.run() 239 | -------------------------------------------------------------------------------- /resnet/ilsvrcsemi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: ilsvrc.py 3 | 4 | import numpy as np 5 | import os 6 | import tarfile 7 | import tqdm 8 | 9 | from tensorpack.utils import logger 10 | from tensorpack.utils.fs import download, get_dataset_path, mkdir_p 11 | from tensorpack.utils.loadcaffe import get_caffe_pb 12 | from tensorpack.utils.timer import timed_operation 13 | from tensorpack.dataflow.base import RNGDataFlow 14 | 15 | __all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files'] 16 | 17 | CAFFE_ILSVRC12_URL = ("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", 17858008) 18 | 19 | 20 | class ILSVRCMeta(object): 21 | """ 22 | Provide methods to access metadata for ILSVRC dataset. 23 | """ 24 | 25 | def __init__(self, dir=None): 26 | if dir is None: 27 | dir = get_dataset_path('ilsvrc_metadata') 28 | self.dir = os.path.expanduser(dir) 29 | mkdir_p(self.dir) 30 | f = os.path.join(self.dir, 'synsets.txt') 31 | if not os.path.isfile(f): 32 | self._download_caffe_meta() 33 | self.caffepb = None 34 | 35 | def get_synset_words_1000(self): 36 | """ 37 | Returns: 38 | dict: {cls_number: cls_name} 39 | """ 40 | fname = os.path.join(self.dir, 'synset_words.txt') 41 | assert os.path.isfile(fname) 42 | lines = [x.strip() for x in open(fname).readlines()] 43 | return dict(enumerate(lines)) 44 | 45 | def get_synset_1000(self): 46 | """ 47 | Returns: 48 | dict: {cls_number: synset_id} 49 | """ 50 | fname = os.path.join(self.dir, 'synsets.txt') 51 | assert os.path.isfile(fname) 52 | lines = [x.strip() for x in open(fname).readlines()] 53 | return dict(enumerate(lines)) 54 | 55 | def _download_caffe_meta(self): 56 | fpath = download(CAFFE_ILSVRC12_URL[0], self.dir, expect_size=CAFFE_ILSVRC12_URL[1]) 57 | tarfile.open(fpath, 'r:gz').extractall(self.dir) 58 | 59 | def get_image_list(self, name, dir_structure='original', labeled=True): 60 | """ 61 | Args: 62 | name (str): 'train' or 'val' or 'test' 63 | dir_structure (str): same as in :meth:`ILSVRC12.__init__()`. 64 | Returns: 65 | list: list of (image filename, label) 66 | """ 67 | assert name in ['train', 'val', 'test'] 68 | assert dir_structure in ['original', 'train'] 69 | add_label_to_fname = (name != 'train' and dir_structure != 'original') 70 | if add_label_to_fname: 71 | synset = self.get_synset_1000() 72 | 73 | if labeled: 74 | print("Read labeled training") 75 | fname = name + '_labeled.txt' 76 | else: 77 | print("Read unlabeled training") 78 | fname = name + '_unlabeled.txt' 79 | assert os.path.isfile(fname), fname 80 | with open(fname) as f: 81 | lines = f.readlines() 82 | if labeled: 83 | lines = lines * 9 84 | np.random.shuffle(lines) 85 | ret = [] 86 | for line in lines: 87 | name, cls = line.strip().split() 88 | cls = int(cls) 89 | 90 | if add_label_to_fname: 91 | name = os.path.join(synset[cls], name) 92 | 93 | ret.append((name.strip(), cls)) 94 | assert len(ret), fname 95 | print(len(ret)) 96 | return ret 97 | 98 | def get_per_pixel_mean(self, size=None): 99 | """ 100 | Args: 101 | size (tuple): image size in (h, w). Defaults to (256, 256). 102 | Returns: 103 | np.ndarray: per-pixel mean of shape (h, w, 3 (BGR)) in range [0, 255]. 104 | """ 105 | if self.caffepb is None: 106 | self.caffepb = get_caffe_pb() 107 | obj = self.caffepb.BlobProto() 108 | 109 | mean_file = os.path.join(self.dir, 'imagenet_mean.binaryproto') 110 | with open(mean_file, 'rb') as f: 111 | obj.ParseFromString(f.read()) 112 | arr = np.array(obj.data).reshape((3, 256, 256)).astype('float32') 113 | arr = np.transpose(arr, [1, 2, 0]) 114 | if size is not None: 115 | arr = cv2.resize(arr, size[::-1]) 116 | return arr 117 | 118 | @staticmethod 119 | def guess_dir_structure(dir): 120 | """ 121 | Return the directory structure of "dir". 122 | 123 | Args: 124 | dir(str): something like '/path/to/imagenet/val' 125 | 126 | Returns: 127 | either 'train' or 'original' 128 | """ 129 | subdir = os.listdir(dir)[0] 130 | # find a subdir starting with 'n' 131 | if subdir.startswith('n') and \ 132 | os.path.isdir(os.path.join(dir, subdir)): 133 | dir_structure = 'train' 134 | else: 135 | dir_structure = 'original' 136 | logger.info( 137 | "[ILSVRC12] Assuming directory {} has '{}' structure.".format( 138 | dir, dir_structure)) 139 | return dir_structure 140 | 141 | 142 | class ILSVRC12Files(RNGDataFlow): 143 | """ 144 | Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays. 145 | This could be useful when ``cv2.imread`` is a bottleneck and you want to 146 | decode it in smarter ways (e.g. in parallel). 147 | """ 148 | def __init__(self, dir, name, meta_dir=None, 149 | shuffle=None, dir_structure=None, labeled=True): 150 | """ 151 | Same as in :class:`ILSVRC12`. 152 | """ 153 | assert name in ['train', 'test', 'val'], name 154 | dir = os.path.expanduser(dir) 155 | assert os.path.isdir(dir), dir 156 | self.full_dir = os.path.join(dir, name) 157 | self.name = name 158 | assert os.path.isdir(self.full_dir), self.full_dir 159 | assert meta_dir is None or os.path.isdir(meta_dir), meta_dir 160 | if shuffle is None: 161 | shuffle = name == 'train' 162 | self.shuffle = shuffle 163 | 164 | if name == 'train': 165 | dir_structure = 'train' 166 | if dir_structure is None: 167 | dir_structure = ILSVRCMeta.guess_dir_structure(self.full_dir) 168 | 169 | meta = ILSVRCMeta(meta_dir) 170 | self.imglist = meta.get_image_list(name, dir_structure, labeled=labeled) 171 | 172 | for fname, _ in self.imglist[:10]: 173 | fname = os.path.join(self.full_dir, fname) 174 | assert os.path.isfile(fname), fname 175 | 176 | def __len__(self): 177 | return len(self.imglist) 178 | 179 | def __iter__(self): 180 | idxs = np.arange(len(self.imglist)) 181 | if self.shuffle: 182 | self.rng.shuffle(idxs) 183 | for k in idxs: 184 | fname, label = self.imglist[k] 185 | fname = os.path.join(self.full_dir, fname) 186 | yield [fname, label] 187 | 188 | 189 | class ILSVRC12(ILSVRC12Files): 190 | """ 191 | Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999]. 192 | """ 193 | def __init__(self, dir, name, meta_dir=None, 194 | shuffle=None, dir_structure=None, labeled=True): 195 | """ 196 | Args: 197 | dir (str): A directory containing a subdir named ``name``, 198 | containing the images in a structure described below. 199 | name (str): One of 'train' or 'val' or 'test'. 200 | shuffle (bool): shuffle the dataset. 201 | Defaults to True if name=='train'. 202 | dir_structure (str): One of 'original' or 'train'. 203 | The directory structure for the 'val' directory. 204 | 'original' means the original decompressed directory, which only has list of image files (as below). 205 | If set to 'train', it expects the same two-level directory structure similar to 'dir/train/'. 206 | By default, it tries to automatically detect the structure. 207 | You probably do not need to care about this option because 'original' is what people usually have. 208 | 209 | Example: 210 | 211 | When `dir_structure=='original'`, `dir` should have the following structure: 212 | 213 | .. code-block:: none 214 | 215 | dir/ 216 | train/ 217 | n02134418/ 218 | n02134418_198.JPEG 219 | ... 220 | ... 221 | val/ 222 | ILSVRC2012_val_00000001.JPEG 223 | ... 224 | test/ 225 | ILSVRC2012_test_00000001.JPEG 226 | ... 227 | 228 | With the downloaded ILSVRC12_img_*.tar, you can use the following 229 | command to build the above structure: 230 | 231 | .. code-block:: none 232 | 233 | mkdir val && tar xvf ILSVRC12_img_val.tar -C val 234 | mkdir test && tar xvf ILSVRC12_img_test.tar -C test 235 | mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train 236 | find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' 237 | 238 | When `dir_structure=='train'`, `dir` should have the following structure: 239 | 240 | .. code-block:: none 241 | 242 | dir/ 243 | train/ 244 | n02134418/ 245 | n02134418_198.JPEG 246 | ... 247 | ... 248 | val/ 249 | n01440764/ 250 | ILSVRC2012_val_00000293.JPEG 251 | ... 252 | ... 253 | test/ 254 | ILSVRC2012_test_00000001.JPEG 255 | ... 256 | """ 257 | super(ILSVRC12, self).__init__( 258 | dir, name, meta_dir, shuffle, dir_structure, labeled=labeled) 259 | """ 260 | There are some CMYK / png images, but cv2 seems robust to them. 261 | https://github.com/tensorflow/models/blob/c0cd713f59cfe44fa049b3120c417cc4079c17e3/research/inception/inception/data/build_imagenet_data.py#L264-L300 262 | """ 263 | def __iter__(self): 264 | for fname, label in super(ILSVRC12, self).__iter__(): 265 | im = cv2.imread(fname, cv2.IMREAD_COLOR) 266 | assert im is not None, fname 267 | yield [im, label] 268 | 269 | @staticmethod 270 | def get_training_bbox(bbox_dir, imglist): 271 | import xml.etree.ElementTree as ET 272 | ret = [] 273 | 274 | def parse_bbox(fname): 275 | root = ET.parse(fname).getroot() 276 | size = root.find('size').getchildren() 277 | size = map(int, [size[0].text, size[1].text]) 278 | 279 | box = root.find('object').find('bndbox').getchildren() 280 | box = map(lambda x: float(x.text), box) 281 | return np.asarray(box, dtype='float32') 282 | 283 | with timed_operation('Loading Bounding Boxes ...'): 284 | cnt = 0 285 | for k in tqdm.trange(len(imglist)): 286 | fname = imglist[k][0] 287 | fname = fname[:-4] + 'xml' 288 | fname = os.path.join(bbox_dir, fname) 289 | try: 290 | ret.append(parse_bbox(fname)) 291 | cnt += 1 292 | except Exception: 293 | ret.append(None) 294 | logger.info("{}/{} images have bounding box.".format(cnt, len(imglist))) 295 | return ret 296 | 297 | 298 | try: 299 | import cv2 300 | except ImportError: 301 | from ...utils.develop import create_dummy_class 302 | ILSVRC12 = create_dummy_class('ILSVRC12', 'cv2') # noqa 303 | 304 | if __name__ == '__main__': 305 | meta = ILSVRCMeta() 306 | # print(meta.get_synset_words_1000()) 307 | 308 | ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False) 309 | ds.reset_state() 310 | 311 | for k in ds: 312 | from IPython import embed 313 | embed() 314 | break 315 | -------------------------------------------------------------------------------- /resnet/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: imagenet_utils.py 3 | 4 | 5 | import cv2 6 | import numpy as np 7 | import tqdm 8 | import multiprocessing 9 | import tensorflow as tf 10 | from abc import abstractmethod 11 | 12 | from tensorpack import * 13 | from tensorpack import ModelDesc 14 | from tensorpack.input_source import QueueInput, StagingInput 15 | from tensorpack.dataflow import ( 16 | JoinData, imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ, 17 | BatchData, MultiThreadMapData) 18 | from tensorpack.predict import PredictConfig, FeedfreePredictor 19 | from tensorpack.utils.stats import RatioCounter 20 | from tensorpack.models import regularize_cost 21 | from tensorpack.tfutils.summary import add_moving_summary 22 | from tensorpack.tfutils.common import get_global_step_var 23 | from tensorpack.utils import logger 24 | import ilsvrcsemi 25 | from flip_gradient import flip_gradient 26 | 27 | class GoogleNetResize(imgaug.ImageAugmentor): 28 | """ 29 | crop 8%~100% of the original image 30 | See `Going Deeper with Convolutions` by Google. 31 | """ 32 | def __init__(self, crop_area_fraction=0.08, 33 | aspect_ratio_low=0.75, aspect_ratio_high=1.333, 34 | target_shape=224): 35 | self._init(locals()) 36 | 37 | def _augment(self, img, _): 38 | h, w = img.shape[:2] 39 | area = h * w 40 | for _ in range(10): 41 | targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area 42 | aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high) 43 | ww = int(np.sqrt(targetArea * aspectR) + 0.5) 44 | hh = int(np.sqrt(targetArea / aspectR) + 0.5) 45 | if self.rng.uniform() < 0.5: 46 | ww, hh = hh, ww 47 | if hh <= h and ww <= w: 48 | x1 = 0 if w == ww else self.rng.randint(0, w - ww) 49 | y1 = 0 if h == hh else self.rng.randint(0, h - hh) 50 | out = img[y1:y1 + hh, x1:x1 + ww] 51 | out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC) 52 | return out 53 | out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img) 54 | out = imgaug.CenterCrop(self.target_shape).augment(out) 55 | return out 56 | 57 | 58 | def fbresnet_augmentor(isTrain): 59 | """ 60 | Augmentor used in fb.resnet.torch, for BGR images in range [0,255]. 61 | """ 62 | if isTrain: 63 | augmentors = [ 64 | GoogleNetResize(), 65 | # It's OK to remove the following augs if your CPU is not fast enough. 66 | # Removing brightness/contrast/saturation does not have a significant effect on accuracy. 67 | # Removing lighting leads to a tiny drop in accuracy. 68 | 69 | imgaug.RandomOrderAug( 70 | [# We removed the following augmentation 71 | #imgaug.BrightnessScale((0.6, 1.4), clip=False), 72 | #imgaug.Contrast((0.6, 1.4), clip=False), 73 | #imgaug.Saturation(0.4, rgb=False), 74 | #rgb-bgr conversion for the constants copied from fb.resnet.torch 75 | imgaug.Lighting(0.1, 76 | eigval=np.asarray( 77 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0, 78 | eigvec=np.array( 79 | [[-0.5675, 0.7192, 0.4009], 80 | [-0.5808, -0.0045, -0.8140], 81 | [-0.5836, -0.6948, 0.4203]], 82 | dtype='float32')[::-1, ::-1] 83 | )]), 84 | imgaug.Flip(horiz=True), 85 | ] 86 | else: 87 | augmentors = [ 88 | imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), 89 | imgaug.CenterCrop((224, 224)), 90 | ] 91 | return augmentors 92 | 93 | 94 | def get_imagenet_dataflow( 95 | datadir, name, batch_size, 96 | augmentors, parallel=None): 97 | """ 98 | See explanations in the tutorial: 99 | http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html 100 | """ 101 | assert name in ['train', 'val', 'test'] 102 | assert datadir is not None 103 | assert isinstance(augmentors, list) 104 | isTrain = name == 'train' 105 | if parallel is None: 106 | parallel = min(40, 16) # assuming hyperthreading 107 | if isTrain: 108 | ds1 = ilsvrcsemi.ILSVRC12(datadir, name, shuffle=True, labeled=True) 109 | ds2 = ilsvrcsemi.ILSVRC12(datadir, name, shuffle=True, labeled=False) 110 | ds1 = AugmentImageComponent(ds1, augmentors, copy=False) 111 | ds2 = AugmentImageComponent(ds2, augmentors, copy=False) 112 | ds = JoinData([ds1, ds2]) 113 | 114 | if parallel < 16: 115 | logger.warn("DataFlow may become the bottleneck when too few processes are used.") 116 | ds = PrefetchDataZMQ(ds, parallel) 117 | ds = BatchData(ds, batch_size, remainder=False) 118 | else: 119 | ds = dataset.ILSVRC12Files(datadir, name, shuffle=False) 120 | aug = imgaug.AugmentorList(augmentors) 121 | 122 | def mapf(dp): 123 | fname, cls = dp 124 | im = cv2.imread(fname, cv2.IMREAD_COLOR) 125 | im = aug.augment(im) 126 | return im, cls, im, cls 127 | ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) 128 | ds = BatchData(ds, batch_size, remainder=True) 129 | ds = PrefetchDataZMQ(ds, 1) 130 | return ds 131 | 132 | 133 | def eval_on_ILSVRC12(model, sessinit, dataflow): 134 | pred_config = PredictConfig( 135 | model=model, 136 | session_init=sessinit, 137 | input_names=['input', 'label', 'input2', 'label2'], 138 | output_names=['wrong-top1', 'wrong-top5'] 139 | ) 140 | acc1, acc5 = RatioCounter(), RatioCounter() 141 | 142 | # This does not have a visible improvement over naive predictor, 143 | # but will have an improvement if image_dtype is set to float32. 144 | pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(dataflow), device='/gpu:0')) 145 | for _ in tqdm.trange(dataflow.size()): 146 | top1, top5 = pred() 147 | batch_size = top1.shape[0] 148 | acc1.feed(top1.sum(), batch_size) 149 | acc5.feed(top5.sum(), batch_size) 150 | 151 | print("Top1 Error: {}".format(acc1.ratio)) 152 | print("Top5 Error: {}".format(acc5.ratio)) 153 | 154 | 155 | class ImageNetModel(ModelDesc): 156 | image_shape = 224 157 | 158 | """ 159 | uint8 instead of float32 is used as input type to reduce copy overhead. 160 | It might hurt the performance a liiiitle bit. 161 | The pretrained models were trained with float32. 162 | """ 163 | image_dtype = tf.uint8 164 | 165 | """ 166 | Either 'NCHW' or 'NHWC' 167 | """ 168 | data_format = 'NCHW' 169 | 170 | """ 171 | Whether the image is BGR or RGB. If using DataFlow, then it should be BGR. 172 | """ 173 | image_bgr = True 174 | 175 | weight_decay = 1e-4 176 | 177 | """ 178 | To apply on normalization parameters, use '.*/W|.*/gamma|.*/beta' 179 | """ 180 | weight_decay_pattern = '.*/W' 181 | 182 | """ 183 | Scale the loss, for whatever reasons (e.g., gradient averaging, fp16 training, etc) 184 | """ 185 | loss_scale = 1. 186 | 187 | """ 188 | Label smoothing (See tf.losses.softmax_cross_entropy) 189 | """ 190 | label_smoothing = 0. 191 | 192 | def inputs(self): 193 | return [tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'), 194 | tf.placeholder(tf.int32, [None], 'label'), 195 | tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input2'), 196 | tf.placeholder(tf.int32, [None], 'label2')] 197 | def build_graph(self, image1, label1, image2, _): 198 | image1 = self.image_preprocess(image1) 199 | image2 = self.image_preprocess(image2) 200 | is_training = get_current_tower_context().is_training 201 | 202 | # Shuffle unlabeled data within batch 203 | if is_training: 204 | image2 = tf.random_shuffle(image2) 205 | 206 | assert self.data_format in ['NCHW', 'NHWC'] 207 | if self.data_format == 'NCHW': 208 | image1 = tf.transpose(image1, [0, 3, 1, 2]) 209 | image2 = tf.transpose(image2, [0, 3, 1, 2]) 210 | 211 | # Pseudo Label 212 | logits2, _ = self.get_logits(image2) 213 | label2 = tf.nn.softmax(logits2) 214 | 215 | # Change this line if you modified training schedule or batchsize: 60 Epoch_num, 256 Batch_size 216 | k = tf.cast(get_global_step_var(), tf.float32) / (60 * 1280000 / 256) 217 | 218 | # Sample lambda 219 | dist_beta = tf.distributions.Beta(1.0, 1.0) 220 | lmb = dist_beta.sample(tf.shape(image1)[0]) 221 | lmb_x = tf.reshape(lmb, [-1, 1, 1, 1]) 222 | lmb_y = tf.reshape(lmb, [-1, 1]) 223 | 224 | # Interpolation 225 | label_ori = label1 226 | if is_training: 227 | image = tf.to_float(image1) * lmb_x + tf.to_float(image2) * (1. - lmb_x) 228 | label = tf.stop_gradient(tf.to_float(tf.one_hot(label1, 1000)) * lmb_y + tf.to_float(label2) * (1. - lmb_y)) 229 | else: 230 | image = image1 231 | label = tf.to_float(tf.one_hot(label1, 1000)) 232 | 233 | # Calculate feats and logits for interpolated samples 234 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 235 | logits, features = self.get_logits(image) 236 | 237 | # Classification Loss and error 238 | loss = ImageNetModel.compute_loss_and_error( 239 | logits, label, label_smoothing=self.label_smoothing, lmb=lmb, label_ori=label_ori) 240 | 241 | # Distribution Alignment 242 | lp = 2. / (1. + tf.exp(-10. * k)) - 1 243 | net_ = flip_gradient(features, lp) 244 | fc1 = FullyConnected('linear_1', net_, 1024, nl=tf.nn.relu) 245 | fc2 = FullyConnected('linear_2', fc1, 1024, nl=tf.nn.relu) 246 | domain_logits = FullyConnected("logits_dm", fc2, 2) 247 | label_dm = tf.concat([tf.reshape(lmb, [-1, 1]), tf.reshape(1. - lmb, [-1, 1])], axis=1) 248 | da_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label_dm, logits=domain_logits)) 249 | 250 | # Final Loss 251 | loss += da_cost 252 | 253 | 254 | if self.weight_decay > 0: 255 | wd_loss = regularize_cost(self.weight_decay_pattern, 256 | tf.contrib.layers.l2_regularizer(self.weight_decay), 257 | name='l2_regularize_loss') 258 | add_moving_summary(loss, wd_loss) 259 | total_cost = tf.add_n([loss, wd_loss], name='cost') 260 | else: 261 | total_cost = tf.identity(loss, name='cost') 262 | add_moving_summary(total_cost) 263 | 264 | if self.loss_scale != 1.: 265 | logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) 266 | return total_cost * self.loss_scale 267 | else: 268 | return total_cost 269 | 270 | @abstractmethod 271 | def get_logits(self, image): 272 | """ 273 | Args: 274 | image: 4D tensor of ``self.input_shape`` in ``self.data_format`` 275 | 276 | Returns: 277 | Nx#class logits 278 | """ 279 | 280 | def optimizer(self): 281 | lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False) 282 | tf.summary.scalar('learning_rate-summary', lr) 283 | return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) 284 | 285 | def image_preprocess(self, image): 286 | with tf.name_scope('image_preprocess'): 287 | if image.dtype.base_dtype != tf.float32: 288 | image = tf.cast(image, tf.float32) 289 | mean = [0.485, 0.456, 0.406] # rgb 290 | std = [0.229, 0.224, 0.225] 291 | if self.image_bgr: 292 | mean = mean[::-1] 293 | std = std[::-1] 294 | image_mean = tf.constant(mean, dtype=tf.float32) * 255. 295 | image_std = tf.constant(std, dtype=tf.float32) * 255. 296 | image = (image - image_mean) / image_std 297 | return image 298 | 299 | @staticmethod 300 | def compute_loss_and_error(logits, label, label_smoothing=0., lmb=1.,label_ori=-1): 301 | loss = lmb * tf.losses.softmax_cross_entropy( 302 | label, logits, label_smoothing=label_smoothing) 303 | loss = tf.reduce_mean(loss, name='xentropy-loss') 304 | 305 | def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): 306 | with tf.name_scope('prediction_incorrect'): 307 | x = tf.logical_not(tf.nn.in_top_k(logits, label, topk)) 308 | return tf.cast(x, tf.float32, name=name) 309 | 310 | wrong = prediction_incorrect(logits, label_ori, 1, name='wrong-top1') 311 | add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1')) 312 | 313 | wrong = prediction_incorrect(logits, label_ori, 5, name='wrong-top5') 314 | add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) 315 | return loss 316 | 317 | 318 | if __name__ == '__main__': 319 | import argparse 320 | from tensorpack.dataflow import TestDataSpeed 321 | parser = argparse.ArgumentParser() 322 | parser.add_argument('--data', required=True) 323 | parser.add_argument('--batch', type=int, default=32) 324 | parser.add_argument('--aug', choices=['train', 'val'], default='val') 325 | args = parser.parse_args() 326 | 327 | if args.aug == 'val': 328 | augs = fbresnet_augmentor(False) 329 | elif args.aug == 'train': 330 | augs = fbresnet_augmentor(True) 331 | df = get_imagenet_dataflow( 332 | args.data, 'train', args.batch, augs) 333 | # For val augmentor, Should get >100 it/s (i.e. 3k im/s) here on a decent E5 server. 334 | TestDataSpeed(df).start() 335 | --------------------------------------------------------------------------------