├── .gitignore ├── README.md ├── convert_data.py ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── aircraft.cpython-36.pyc │ ├── aircraft.cpython-37.pyc │ ├── bird.cpython-36.pyc │ ├── bird.cpython-37.pyc │ ├── car.cpython-36.pyc │ ├── car.cpython-37.pyc │ ├── convert_aircraft.cpython-36.pyc │ ├── convert_bird.cpython-36.pyc │ ├── convert_car.cpython-36.pyc │ ├── convert_dog.cpython-36.pyc │ ├── dataset_factory.cpython-36.pyc │ ├── dataset_factory.cpython-37.pyc │ ├── dataset_utils.cpython-36.pyc │ ├── dataset_utils.cpython-37.pyc │ ├── dog.cpython-36.pyc │ └── dog.cpython-37.pyc ├── aircraft.py ├── bird.py ├── car.py ├── convert_aircraft.py ├── convert_bird.py ├── convert_car.py ├── convert_dog.py ├── dataset_factory.py ├── dataset_utils.py └── dog.py ├── deployment ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── model_deploy.cpython-36.pyc │ └── model_deploy.cpython-37.pyc ├── model_deploy.py └── model_deploy_test.py ├── eval_labels.pkl ├── eval_logits.pkl ├── eval_sample.py ├── eval_sample_aircraft.sh ├── eval_sample_bird.sh ├── eval_sample_dog.sh ├── nets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── alexnet.cpython-36.pyc │ ├── alexnet.cpython-37.pyc │ ├── cifarnet.cpython-36.pyc │ ├── cifarnet.cpython-37.pyc │ ├── inception.cpython-36.pyc │ ├── inception.cpython-37.pyc │ ├── inception_resnet_v2.cpython-36.pyc │ ├── inception_resnet_v2.cpython-37.pyc │ ├── inception_utils.cpython-36.pyc │ ├── inception_utils.cpython-37.pyc │ ├── inception_v1.cpython-36.pyc │ ├── inception_v1.cpython-37.pyc │ ├── inception_v2.cpython-36.pyc │ ├── inception_v2.cpython-37.pyc │ ├── inception_v3.cpython-36.pyc │ ├── inception_v3.cpython-37.pyc │ ├── inception_v3_bap.cpython-36.pyc │ ├── inception_v3_bap.cpython-37.pyc │ ├── inception_v3_topk.cpython-36.pyc │ ├── inception_v3_topk.cpython-37.pyc │ ├── inception_v4.cpython-36.pyc │ ├── inception_v4.cpython-37.pyc │ ├── lenet.cpython-36.pyc │ ├── lenet.cpython-37.pyc │ ├── mobilenet_v1.cpython-36.pyc │ ├── mobilenet_v1.cpython-37.pyc │ ├── nets_factory.cpython-36.pyc │ ├── nets_factory.cpython-37.pyc │ ├── overfeat.cpython-36.pyc │ ├── overfeat.cpython-37.pyc │ ├── resnet_utils.cpython-36.pyc │ ├── resnet_utils.cpython-37.pyc │ ├── resnet_v1.cpython-36.pyc │ ├── resnet_v1.cpython-37.pyc │ ├── resnet_v2.cpython-36.pyc │ ├── resnet_v2.cpython-37.pyc │ ├── vgg.cpython-36.pyc │ └── vgg.cpython-37.pyc ├── alexnet.py ├── alexnet_test.py ├── cifarnet.py ├── cyclegan.py ├── cyclegan_test.py ├── dcgan.py ├── dcgan_test.py ├── inception.py ├── inception_resnet_v2.py ├── inception_resnet_v2_test.py ├── inception_utils.py ├── inception_v1.py ├── inception_v1_test.py ├── inception_v2.py ├── inception_v2_test.py ├── inception_v3.py ├── inception_v3_bap.py ├── inception_v3_test.py ├── inception_v3_topk.py ├── inception_v4.py ├── inception_v4_test.py ├── lenet.py ├── mobilenet │ ├── README.md │ ├── conv_blocks.py │ ├── madds_top1_accuracy.png │ ├── mnet_v1_vs_v2_pixel1_latency.png │ ├── mobilenet.py │ ├── mobilenet_example.ipynb │ ├── mobilenet_v2.py │ └── mobilenet_v2_test.py ├── mobilenet_v1.md ├── mobilenet_v1.png ├── mobilenet_v1.py ├── mobilenet_v1_eval.py ├── mobilenet_v1_test.py ├── mobilenet_v1_train.py ├── nasnet │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── nasnet.cpython-36.pyc │ │ ├── nasnet.cpython-37.pyc │ │ ├── nasnet_utils.cpython-36.pyc │ │ └── nasnet_utils.cpython-37.pyc │ ├── nasnet.py │ ├── nasnet_test.py │ ├── nasnet_utils.py │ ├── nasnet_utils_test.py │ ├── pnasnet.py │ └── pnasnet_test.py ├── nets_factory.py ├── nets_factory_test.py ├── overfeat.py ├── overfeat_test.py ├── pix2pix.py ├── pix2pix_test.py ├── resnet_utils.py ├── resnet_v1.py ├── resnet_v1_test.py ├── resnet_v2.py ├── resnet_v2_test.py ├── vgg.py └── vgg_test.py ├── num_bboxes.pkl ├── preprocessing ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── cifarnet_preprocessing.cpython-36.pyc │ ├── cifarnet_preprocessing.cpython-37.pyc │ ├── inception_preprocessing.cpython-36.pyc │ ├── inception_preprocessing.cpython-37.pyc │ ├── lenet_preprocessing.cpython-36.pyc │ ├── lenet_preprocessing.cpython-37.pyc │ ├── preprocessing_factory.cpython-36.pyc │ ├── preprocessing_factory.cpython-37.pyc │ ├── vgg_preprocessing.cpython-36.pyc │ └── vgg_preprocessing.cpython-37.pyc ├── cifarnet_preprocessing.py ├── inception_preprocessing.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── vgg_preprocessing.py ├── requirements.txt ├── setup.py ├── train_sample.py ├── train_sample_aircraft.sh ├── train_sample_bird.sh ├── train_sample_dog.sh └── utils ├── __pycache__ ├── lstm.cpython-36.pyc ├── lstm.cpython-37.pyc ├── utils.cpython-36.pyc └── utils.cpython-37.pyc ├── lstm.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Datasets 2 | Bird/ 3 | Dog/ 4 | Aircraft/ 5 | # Dictionary 6 | data/ 7 | # Pretrained models 8 | pre_trained/ 9 | # Virtual Environment 10 | venv/ 11 | .idea/ 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-Grained Visual Classification using Self Assessment Classifier 2 | 3 | ## Prerequisites 4 | 5 | PYTHON 3.7 version 6 | 7 | CUDA 11.0 version 8 | ``` 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Datasets 13 | 14 | - Download CUB-200-2011 dataset (tfrecords) at [link](https://vision.aioz.io/f/da1b0ade1c4e4acfbd08/?dl=1) and extract them into `Bird/Data` folder. 15 | - Download FGVC AIRCRAFT dataset (tfrecords) at [link](https://vision.aioz.io/f/cd9a1940688c4795bbdc/?dl=1) and extract them into `Aircraft/Data` folder. 16 | - Download STANFORD DOGS dataset at [link](http://vision.stanford.edu/aditya86/ImageNetDogs/), then convert them into *tfrecords* format and put into `Dog/Data` folder. 17 | 18 | #### Dictionary 19 | 20 | - Download data dictionary at [link](https://vision.aioz.io/f/5e39ee074cdd446ca9b2/?dl=1) and extract them into `data` folder. 21 | 22 | ## Training 23 | 24 | Please download pretrained backbone of WS_DAN at [link](https://vision.aioz.io/f/be3af5363b9a425cbc7f/?dl=1) and extract them into `pre_trained` folder. 25 | - To train our method on CUB-200-2011 dataset, please run: 26 | ``` 27 | bash train_sample_bird.sh 28 | ``` 29 | - To train our method on FGVC AIRCRAFT dataset, please run: 30 | ``` 31 | bash train_sample_aircraft.sh 32 | ``` 33 | - To train our method on STANFORD DOGS dataset, please run: 34 | ``` 35 | bash train_sample_dog.sh 36 | ``` 37 | 38 | ## Testing 39 | 40 | #### Evaluate 41 | 42 | - To evaluate our method on CUB-200-2011 dataset, please run: 43 | ``` 44 | bash eval_sample_bird.sh 45 | ``` 46 | - To evaluate our method on FGVC AIRCRAFT dataset, please run: 47 | ``` 48 | bash eval_sample_aircraft.sh 49 | ``` 50 | - To evaluate our method on STANFORD DOGS dataset, please run: 51 | ``` 52 | bash eval_sample_dog.sh 53 | ``` 54 | #### Pretrained model 55 | 56 | We provide the pretrained model of SAC integrated in WS_DAN on CUB-200-2011 dataset. 57 | - Download our pretrained weights at [link](https://vision.aioz.io/f/97aa30aca9a74ca58bd8/?dl=1) and extract them into `Bird/SAC/TRAIN/Bird` folder. 58 | 59 | ## Citation 60 | 61 | If you use this code as part of any published research, we'd really appreciate it if you could cite the following paper: 62 | 63 | ``` 64 | Updating 65 | ``` 66 | 67 | ## License 68 | 69 | MIT License 70 | 71 | -------------------------------------------------------------------------------- /convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import convert_bird 40 | from datasets import convert_aircraft 41 | from datasets import convert_dog 42 | from datasets import convert_car 43 | 44 | FLAGS = tf.app.flags.FLAGS 45 | 46 | tf.app.flags.DEFINE_string( 47 | 'dataset_name', 48 | 'Bird', 49 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 50 | 51 | tf.app.flags.DEFINE_string( 52 | 'dataset_dir', 53 | './Bird/Data', 54 | 'The directory where the output TFRecords and temporary files are saved.') 55 | 56 | def main(_): 57 | if not FLAGS.dataset_name: 58 | raise ValueError('You must supply the dataset name with --dataset_name') 59 | if not FLAGS.dataset_dir: 60 | raise ValueError('You must supply the dataset directory with --dataset_dir') 61 | 62 | if FLAGS.dataset_name == 'Bird': 63 | convert_bird.run(FLAGS.dataset_dir) 64 | elif FLAGS.dataset_name == 'Aircraft': 65 | convert_aircraft.run(FLAGS.dataset_dir) 66 | elif FLAGS.dataset_name == 'Car': 67 | convert_car.run(FLAGS.dataset_dir) 68 | elif FLAGS.dataset_name == 'Dog': 69 | convert_dog.run(FLAGS.dataset_dir) 70 | else: 71 | raise ValueError( 72 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) 73 | 74 | 75 | if __name__ == '__main__': 76 | tf.app.run() 77 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/aircraft.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/aircraft.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/aircraft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/aircraft.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bird.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/bird.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bird.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/bird.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/car.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/car.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/car.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/car.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/convert_aircraft.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/convert_aircraft.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/convert_bird.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/convert_bird.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/convert_car.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/convert_car.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/convert_dog.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/convert_dog.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/dataset_factory.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/dataset_factory.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/dataset_utils.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/dataset_utils.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dog.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/dog.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dog.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/datasets/__pycache__/dog.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/aircraft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the fgvc dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_fgvc.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'Aircraft_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 6667, 'test': 3333} 35 | 36 | _NUM_CLASSES = 100 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading fgvc. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/bird.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the cub dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_cub.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'Bird_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 5994, 'test': 5794} 35 | 36 | _NUM_CLASSES = 200 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cub. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)) 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label') 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/car.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the fgvc dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_fgvc.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'Car_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 8144, 'test': 8041} 35 | 36 | _NUM_CLASSES = 196 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading fgvc. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)) 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/convert_aircraft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts cub data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the cub data, uncompresses it, reads the files 18 | that make up the cub data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take about a minute to run. 23 | 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import math 31 | import os 32 | import random 33 | import sys 34 | import numpy as np 35 | 36 | import tensorflow as tf 37 | 38 | from datasets import dataset_utils 39 | 40 | 41 | # Seed for repeatability. 42 | _RANDOM_SEED = 0 43 | 44 | # The number of shards per dataset split. 45 | _NUM_SHARDS = 5 46 | 47 | 48 | class ImageReader(object): 49 | """Helper class that provides TensorFlow image coding utilities.""" 50 | 51 | def __init__(self): 52 | # Initializes function that decodes RGB JPEG data. 53 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 54 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 55 | 56 | def read_image_dims(self, sess, image_data): 57 | image = self.decode_jpeg(sess, image_data) 58 | return image.shape[0], image.shape[1] 59 | 60 | def decode_jpeg(self, sess, image_data): 61 | image = sess.run(self._decode_jpeg, 62 | feed_dict={self._decode_jpeg_data: image_data}) 63 | assert len(image.shape) == 3 64 | assert image.shape[2] == 3 65 | return image 66 | 67 | 68 | def _get_filenames_and_classes(dataset_dir): 69 | """Returns a list of filenames and inferred class names. 70 | 71 | Args: 72 | dataset_dir: A directory containing a set of subdirectories representing 73 | class names. Each subdirectory should contain PNG or JPG encoded images. 74 | 75 | Returns: 76 | A list of image file paths, relative to `dataset_dir` and the list of 77 | subdirectories, representing class names. 78 | """ 79 | images_root = os.path.join(dataset_dir, 'images') 80 | directories = [] 81 | class_names = [] 82 | for filename in os.listdir(images_root): 83 | path = os.path.join(images_root, filename) 84 | if os.path.isdir(path): 85 | directories.append(path) 86 | class_names.append(filename) 87 | 88 | photo_filenames = [] 89 | for directory in directories: 90 | for filename in os.listdir(directory): 91 | path = os.path.join(directory, filename) 92 | photo_filenames.append(path) 93 | 94 | return photo_filenames, sorted(class_names) 95 | 96 | 97 | def _get_dataset_filename(dataset_dir, split_name, shard_id): 98 | output_filename = 'Aircraft_%s_%05d-of-%05d.tfrecord' % ( 99 | split_name, shard_id, _NUM_SHARDS) 100 | if not os.path.exists(os.path.join(dataset_dir, 'tfrecords')): 101 | os.makedirs(os.path.join(dataset_dir, 'tfrecords')) 102 | return os.path.join(dataset_dir, 'tfrecords', output_filename) 103 | 104 | 105 | def _convert_dataset(split_name, datasets, dataset_dir): 106 | """Converts the given filenames to a TFRecord dataset. 107 | 108 | Args: 109 | split_name: The name of the dataset, either 'train' or 'validation'. 110 | filenames: A list of absolute paths to png or jpg images. 111 | class_names_to_ids: A dictionary from class names (strings) to ids 112 | (integers). 113 | dataset_dir: The directory where the converted datasets are stored. 114 | """ 115 | assert split_name in ['train', 'test'] 116 | 117 | num_per_shard = int(math.ceil(len(datasets) / float(_NUM_SHARDS))) 118 | 119 | with tf.Graph().as_default(): 120 | image_reader = ImageReader() 121 | 122 | config = tf.ConfigProto( 123 | allow_soft_placement=True, 124 | log_device_placement=False) 125 | config.gpu_options.allow_growth = True 126 | # config.gpu_options.visible_device_list='1' 127 | with tf.Session(config=config) as sess: 128 | for shard_id in range(_NUM_SHARDS): 129 | output_filename = _get_dataset_filename( 130 | dataset_dir, split_name, shard_id) 131 | 132 | with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: 133 | start_ndx = shard_id * num_per_shard 134 | end_ndx = min((shard_id+1) * num_per_shard, len(datasets)) 135 | for i in range(start_ndx, end_ndx): 136 | sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( 137 | i+1, len(datasets), shard_id)) 138 | sys.stdout.flush() 139 | 140 | # Read the filename: 141 | image_data = tf.gfile.FastGFile(datasets[i]['filename'], 'rb').read() 142 | height, width = image_reader.read_image_dims(sess, image_data) 143 | 144 | class_id = datasets[i]['label'] 145 | 146 | example = dataset_utils.image_to_tfexample( 147 | image_data, b'jpg', height, width, class_id) 148 | tfrecord_writer.write(example.SerializeToString()) 149 | 150 | sys.stdout.write('\n') 151 | sys.stdout.flush() 152 | 153 | 154 | def generate_datasets(data_root): 155 | train_info = np.loadtxt(os.path.join(data_root, 'fgvc-aircraft-2013b/data', 'images_variant_trainval.txt'), str) 156 | test_info = np.loadtxt(os.path.join(data_root, 'fgvc-aircraft-2013b/data', 'images_variant_test.txt'), str) 157 | category_info = np.loadtxt(os.path.join(data_root, 'fgvc-aircraft-2013b/data', 'variants.txt'), str) 158 | 159 | train_dataset = [] 160 | test_dataset = [] 161 | for index in range(len(train_info)): 162 | images_file = os.path.join(data_root, 'fgvc-aircraft-2013b/data/images', train_info[index, 0] + '.jpg') 163 | category = train_info[index, 1] 164 | label = np.where(category_info == category)[0][0] 165 | 166 | example = {} 167 | example['filename'] = images_file 168 | example['label'] = int(label) 169 | train_dataset.append(example) 170 | 171 | for index in range(len(test_info)): 172 | images_file = os.path.join(data_root, 'fgvc-aircraft-2013b/data/images', test_info[index, 0] + '.jpg') 173 | category = test_info[index, 1] 174 | label = np.where(category_info == category)[0][0] 175 | 176 | example = {} 177 | example['filename'] = images_file 178 | example['label'] = int(label) 179 | test_dataset.append(example) 180 | 181 | return train_dataset, test_dataset 182 | 183 | 184 | def run(dataset_dir): 185 | """Runs the download and conversion operation. 186 | 187 | Args: 188 | dataset_dir: The dataset directory where the dataset is stored. 189 | """ 190 | if not tf.gfile.Exists(dataset_dir): 191 | tf.gfile.MakeDirs(dataset_dir) 192 | 193 | # Divide into train and test: 194 | random.seed(_RANDOM_SEED) 195 | 196 | train_dataset, test_dataset = generate_datasets(dataset_dir) 197 | 198 | random.shuffle(train_dataset) 199 | random.shuffle(test_dataset) 200 | 201 | # First, convert the training and test sets. 202 | _convert_dataset('train', train_dataset, dataset_dir) 203 | _convert_dataset('test', test_dataset, dataset_dir) 204 | 205 | # _clean_up_temporary_files(dataset_dir) 206 | print('\nFinished converting the fgvc dataset!') 207 | -------------------------------------------------------------------------------- /datasets/convert_bird.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts Bird data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the Bird data, uncompresses it, reads the files 18 | that make up the Bird data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take about a minute to run. 23 | 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import math 31 | import os 32 | import random 33 | import sys 34 | import numpy as np 35 | 36 | import tensorflow as tf 37 | 38 | from datasets import dataset_utils 39 | 40 | # Seed for repeatability. 41 | _RANDOM_SEED = 0 42 | # The number of shards per dataset split. 43 | _NUM_SHARDS = 5 44 | 45 | 46 | class ImageReader(object): 47 | """Helper class that provides TensorFlow image coding utilities.""" 48 | 49 | def __init__(self): 50 | # Initializes function that decodes RGB JPEG data. 51 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 52 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 53 | 54 | def read_image_dims(self, sess, image_data): 55 | image = self.decode_jpeg(sess, image_data) 56 | return image.shape[0], image.shape[1] 57 | 58 | def decode_jpeg(self, sess, image_data): 59 | image = sess.run(self._decode_jpeg, 60 | feed_dict={self._decode_jpeg_data: image_data}) 61 | assert len(image.shape) == 3 62 | assert image.shape[2] == 3 63 | return image 64 | 65 | 66 | def _get_filenames_and_classes(dataset_dir): 67 | """Returns a list of filenames and inferred class names. 68 | 69 | Args: 70 | dataset_dir: A directory containing a set of subdirectories representing 71 | class names. Each subdirectory should contain PNG or JPG encoded images. 72 | 73 | Returns: 74 | A list of image file paths, relative to `dataset_dir` and the list of 75 | subdirectories, representing class names. 76 | """ 77 | Bird_root = os.path.join(dataset_dir, 'images') 78 | directories = [] 79 | class_names = [] 80 | for filename in os.listdir(Bird_root): 81 | path = os.path.join(Bird_root, filename) 82 | if os.path.isdir(path): 83 | directories.append(path) 84 | class_names.append(filename) 85 | 86 | photo_filenames = [] 87 | for directory in directories: 88 | for filename in os.listdir(directory): 89 | path = os.path.join(directory, filename) 90 | photo_filenames.append(path) 91 | 92 | return photo_filenames, sorted(class_names) 93 | 94 | 95 | def _get_dataset_filename(dataset_dir, split_name, shard_id): 96 | output_filename = 'Bird_%s_%05d-of-%05d.tfrecord' % ( 97 | split_name, shard_id, _NUM_SHARDS) 98 | if not os.path.exists(os.path.join(dataset_dir, 'tfrecords')): 99 | os.makedirs(os.path.join(dataset_dir, 'tfrecords')) 100 | return os.path.join(dataset_dir, 'tfrecords', output_filename) 101 | 102 | 103 | def _convert_dataset(split_name, dataset, dataset_dir): 104 | """Converts the given filenames to a TFRecord dataset. 105 | 106 | Args: 107 | split_name: The name of the dataset, either 'train' or 'testing'. 108 | filenames: A list of absolute paths to png or jpg images. 109 | class_names_to_ids: A dictionary from class names (strings) to ids 110 | (integers). 111 | dataset_dir: The directory where the converted datasets are stored. 112 | """ 113 | assert split_name in ['train', 'test'] 114 | 115 | num_per_shard = int(math.ceil(len(dataset) / float(_NUM_SHARDS))) 116 | 117 | with tf.Graph().as_default(): 118 | image_reader = ImageReader() 119 | 120 | config = tf.ConfigProto( 121 | allow_soft_placement=True, 122 | log_device_placement=False) 123 | config.gpu_options.allow_growth = True 124 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 125 | 126 | with tf.Session(config=config) as sess: 127 | 128 | for shard_id in range(_NUM_SHARDS): 129 | output_filename = _get_dataset_filename( 130 | dataset_dir, split_name, shard_id) 131 | 132 | with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: 133 | start_ndx = shard_id * num_per_shard 134 | end_ndx = min((shard_id+1) * num_per_shard, len(dataset)) 135 | for i in range(start_ndx, end_ndx): 136 | sys.stdout.write('\r>> Converting %s image %d/%d shard %d' % (split_name, 137 | i+1, len(dataset), shard_id)) 138 | sys.stdout.flush() 139 | 140 | # Read the filename: 141 | image_data = tf.gfile.FastGFile(dataset[i]['filename'], 'rb').read() 142 | height, width = image_reader.read_image_dims(sess, image_data) 143 | 144 | label = dataset[i]['label'] 145 | 146 | example = dataset_utils.image_to_tfexample( 147 | image_data, b'jpg', height, width, label) 148 | tfrecord_writer.write(example.SerializeToString()) 149 | 150 | sys.stdout.write('\n') 151 | sys.stdout.flush() 152 | 153 | 154 | def generate_datasets(data_root): 155 | train_test = np.loadtxt(os.path.join(data_root, 'train_test_split.txt'), int) 156 | images_files = np.loadtxt(os.path.join(data_root, 'images.txt'), str) 157 | labels = np.loadtxt(os.path.join(data_root, 'image_class_labels.txt'), int) - 1 158 | # parts = np.loadtxt(os.path.join(data_root, 'parts', 'part_locs.txt'), float) 159 | # parts = np.reshape(parts, [-1, 15, parts.shape[-1]]) 160 | # 161 | # bboxes = np.loadtxt(os.path.join(data_root, 'bounding_boxes.txt'), float) 162 | 163 | train_dataset = [] 164 | test_dataset = [] 165 | 166 | # train_index = 0 167 | # eval_index = 0 168 | for index in range(len(images_files)): 169 | images_file = images_files[index, 1] 170 | is_training = train_test[index, 1] 171 | label = labels[index, 1] 172 | # part = np.reshape(parts[index, :, 2:4], [-1]).tolist() 173 | # exist = parts[index, :, 4].astype(np.int64).tolist() 174 | # bbox = bboxes[index, 1:].tolist() 175 | 176 | example = {} 177 | example['filename'] = os.path.join(data_root, 'images', images_file) 178 | example['label'] = label 179 | # example['part'] = part 180 | # example['exist'] = exist 181 | # example['bbox'] = bbox 182 | 183 | if is_training: 184 | train_dataset.append(example) 185 | else: 186 | test_dataset.append(example) 187 | 188 | return train_dataset, test_dataset 189 | 190 | 191 | def run(dataset_dir): 192 | """Runs the download and conversion operation. 193 | 194 | Args: 195 | dataset_dir: The dataset directory where the dataset is stored. 196 | """ 197 | if not tf.gfile.Exists(dataset_dir): 198 | tf.gfile.MakeDirs(dataset_dir) 199 | 200 | # Divide into train and test: 201 | random.seed(_RANDOM_SEED) 202 | 203 | train_dataset, test_dataset = generate_datasets(dataset_dir) 204 | 205 | random.shuffle(train_dataset) 206 | random.shuffle(test_dataset) 207 | 208 | # First, convert the training and testing sets. 209 | _convert_dataset('train', train_dataset, dataset_dir) 210 | _convert_dataset('test', test_dataset, dataset_dir) 211 | 212 | # Finally, write the labels file: 213 | # labels_to_class_names = dict(zip(range(len(class_names)), class_names)) 214 | # dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 215 | 216 | # _clean_up_temporary_files(dataset_dir) 217 | print('\nFinished converting the Bird dataset!') 218 | -------------------------------------------------------------------------------- /datasets/convert_car.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts cub data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the cub data, uncompresses it, reads the files 18 | that make up the cub data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take about a minute to run. 23 | 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import math 31 | import os 32 | import random 33 | import sys 34 | import tensorflow as tf 35 | import scipy.io as sio 36 | 37 | from datasets import dataset_utils 38 | 39 | 40 | # Seed for repeatability. 41 | _RANDOM_SEED = 0 42 | 43 | # The number of shards per dataset split. 44 | _NUM_SHARDS = 5 45 | 46 | 47 | class ImageReader(object): 48 | """Helper class that provides TensorFlow image coding utilities.""" 49 | 50 | def __init__(self): 51 | # Initializes function that decodes RGB JPEG data. 52 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 53 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 54 | 55 | def read_image_dims(self, sess, image_data): 56 | image = self.decode_jpeg(sess, image_data) 57 | return image.shape[0], image.shape[1] 58 | 59 | def decode_jpeg(self, sess, image_data): 60 | image = sess.run(self._decode_jpeg, 61 | feed_dict={self._decode_jpeg_data: image_data}) 62 | assert len(image.shape) == 3 63 | assert image.shape[2] == 3 64 | return image 65 | 66 | 67 | def _get_filenames_and_classes(dataset_dir): 68 | """Returns a list of filenames and inferred class names. 69 | 70 | Args: 71 | dataset_dir: A directory containing a set of subdirectories representing 72 | class names. Each subdirectory should contain PNG or JPG encoded images. 73 | 74 | Returns: 75 | A list of image file paths, relative to `dataset_dir` and the list of 76 | subdirectories, representing class names. 77 | """ 78 | images_root = os.path.join(dataset_dir, 'images') 79 | directories = [] 80 | class_names = [] 81 | for filename in os.listdir(images_root): 82 | path = os.path.join(images_root, filename) 83 | if os.path.isdir(path): 84 | directories.append(path) 85 | class_names.append(filename) 86 | 87 | photo_filenames = [] 88 | for directory in directories: 89 | for filename in os.listdir(directory): 90 | path = os.path.join(directory, filename) 91 | photo_filenames.append(path) 92 | 93 | return photo_filenames, sorted(class_names) 94 | 95 | 96 | def _get_dataset_filename(dataset_dir, split_name, shard_id): 97 | output_filename = 'Car_%s_%05d-of-%05d.tfrecord' % ( 98 | split_name, shard_id, _NUM_SHARDS) 99 | if not os.path.exists(os.path.join(dataset_dir, 'tfrecords')): 100 | os.makedirs(os.path.join(dataset_dir, 'tfrecords')) 101 | return os.path.join(dataset_dir, 'tfrecords', output_filename) 102 | 103 | 104 | def _convert_dataset(split_name, datasets, dataset_dir): 105 | """Converts the given filenames to a TFRecord dataset. 106 | 107 | Args: 108 | split_name: The name of the dataset, either 'train' or 'validation'. 109 | filenames: A list of absolute paths to png or jpg images. 110 | class_names_to_ids: A dictionary from class names (strings) to ids 111 | (integers). 112 | dataset_dir: The directory where the converted datasets are stored. 113 | """ 114 | assert split_name in ['train', 'test'] 115 | 116 | num_per_shard = int(math.ceil(len(datasets) / float(_NUM_SHARDS))) 117 | 118 | with tf.Graph().as_default(): 119 | image_reader = ImageReader() 120 | 121 | config = tf.ConfigProto( 122 | allow_soft_placement=True, 123 | log_device_placement=False) 124 | config.gpu_options.allow_growth = True 125 | # config.gpu_options.visible_device_list='3' 126 | with tf.Session(config=config) as sess: 127 | 128 | for shard_id in range(_NUM_SHARDS): 129 | output_filename = _get_dataset_filename( 130 | dataset_dir, split_name, shard_id) 131 | 132 | with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: 133 | start_ndx = shard_id * num_per_shard 134 | end_ndx = min((shard_id+1) * num_per_shard, len(datasets)) 135 | for i in range(start_ndx, end_ndx): 136 | sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( 137 | i+1, len(datasets), shard_id)) 138 | sys.stdout.flush() 139 | 140 | # Read the filename: 141 | image_data = tf.gfile.FastGFile(datasets[i]['filename'], 'rb').read() 142 | height, width = image_reader.read_image_dims(sess, image_data) 143 | 144 | class_id = datasets[i]['label'] 145 | 146 | example = dataset_utils.image_to_tfexample( 147 | image_data, b'jpg', height, width, class_id) 148 | tfrecord_writer.write(example.SerializeToString()) 149 | 150 | sys.stdout.write('\n') 151 | sys.stdout.flush() 152 | 153 | 154 | def generate_datasets(data_root): 155 | train_info = sio.loadmat(os.path.join(data_root, 'devkit', 'cars_train_annos.mat'))['annotations'][0] 156 | test_info = sio.loadmat(os.path.join(data_root, 'devkit', 'cars_test_annos_withlabels.mat'))['annotations'][0] 157 | 158 | train_dataset = [] 159 | test_dataset = [] 160 | 161 | for index in range(len(train_info)): 162 | images_file = str(train_info['fname'][index][0]) 163 | label = train_info['class'][index][0][0] - 1 164 | 165 | example = {} 166 | example['filename'] = os.path.join(data_root, 'cars_train', images_file) 167 | example['label'] = int(label) 168 | train_dataset.append(example) 169 | 170 | for index in range(len(test_info)): 171 | images_file = str(test_info['fname'][index][0]) 172 | label = test_info['class'][index][0][0] - 1 173 | 174 | example = {} 175 | example['filename'] = os.path.join(data_root, 'cars_test', images_file) 176 | example['label'] = int(label) 177 | test_dataset.append(example) 178 | return train_dataset, test_dataset 179 | 180 | 181 | def run(dataset_dir): 182 | """Runs the download and conversion operation. 183 | 184 | Args: 185 | dataset_dir: The dataset directory where the dataset is stored. 186 | """ 187 | if not tf.gfile.Exists(dataset_dir): 188 | tf.gfile.MakeDirs(dataset_dir) 189 | 190 | # Divide into train and test: 191 | random.seed(_RANDOM_SEED) 192 | 193 | train_dataset, test_dataset = generate_datasets(dataset_dir) 194 | 195 | random.shuffle(train_dataset) 196 | random.shuffle(test_dataset) 197 | 198 | # First, convert the training and test sets. 199 | _convert_dataset('train', train_dataset, dataset_dir) 200 | _convert_dataset('test', test_dataset, dataset_dir) 201 | # _clean_up_temporary_files(dataset_dir) 202 | print('\nFinished converting the dog dataset!') 203 | -------------------------------------------------------------------------------- /datasets/convert_dog.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts cub data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the cub data, uncompresses it, reads the files 18 | that make up the cub data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take about a minute to run. 23 | 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import math 31 | import os 32 | import random 33 | import sys 34 | import numpy as np 35 | 36 | import tensorflow as tf 37 | import scipy.io as sio 38 | 39 | from datasets import dataset_utils 40 | 41 | 42 | # Seed for repeatability. 43 | _RANDOM_SEED = 0 44 | 45 | # The number of shards per dataset split. 46 | _NUM_SHARDS = 5 47 | 48 | 49 | class ImageReader(object): 50 | """Helper class that provides TensorFlow image coding utilities.""" 51 | 52 | def __init__(self): 53 | # Initializes function that decodes RGB JPEG data. 54 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 55 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 56 | 57 | def read_image_dims(self, sess, image_data): 58 | image = self.decode_jpeg(sess, image_data) 59 | return image.shape[0], image.shape[1] 60 | 61 | def decode_jpeg(self, sess, image_data): 62 | image = sess.run(self._decode_jpeg, 63 | feed_dict={self._decode_jpeg_data: image_data}) 64 | assert len(image.shape) == 3 65 | assert image.shape[2] == 3 66 | return image 67 | 68 | 69 | def _get_filenames_and_classes(dataset_dir): 70 | """Returns a list of filenames and inferred class names. 71 | 72 | Args: 73 | dataset_dir: A directory containing a set of subdirectories representing 74 | class names. Each subdirectory should contain PNG or JPG encoded images. 75 | 76 | Returns: 77 | A list of image file paths, relative to `dataset_dir` and the list of 78 | subdirectories, representing class names. 79 | """ 80 | images_root = os.path.join(dataset_dir, 'images') 81 | directories = [] 82 | class_names = [] 83 | for filename in os.listdir(images_root): 84 | path = os.path.join(images_root, filename) 85 | if os.path.isdir(path): 86 | directories.append(path) 87 | class_names.append(filename) 88 | 89 | photo_filenames = [] 90 | for directory in directories: 91 | for filename in os.listdir(directory): 92 | path = os.path.join(directory, filename) 93 | photo_filenames.append(path) 94 | 95 | return photo_filenames, sorted(class_names) 96 | 97 | 98 | def _get_dataset_filename(dataset_dir, split_name, shard_id): 99 | output_filename = 'Dog_%s_%05d-of-%05d.tfrecord' % ( 100 | split_name, shard_id, _NUM_SHARDS) 101 | if not os.path.exists(os.path.join(dataset_dir, 'tfrecords')): 102 | os.makedirs(os.path.join(dataset_dir, 'tfrecords')) 103 | return os.path.join(dataset_dir, 'tfrecords', output_filename) 104 | 105 | 106 | def _convert_dataset(split_name, datasets, dataset_dir): 107 | """Converts the given filenames to a TFRecord dataset. 108 | 109 | Args: 110 | split_name: The name of the dataset, either 'train' or 'validation'. 111 | filenames: A list of absolute paths to png or jpg images. 112 | class_names_to_ids: A dictionary from class names (strings) to ids 113 | (integers). 114 | dataset_dir: The directory where the converted datasets are stored. 115 | """ 116 | assert split_name in ['train', 'test'] 117 | 118 | num_per_shard = int(math.ceil(len(datasets) / float(_NUM_SHARDS))) 119 | 120 | with tf.Graph().as_default(): 121 | image_reader = ImageReader() 122 | 123 | config = tf.ConfigProto( 124 | allow_soft_placement=True, 125 | log_device_placement=False) 126 | config.gpu_options.allow_growth = True 127 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 128 | config.gpu_options.visible_device_list='1' 129 | with tf.Session(config=config) as sess: 130 | 131 | for shard_id in range(_NUM_SHARDS): 132 | output_filename = _get_dataset_filename( 133 | dataset_dir, split_name, shard_id) 134 | 135 | with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: 136 | start_ndx = shard_id * num_per_shard 137 | end_ndx = min((shard_id+1) * num_per_shard, len(datasets)) 138 | for i in range(start_ndx, end_ndx): 139 | sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( 140 | i+1, len(datasets), shard_id)) 141 | sys.stdout.flush() 142 | 143 | # Read the filename: 144 | image_data = tf.gfile.FastGFile(datasets[i]['filename'], 'rb').read() 145 | height, width = image_reader.read_image_dims(sess, image_data) 146 | class_id = datasets[i]['label'] 147 | example = dataset_utils.image_to_tfexample( 148 | image_data, b'jpg', height, width, class_id) 149 | tfrecord_writer.write(example.SerializeToString()) 150 | 151 | sys.stdout.write('\n') 152 | sys.stdout.flush() 153 | 154 | 155 | def generate_datasets(data_root): 156 | train_info = sio.loadmat(os.path.join(data_root, 'train_list.mat'))['file_list'] 157 | test_info = sio.loadmat(os.path.join(data_root, 'test_list.mat'))['file_list'] 158 | 159 | class_names = os.listdir(os.path.join(data_root, 'Images')) 160 | class_names.sort() 161 | 162 | train_dataset = [] 163 | test_dataset = [] 164 | 165 | for index in range(len(train_info)): 166 | images_file = str(train_info[index][0][0]) 167 | label_name = images_file.split('/')[0] 168 | label = class_names.index(label_name) 169 | 170 | example = {} 171 | example['filename'] = os.path.join(data_root, 'Images', images_file) 172 | example['label'] = int(label) 173 | train_dataset.append(example) 174 | 175 | for index in range(len(test_info)): 176 | images_file = str(test_info[index][0][0]) 177 | label_name = images_file.split('/')[0] 178 | label = class_names.index(label_name) 179 | 180 | example = {} 181 | example['filename'] = os.path.join(data_root, 'Images', images_file) 182 | example['label'] = int(label) 183 | test_dataset.append(example) 184 | 185 | return train_dataset, test_dataset 186 | 187 | 188 | def run(dataset_dir): 189 | """Runs the download and conversion operation. 190 | 191 | Args: 192 | dataset_dir: The dataset directory where the dataset is stored. 193 | """ 194 | if not tf.gfile.Exists(dataset_dir): 195 | tf.gfile.MakeDirs(dataset_dir) 196 | 197 | # Divide into train and test: 198 | random.seed(_RANDOM_SEED) 199 | 200 | train_dataset, test_dataset = generate_datasets(dataset_dir) 201 | 202 | random.shuffle(train_dataset) 203 | random.shuffle(test_dataset) 204 | 205 | # First, convert the training and test sets. 206 | _convert_dataset('train', train_dataset, 207 | dataset_dir) 208 | _convert_dataset('test', test_dataset, 209 | dataset_dir) 210 | 211 | # _clean_up_temporary_files(dataset_dir) 212 | print('\nFinished converting the dog dataset!') 213 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import bird 22 | from datasets import aircraft 23 | from datasets import dog 24 | from datasets import car 25 | 26 | datasets_map = { 27 | 'Bird': bird, 28 | 'Aircraft': aircraft, 29 | 'Dog': dog, 30 | 'Car': car, 31 | } 32 | 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | 45 | Returns: 46 | A `Dataset` class. 47 | 48 | Raises: 49 | ValueError: If the dataset `name` is unknown. 50 | """ 51 | if name not in datasets_map: 52 | raise ValueError('Name of dataset unknown %s' % name) 53 | return datasets_map[name].get_split( 54 | split_name, 55 | dataset_dir, 56 | file_pattern, 57 | reader) 58 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(values): 31 | """Returns a TF-Feature of int64s. 32 | 33 | Args: 34 | values: A scalar or list of values. 35 | 36 | Returns: 37 | A TF-Feature. 38 | """ 39 | if not isinstance(values, (tuple, list)): 40 | values = [values] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 42 | 43 | 44 | def bytes_feature(values): 45 | """Returns a TF-Feature of bytes. 46 | 47 | Args: 48 | values: A string. 49 | 50 | Returns: 51 | A TF-Feature. 52 | """ 53 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 54 | 55 | 56 | def float_feature(values): 57 | """Returns a TF-Feature of floats. 58 | 59 | Args: 60 | values: A scalar of list of values. 61 | 62 | Returns: 63 | A TF-Feature. 64 | """ 65 | if not isinstance(values, (tuple, list)): 66 | values = [values] 67 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 68 | 69 | 70 | def image_to_tfexample(image_data, image_format, height, width, label): 71 | return tf.train.Example(features=tf.train.Features(feature={ 72 | 'image/encoded': bytes_feature(image_data), 73 | 'image/format': bytes_feature(image_format), 74 | 'image/height': int64_feature(height), 75 | 'image/width': int64_feature(width), 76 | 'image/class/label': int64_feature(label), 77 | })) 78 | 79 | 80 | def example_to_tfexample(image_data, image_format, height, width, label): 81 | return tf.train.Example(features=tf.train.Features(feature={ 82 | 'image/encoded': bytes_feature(image_data), 83 | 'image/format': bytes_feature(image_format), 84 | 'image/height': int64_feature(height), 85 | 'image/width': int64_feature(width), 86 | 'image/class/label': int64_feature(label) 87 | })) 88 | 89 | 90 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 91 | """Downloads the `tarball_url` and uncompresses it locally. 92 | 93 | Args: 94 | tarball_url: The URL of a tarball file. 95 | dataset_dir: The directory where the temporary files are stored. 96 | """ 97 | filename = tarball_url.split('/')[-1] 98 | filepath = os.path.join(dataset_dir, filename) 99 | 100 | def _progress(count, block_size, total_size): 101 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 102 | filename, float(count * block_size) / float(total_size) * 100.0)) 103 | sys.stdout.flush() 104 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 105 | print() 106 | statinfo = os.stat(filepath) 107 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 108 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 109 | 110 | 111 | def write_label_file(labels_to_class_names, dataset_dir, 112 | filename=LABELS_FILENAME): 113 | """Writes a file with the list of class names. 114 | 115 | Args: 116 | labels_to_class_names: A map of (integer) labels to class names. 117 | dataset_dir: The directory in which the labels file should be written. 118 | filename: The filename where the class names are written. 119 | """ 120 | labels_filename = os.path.join(dataset_dir, filename) 121 | with tf.gfile.Open(labels_filename, 'w') as f: 122 | for label in labels_to_class_names: 123 | class_name = labels_to_class_names[label] 124 | f.write('%d:%s\n' % (label, class_name)) 125 | 126 | 127 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 128 | """Specifies whether or not the dataset directory contains a label map file. 129 | 130 | Args: 131 | dataset_dir: The directory in which the labels file is found. 132 | filename: The filename where the class names are written. 133 | 134 | Returns: 135 | `True` if the labels file exists and `False` otherwise. 136 | """ 137 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 138 | 139 | 140 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 141 | """Reads the labels file and returns a mapping from ID to class name. 142 | 143 | Args: 144 | dataset_dir: The directory in which the labels file is found. 145 | filename: The filename where the class names are written. 146 | 147 | Returns: 148 | A map from a label (integer) to class name. 149 | """ 150 | labels_filename = os.path.join(dataset_dir, filename) 151 | with tf.gfile.Open(labels_filename, 'rb') as f: 152 | lines = f.read().decode() 153 | lines = lines.split('\n') 154 | lines = filter(None, lines) 155 | 156 | labels_to_class_names = {} 157 | for line in lines: 158 | index = line.index(':') 159 | labels_to_class_names[int(line[:index])] = line[index+1:] 160 | return labels_to_class_names 161 | -------------------------------------------------------------------------------- /datasets/dog.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the fgvc dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_fgvc.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'Dog_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 12000, 'test': 8580} 35 | 36 | _NUM_CLASSES = 120 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading fgvc. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)) 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label') 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/deployment/__init__.py -------------------------------------------------------------------------------- /deployment/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/deployment/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deployment/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/deployment/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /deployment/__pycache__/model_deploy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/deployment/__pycache__/model_deploy.cpython-36.pyc -------------------------------------------------------------------------------- /deployment/__pycache__/model_deploy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/deployment/__pycache__/model_deploy.cpython-37.pyc -------------------------------------------------------------------------------- /eval_labels.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/eval_labels.pkl -------------------------------------------------------------------------------- /eval_logits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/eval_logits.pkl -------------------------------------------------------------------------------- /eval_sample_aircraft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET="Aircraft" 4 | TRAIN_DIR="./$DATASET/SAC/TRAIN/Aircraft" 5 | TEST_DIR="./$DATASET/SAC/TEST/Aircraft" 6 | 7 | python eval_sample.py --checkpoint_path=$TRAIN_DIR \ 8 | --dataset_name=$DATASET \ 9 | --dataset_split_name='test' \ 10 | --dataset_dir="./$DATASET/Data/tfrecords" \ 11 | --eval_dir=$TEST_DIR \ 12 | --model_name='inception_v3_topk' \ 13 | --batch_size=16 \ 14 | --eval_image_size=448\ 15 | --gpus="0"\ 16 | --num_classes=100\ 17 | --feature_maps="Mixed_6e"\ 18 | --attention_maps="Mixed_7a_b0"\ 19 | --num_parts=32 20 | -------------------------------------------------------------------------------- /eval_sample_bird.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET="Bird" 4 | WEIGHT_DIR="./$DATASET/SAC/TRAIN/Bird" 5 | TEST_DIR="./$DATASET/SAC/TEST/Bird" 6 | 7 | python eval_sample.py --checkpoint_path=$WEIGHT_DIR \ 8 | --dataset_name=$DATASET \ 9 | --dataset_split_name='test' \ 10 | --dataset_dir="./$DATASET/Data/tfrecords" \ 11 | --eval_dir=$TEST_DIR \ 12 | --model_name='inception_v3_topk' \ 13 | --batch_size=16 \ 14 | --eval_image_size=448\ 15 | --gpus="0"\ 16 | --num_classes=200\ 17 | --feature_maps="Mixed_6e"\ 18 | --attention_maps="Mixed_7a_b0"\ 19 | --num_parts=32 20 | -------------------------------------------------------------------------------- /eval_sample_dog.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET="Dog" 4 | TRAIN_DIR="./$DATASET/SAC/TRAIN/Dog" 5 | TEST_DIR="./$DATASET/SAC/TEST/Dog" 6 | 7 | python eval_sample.py --checkpoint_path=$TRAIN_DIR \ 8 | --dataset_name=$DATASET \ 9 | --dataset_split_name='test' \ 10 | --dataset_dir="./$DATASET/Data/tfrecords" \ 11 | --eval_dir=$TEST_DIR \ 12 | --model_name='inception_v3_topk' \ 13 | --batch_size=16 \ 14 | --eval_image_size=448\ 15 | --gpus="1"\ 16 | --num_classes=200\ 17 | --feature_maps="Mixed_7c"\ 18 | --attention_maps="Mixed_7c"\ 19 | --num_parts=32\ 20 | --ignore_missing_vars=True 21 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__init__.py -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/alexnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/alexnet.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/cifarnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/cifarnet.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/cifarnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/cifarnet.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_resnet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_resnet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_resnet_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_resnet_v2.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_utils.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_utils.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v1.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v1.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v2.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v2.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v3.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v3.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v3_bap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v3_bap.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v3_bap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v3_bap.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v3_topk.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v3_topk.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v3_topk.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v3_topk.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v4.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v4.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v4.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/inception_v4.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/lenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/lenet.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/lenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/lenet.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/mobilenet_v1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/mobilenet_v1.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/mobilenet_v1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/mobilenet_v1.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/nets_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/nets_factory.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/nets_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/nets_factory.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/overfeat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/overfeat.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/overfeat.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/overfeat.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/resnet_utils.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/resnet_utils.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_v1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/resnet_v1.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_v1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/resnet_v1.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/resnet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/resnet_v2.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/vgg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/__pycache__/vgg.cpython-37.pyc -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 256, 256 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 224, 224 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 70 | expected_names = ['alexnet_v2/conv1', 71 | 'alexnet_v2/pool1', 72 | 'alexnet_v2/conv2', 73 | 'alexnet_v2/pool2', 74 | 'alexnet_v2/conv3', 75 | 'alexnet_v2/conv4', 76 | 'alexnet_v2/conv5', 77 | 'alexnet_v2/pool5', 78 | 'alexnet_v2/fc6', 79 | 'alexnet_v2/fc7', 80 | 'alexnet_v2/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 224, 224 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = alexnet.alexnet_v2(inputs, num_classes) 91 | expected_names = ['alexnet_v2/conv1', 92 | 'alexnet_v2/pool1', 93 | 'alexnet_v2/conv2', 94 | 'alexnet_v2/pool2', 95 | 'alexnet_v2/conv3', 96 | 'alexnet_v2/conv4', 97 | 'alexnet_v2/conv5', 98 | 'alexnet_v2/pool5', 99 | 'alexnet_v2/fc6', 100 | 'alexnet_v2/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('alexnet_v2/fc7')) 104 | self.assertListEqual(net.get_shape().as_list(), 105 | [batch_size, 1, 1, 4096]) 106 | 107 | def testModelVariables(self): 108 | batch_size = 5 109 | height, width = 224, 224 110 | num_classes = 1000 111 | with self.test_session(): 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | alexnet.alexnet_v2(inputs, num_classes) 114 | expected_names = ['alexnet_v2/conv1/weights', 115 | 'alexnet_v2/conv1/biases', 116 | 'alexnet_v2/conv2/weights', 117 | 'alexnet_v2/conv2/biases', 118 | 'alexnet_v2/conv3/weights', 119 | 'alexnet_v2/conv3/biases', 120 | 'alexnet_v2/conv4/weights', 121 | 'alexnet_v2/conv4/biases', 122 | 'alexnet_v2/conv5/weights', 123 | 'alexnet_v2/conv5/biases', 124 | 'alexnet_v2/fc6/weights', 125 | 'alexnet_v2/fc6/biases', 126 | 'alexnet_v2/fc7/weights', 127 | 'alexnet_v2/fc7/biases', 128 | 'alexnet_v2/fc8/weights', 129 | 'alexnet_v2/fc8/biases', 130 | ] 131 | model_variables = [v.op.name for v in slim.get_model_variables()] 132 | self.assertSetEqual(set(model_variables), set(expected_names)) 133 | 134 | def testEvaluation(self): 135 | batch_size = 2 136 | height, width = 224, 224 137 | num_classes = 1000 138 | with self.test_session(): 139 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 140 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 141 | self.assertListEqual(logits.get_shape().as_list(), 142 | [batch_size, num_classes]) 143 | predictions = tf.argmax(logits, 1) 144 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 145 | 146 | def testTrainEvalWithReuse(self): 147 | train_batch_size = 2 148 | eval_batch_size = 1 149 | train_height, train_width = 224, 224 150 | eval_height, eval_width = 300, 400 151 | num_classes = 1000 152 | with self.test_session(): 153 | train_inputs = tf.random_uniform( 154 | (train_batch_size, train_height, train_width, 3)) 155 | logits, _ = alexnet.alexnet_v2(train_inputs) 156 | self.assertListEqual(logits.get_shape().as_list(), 157 | [train_batch_size, num_classes]) 158 | tf.get_variable_scope().reuse_variables() 159 | eval_inputs = tf.random_uniform( 160 | (eval_batch_size, eval_height, eval_width, 3)) 161 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 162 | spatial_squeeze=False) 163 | self.assertListEqual(logits.get_shape().as_list(), 164 | [eval_batch_size, 4, 7, num_classes]) 165 | logits = tf.reduce_mean(logits, [1, 2]) 166 | predictions = tf.argmax(logits, 1) 167 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 168 | 169 | def testForward(self): 170 | batch_size = 1 171 | height, width = 224, 224 172 | with self.test_session() as sess: 173 | inputs = tf.random_uniform((batch_size, height, width, 3)) 174 | logits, _ = alexnet.alexnet_v2(inputs) 175 | sess.run(tf.global_variables_initializer()) 176 | output = sess.run(logits) 177 | self.assertTrue(output.any()) 178 | 179 | if __name__ == '__main__': 180 | tf.test.main() 181 | -------------------------------------------------------------------------------- /nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /nets/dcgan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DCGAN generator and discriminator from https://arxiv.org/abs/1511.06434.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from math import log 21 | 22 | from six.moves import xrange # pylint: disable=redefined-builtin 23 | import tensorflow as tf 24 | 25 | slim = tf.contrib.slim 26 | 27 | 28 | def _validate_image_inputs(inputs): 29 | inputs.get_shape().assert_has_rank(4) 30 | inputs.get_shape()[1:3].assert_is_fully_defined() 31 | if inputs.get_shape()[1] != inputs.get_shape()[2]: 32 | raise ValueError('Input tensor does not have equal width and height: ', 33 | inputs.get_shape()[1:3]) 34 | width = inputs.get_shape().as_list()[1] 35 | if log(width, 2) != int(log(width, 2)): 36 | raise ValueError('Input tensor `width` is not a power of 2: ', width) 37 | 38 | 39 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 40 | # setups need the gradient of gradient FusedBatchNormGrad. 41 | def discriminator(inputs, 42 | depth=64, 43 | is_training=True, 44 | reuse=None, 45 | scope='Discriminator', 46 | fused_batch_norm=False): 47 | """Discriminator network for DCGAN. 48 | 49 | Construct discriminator network from inputs to the final endpoint. 50 | 51 | Args: 52 | inputs: A tensor of size [batch_size, height, width, channels]. Must be 53 | floating point. 54 | depth: Number of channels in first convolution layer. 55 | is_training: Whether the network is for training or not. 56 | reuse: Whether or not the network variables should be reused. `scope` 57 | must be given to be reused. 58 | scope: Optional variable_scope. 59 | fused_batch_norm: If `True`, use a faster, fused implementation of 60 | batch norm. 61 | 62 | Returns: 63 | logits: The pre-softmax activations, a tensor of size [batch_size, 1] 64 | end_points: a dictionary from components of the network to their activation. 65 | 66 | Raises: 67 | ValueError: If the input image shape is not 4-dimensional, if the spatial 68 | dimensions aren't defined at graph construction time, if the spatial 69 | dimensions aren't square, or if the spatial dimensions aren't a power of 70 | two. 71 | """ 72 | 73 | normalizer_fn = slim.batch_norm 74 | normalizer_fn_args = { 75 | 'is_training': is_training, 76 | 'zero_debias_moving_mean': True, 77 | 'fused': fused_batch_norm, 78 | } 79 | 80 | _validate_image_inputs(inputs) 81 | inp_shape = inputs.get_shape().as_list()[1] 82 | 83 | end_points = {} 84 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 85 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 86 | with slim.arg_scope([slim.conv2d], 87 | stride=2, 88 | kernel_size=4, 89 | activation_fn=tf.nn.leaky_relu): 90 | net = inputs 91 | for i in xrange(int(log(inp_shape, 2))): 92 | scope = 'conv%i' % (i + 1) 93 | current_depth = depth * 2**i 94 | normalizer_fn_ = None if i == 0 else normalizer_fn 95 | net = slim.conv2d( 96 | net, current_depth, normalizer_fn=normalizer_fn_, scope=scope) 97 | end_points[scope] = net 98 | 99 | logits = slim.conv2d(net, 1, kernel_size=1, stride=1, padding='VALID', 100 | normalizer_fn=None, activation_fn=None) 101 | logits = tf.reshape(logits, [-1, 1]) 102 | end_points['logits'] = logits 103 | 104 | return logits, end_points 105 | 106 | 107 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 108 | # setups need the gradient of gradient FusedBatchNormGrad. 109 | def generator(inputs, 110 | depth=64, 111 | final_size=32, 112 | num_outputs=3, 113 | is_training=True, 114 | reuse=None, 115 | scope='Generator', 116 | fused_batch_norm=False): 117 | """Generator network for DCGAN. 118 | 119 | Construct generator network from inputs to the final endpoint. 120 | 121 | Args: 122 | inputs: A tensor with any size N. [batch_size, N] 123 | depth: Number of channels in last deconvolution layer. 124 | final_size: The shape of the final output. 125 | num_outputs: Number of output features. For images, this is the number of 126 | channels. 127 | is_training: whether is training or not. 128 | reuse: Whether or not the network has its variables should be reused. scope 129 | must be given to be reused. 130 | scope: Optional variable_scope. 131 | fused_batch_norm: If `True`, use a faster, fused implementation of 132 | batch norm. 133 | 134 | Returns: 135 | logits: the pre-softmax activations, a tensor of size 136 | [batch_size, 32, 32, channels] 137 | end_points: a dictionary from components of the network to their activation. 138 | 139 | Raises: 140 | ValueError: If `inputs` is not 2-dimensional. 141 | ValueError: If `final_size` isn't a power of 2 or is less than 8. 142 | """ 143 | normalizer_fn = slim.batch_norm 144 | normalizer_fn_args = { 145 | 'is_training': is_training, 146 | 'zero_debias_moving_mean': True, 147 | 'fused': fused_batch_norm, 148 | } 149 | 150 | inputs.get_shape().assert_has_rank(2) 151 | if log(final_size, 2) != int(log(final_size, 2)): 152 | raise ValueError('`final_size` (%i) must be a power of 2.' % final_size) 153 | if final_size < 8: 154 | raise ValueError('`final_size` (%i) must be greater than 8.' % final_size) 155 | 156 | end_points = {} 157 | num_layers = int(log(final_size, 2)) - 1 158 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 159 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 160 | with slim.arg_scope([slim.conv2d_transpose], 161 | normalizer_fn=normalizer_fn, 162 | stride=2, 163 | kernel_size=4): 164 | net = tf.expand_dims(tf.expand_dims(inputs, 1), 1) 165 | 166 | # First upscaling is different because it takes the input vector. 167 | current_depth = depth * 2 ** (num_layers - 1) 168 | scope = 'deconv1' 169 | net = slim.conv2d_transpose( 170 | net, current_depth, stride=1, padding='VALID', scope=scope) 171 | end_points[scope] = net 172 | 173 | for i in xrange(2, num_layers): 174 | scope = 'deconv%i' % (i) 175 | current_depth = depth * 2 ** (num_layers - i) 176 | net = slim.conv2d_transpose(net, current_depth, scope=scope) 177 | end_points[scope] = net 178 | 179 | # Last layer has different normalizer and activation. 180 | scope = 'deconv%i' % (num_layers) 181 | net = slim.conv2d_transpose( 182 | net, depth, normalizer_fn=None, activation_fn=None, scope=scope) 183 | end_points[scope] = net 184 | 185 | # Convert to proper channels. 186 | scope = 'logits' 187 | logits = slim.conv2d( 188 | net, 189 | num_outputs, 190 | normalizer_fn=None, 191 | activation_fn=None, 192 | kernel_size=1, 193 | stride=1, 194 | padding='VALID', 195 | scope=scope) 196 | end_points[scope] = logits 197 | 198 | logits.get_shape().assert_has_rank(4) 199 | logits.get_shape().assert_is_compatible_with( 200 | [None, final_size, final_size, num_outputs]) 201 | 202 | return logits, end_points 203 | -------------------------------------------------------------------------------- /nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | 24 | from nets import dcgan 25 | 26 | 27 | class DCGANTest(tf.test.TestCase): 28 | 29 | def test_generator_run(self): 30 | tf.set_random_seed(1234) 31 | noise = tf.random_normal([100, 64]) 32 | image, _ = dcgan.generator(noise) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | image.eval() 36 | 37 | def test_generator_graph(self): 38 | tf.set_random_seed(1234) 39 | # Check graph construction for a number of image size/depths and batch 40 | # sizes. 41 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 42 | tf.reset_default_graph() 43 | final_size = 2 ** i 44 | noise = tf.random_normal([batch_size, 64]) 45 | image, end_points = dcgan.generator( 46 | noise, 47 | depth=32, 48 | final_size=final_size) 49 | 50 | self.assertAllEqual([batch_size, final_size, final_size, 3], 51 | image.shape.as_list()) 52 | 53 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 54 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 55 | 56 | # Check layer depths. 57 | for j in range(1, i): 58 | layer = end_points['deconv%i' % j] 59 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 60 | 61 | def test_generator_invalid_input(self): 62 | wrong_dim_input = tf.zeros([5, 32, 32]) 63 | with self.assertRaises(ValueError): 64 | dcgan.generator(wrong_dim_input) 65 | 66 | correct_input = tf.zeros([3, 2]) 67 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 68 | dcgan.generator(correct_input, final_size=30) 69 | 70 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 71 | dcgan.generator(correct_input, final_size=4) 72 | 73 | def test_discriminator_run(self): 74 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 75 | output, _ = dcgan.discriminator(image) 76 | with self.test_session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | output.eval() 79 | 80 | def test_discriminator_graph(self): 81 | # Check graph construction for a number of image size/depths and batch 82 | # sizes. 83 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 84 | tf.reset_default_graph() 85 | img_w = 2 ** i 86 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 87 | output, end_points = dcgan.discriminator( 88 | image, 89 | depth=32) 90 | 91 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 92 | 93 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 94 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 95 | 96 | # Check layer depths. 97 | for j in range(1, i+1): 98 | layer = end_points['conv%i' % j] 99 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 100 | 101 | def test_discriminator_invalid_input(self): 102 | wrong_dim_img = tf.zeros([5, 32, 32]) 103 | with self.assertRaises(ValueError): 104 | dcgan.discriminator(wrong_dim_img) 105 | 106 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 107 | with self.assertRaises(ValueError): 108 | dcgan.discriminator(spatially_undefined_shape) 109 | 110 | not_square = tf.zeros([5, 32, 16, 3]) 111 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 112 | dcgan.discriminator(not_square) 113 | 114 | not_power_2 = tf.zeros([5, 30, 30, 3]) 115 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 116 | dcgan.discriminator(not_power_2) 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.test.main() 121 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | 38 | from nets.inception_v3_bap import inception_v3_bap 39 | from nets.inception_v3_bap import inception_v3_bap_arg_scope 40 | from nets.inception_v3_bap import inception_v3_bap_base 41 | 42 | from nets.inception_v3_topk import inception_v3_topk 43 | from nets.inception_v3_topk import inception_v3_topk_arg_scope 44 | from nets.inception_v3_topk import inception_v3_topk_base 45 | # pylint: enable=unused-import 46 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu, 37 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 38 | """Defines the default arg scope for inception models. 39 | 40 | Args: 41 | weight_decay: The weight decay to use for regularizing the model. 42 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 43 | batch_norm_decay: Decay for batch norm moving average. 44 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 45 | in batch norm. 46 | activation_fn: Activation function for conv2d. 47 | batch_norm_updates_collections: Collection for the update ops for 48 | batch norm. 49 | 50 | Returns: 51 | An `arg_scope` to use for the inception models. 52 | """ 53 | batch_norm_params = { 54 | # Decay for the moving averages. 55 | 'decay': batch_norm_decay, 56 | # epsilon to prevent 0s in variance. 57 | 'epsilon': batch_norm_epsilon, 58 | # collection containing update_ops. 59 | 'updates_collections': batch_norm_updates_collections, 60 | # use fused batch norm if possible. 61 | 'fused': None, 62 | } 63 | if use_batch_norm: 64 | normalizer_fn = slim.batch_norm 65 | normalizer_params = batch_norm_params 66 | else: 67 | normalizer_fn = None 68 | normalizer_params = {} 69 | # Set weight_decay for weights in Conv and FC layers. 70 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 71 | weights_regularizer=slim.l2_regularizer(weight_decay)): 72 | with slim.arg_scope( 73 | [slim.conv2d, slim.fully_connected], 74 | weights_initializer=slim.variance_scaling_initializer(), 75 | activation_fn=activation_fn, 76 | normalizer_fn=normalizer_fn, 77 | normalizer_params=normalizer_params) as sc: 78 | return sc 79 | -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /nets/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 2 | This folder contains building code for MobileNetV2, based on 3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 4 | 5 | # Performance 6 | ## Latency 7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using 8 | TF-Lite on the large core of Pixel 1 phone. 9 | 10 | ![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png) 11 | 12 | ## MACs 13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed 14 | to compute an inference on a single image is a common metric to measure the efficiency of the model. 15 | 16 | Below is the graph comparing V2 vs a few selected networks. The size 17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there 18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers. 19 | 20 | ![madds_top1_accuracy](madds_top1_accuracy.png) 21 | 22 | # Pretrained models 23 | ## Imagenet Checkpoints 24 | 25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1 26 | ---------------------------|---------|---------------|---------|----|------------- 27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0 28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0 29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8 30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1 31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2 32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6 33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6 34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8 35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6 36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4 37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9 38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2 39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7 40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1 41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9 42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9 43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4 44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7 45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6 46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5 47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9 48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5 49 | 50 | # Training 51 | The numbers above can be reproduced using slim's `train_image_classifier`. 52 | Below is the set of parameters that achieves 72.0% for full size MobileNetV2, after about 700K when trained on 8 GPU. 53 | If trained on a single GPU the full convergence is after 5.5M steps. Also note that learning rate and 54 | num_epochs_per_decay both need to be adjusted depending on how many GPUs are being 55 | used due to slim's internal averaging. 56 | 57 | ```bash 58 | --model_name="mobilenet_v2" 59 | --learning_rate=0.045 * NUM_GPUS #slim internally averages clones so we compensate 60 | --preprocessing_name="inception_v2" 61 | --label_smoothing=0.1 62 | --moving_average_decay=0.9999 63 | --batch_size= 96 64 | --num_clones = NUM_GPUS # you can use any number here between 1 and 8 depending on your hardware setup. 65 | --learning_rate_decay_factor=0.98 66 | --num_epochs_per_decay = 2.5 / NUM_GPUS # train_image_classifier does per clone epochs 67 | ``` 68 | 69 | # Example 70 | 71 | 72 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb). 73 | 74 | -------------------------------------------------------------------------------- /nets/mobilenet/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/mobilenet/madds_top1_accuracy.png -------------------------------------------------------------------------------- /nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /nets/mobilenet/mobilenet_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for mobilenet_v2.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import copy 21 | import tensorflow as tf 22 | from nets.mobilenet import conv_blocks as ops 23 | from nets.mobilenet import mobilenet 24 | from nets.mobilenet import mobilenet_v2 25 | 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def find_ops(optype): 31 | """Find ops of a given type in graphdef or a graph. 32 | 33 | Args: 34 | optype: operation type (e.g. Conv2D) 35 | Returns: 36 | List of operations. 37 | """ 38 | gd = tf.get_default_graph() 39 | return [var for var in gd.get_operations() if var.type == optype] 40 | 41 | 42 | class MobilenetV2Test(tf.test.TestCase): 43 | 44 | def setUp(self): 45 | tf.reset_default_graph() 46 | 47 | def testCreation(self): 48 | spec = dict(mobilenet_v2.V2_DEF) 49 | _, ep = mobilenet.mobilenet( 50 | tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec) 51 | num_convs = len(find_ops('Conv2D')) 52 | 53 | # This is mostly a sanity test. No deep reason for these particular 54 | # constants. 55 | # 56 | # All but first 2 and last one have two convolutions, and there is one 57 | # extra conv that is not in the spec. (logits) 58 | self.assertEqual(num_convs, len(spec['spec']) * 2 - 2) 59 | # Check that depthwise are exposed. 60 | for i in range(2, 17): 61 | self.assertIn('layer_%d/depthwise_output' % i, ep) 62 | 63 | def testCreationNoClasses(self): 64 | spec = copy.deepcopy(mobilenet_v2.V2_DEF) 65 | net, ep = mobilenet.mobilenet( 66 | tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec, 67 | num_classes=None) 68 | self.assertIs(net, ep['global_pool']) 69 | 70 | def testImageSizes(self): 71 | for input_size, output_size in [(224, 7), (192, 6), (160, 5), 72 | (128, 4), (96, 3)]: 73 | tf.reset_default_graph() 74 | _, ep = mobilenet_v2.mobilenet( 75 | tf.placeholder(tf.float32, (10, input_size, input_size, 3))) 76 | 77 | self.assertEqual(ep['layer_18/output'].get_shape().as_list()[1:3], 78 | [output_size] * 2) 79 | 80 | def testWithSplits(self): 81 | spec = copy.deepcopy(mobilenet_v2.V2_DEF) 82 | spec['overrides'] = { 83 | (ops.expanded_conv,): dict(split_expansion=2), 84 | } 85 | _, _ = mobilenet.mobilenet( 86 | tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec) 87 | num_convs = len(find_ops('Conv2D')) 88 | # All but 3 op has 3 conv operatore, the remainign 3 have one 89 | # and there is one unaccounted. 90 | self.assertEqual(num_convs, len(spec['spec']) * 3 - 5) 91 | 92 | def testWithOutputStride8(self): 93 | out, _ = mobilenet.mobilenet_base( 94 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 95 | conv_defs=mobilenet_v2.V2_DEF, 96 | output_stride=8, 97 | scope='MobilenetV2') 98 | self.assertEqual(out.get_shape().as_list()[1:3], [28, 28]) 99 | 100 | def testDivisibleBy(self): 101 | tf.reset_default_graph() 102 | mobilenet_v2.mobilenet( 103 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 104 | conv_defs=mobilenet_v2.V2_DEF, 105 | divisible_by=16, 106 | min_depth=32) 107 | s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')] 108 | s = set(s) 109 | self.assertSameElements([32, 64, 96, 160, 192, 320, 384, 576, 960, 1280, 110 | 1001], s) 111 | 112 | def testDivisibleByWithArgScope(self): 113 | tf.reset_default_graph() 114 | # Verifies that depth_multiplier arg scope actually works 115 | # if no default min_depth is provided. 116 | with slim.arg_scope((mobilenet.depth_multiplier,), min_depth=32): 117 | mobilenet_v2.mobilenet( 118 | tf.placeholder(tf.float32, (10, 224, 224, 2)), 119 | conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.1) 120 | s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')] 121 | s = set(s) 122 | self.assertSameElements(s, [32, 192, 128, 1001]) 123 | 124 | def testFineGrained(self): 125 | tf.reset_default_graph() 126 | # Verifies that depth_multiplier arg scope actually works 127 | # if no default min_depth is provided. 128 | 129 | mobilenet_v2.mobilenet( 130 | tf.placeholder(tf.float32, (10, 224, 224, 2)), 131 | conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.01, 132 | finegrain_classification_mode=True) 133 | s = [op.outputs[0].get_shape().as_list()[-1] for op in find_ops('Conv2D')] 134 | s = set(s) 135 | # All convolutions will be 8->48, except for the last one. 136 | self.assertSameElements(s, [8, 48, 1001, 1280]) 137 | 138 | def testMobilenetBase(self): 139 | tf.reset_default_graph() 140 | # Verifies that mobilenet_base returns pre-pooling layer. 141 | with slim.arg_scope((mobilenet.depth_multiplier,), min_depth=32): 142 | net, _ = mobilenet_v2.mobilenet_base( 143 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 144 | conv_defs=mobilenet_v2.V2_DEF, depth_multiplier=0.1) 145 | self.assertEqual(net.get_shape().as_list(), [10, 7, 7, 128]) 146 | 147 | def testWithOutputStride16(self): 148 | tf.reset_default_graph() 149 | out, _ = mobilenet.mobilenet_base( 150 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 151 | conv_defs=mobilenet_v2.V2_DEF, 152 | output_stride=16) 153 | self.assertEqual(out.get_shape().as_list()[1:3], [14, 14]) 154 | 155 | def testWithOutputStride8AndExplicitPadding(self): 156 | tf.reset_default_graph() 157 | out, _ = mobilenet.mobilenet_base( 158 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 159 | conv_defs=mobilenet_v2.V2_DEF, 160 | output_stride=8, 161 | use_explicit_padding=True, 162 | scope='MobilenetV2') 163 | self.assertEqual(out.get_shape().as_list()[1:3], [28, 28]) 164 | 165 | def testWithOutputStride16AndExplicitPadding(self): 166 | tf.reset_default_graph() 167 | out, _ = mobilenet.mobilenet_base( 168 | tf.placeholder(tf.float32, (10, 224, 224, 16)), 169 | conv_defs=mobilenet_v2.V2_DEF, 170 | output_stride=16, 171 | use_explicit_padding=True) 172 | self.assertEqual(out.get_shape().as_list()[1:3], [14, 14]) 173 | 174 | def testBatchNormScopeDoesNotHaveIsTrainingWhenItsSetToNone(self): 175 | sc = mobilenet.training_scope(is_training=None) 176 | self.assertNotIn('is_training', sc[slim.arg_scope_func_key( 177 | slim.batch_norm)]) 178 | 179 | def testBatchNormScopeDoesHasIsTrainingWhenItsNotNone(self): 180 | sc = mobilenet.training_scope(is_training=False) 181 | self.assertIn('is_training', sc[slim.arg_scope_func_key(slim.batch_norm)]) 182 | sc = mobilenet.training_scope(is_training=True) 183 | self.assertIn('is_training', sc[slim.arg_scope_func_key(slim.batch_norm)]) 184 | sc = mobilenet.training_scope() 185 | self.assertIn('is_training', sc[slim.arg_scope_func_key(slim.batch_norm)]) 186 | 187 | 188 | if __name__ == '__main__': 189 | tf.test.main() 190 | -------------------------------------------------------------------------------- /nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /nets/mobilenet_v1_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Validate mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import mobilenet_v1 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | flags = tf.app.flags 31 | 32 | flags.DEFINE_string('master', '', 'Session master') 33 | flags.DEFINE_integer('batch_size', 250, 'Batch size') 34 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 35 | flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate') 36 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 37 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 38 | flags.DEFINE_bool('quantize', False, 'Quantize training') 39 | flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints') 40 | flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs') 41 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def imagenet_input(is_training): 47 | """Data reader for imagenet. 48 | 49 | Reads in imagenet data and performs pre-processing on the images. 50 | 51 | Args: 52 | is_training: bool specifying if train or validation dataset is needed. 53 | Returns: 54 | A batch of images and labels. 55 | """ 56 | if is_training: 57 | dataset = dataset_factory.get_dataset('imagenet', 'train', 58 | FLAGS.dataset_dir) 59 | else: 60 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 61 | FLAGS.dataset_dir) 62 | 63 | provider = slim.dataset_data_provider.DatasetDataProvider( 64 | dataset, 65 | shuffle=is_training, 66 | common_queue_capacity=2 * FLAGS.batch_size, 67 | common_queue_min=FLAGS.batch_size) 68 | [image, label] = provider.get(['image', 'label']) 69 | 70 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 71 | 'mobilenet_v1', is_training=is_training) 72 | 73 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 74 | 75 | images, labels = tf.train.batch( 76 | tensors=[image, label], 77 | batch_size=FLAGS.batch_size, 78 | num_threads=4, 79 | capacity=5 * FLAGS.batch_size) 80 | return images, labels 81 | 82 | 83 | def metrics(logits, labels): 84 | """Specify the metrics for eval. 85 | 86 | Args: 87 | logits: Logits output from the graph. 88 | labels: Ground truth labels for inputs. 89 | 90 | Returns: 91 | Eval Op for the graph. 92 | """ 93 | labels = tf.squeeze(labels) 94 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 95 | 'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels), 96 | 'Recall_5': tf.metrics.recall_at_k(labels, logits, 5), 97 | }) 98 | for name, value in names_to_values.iteritems(): 99 | slim.summaries.add_scalar_summary( 100 | value, name, prefix='eval', print_summary=True) 101 | return names_to_updates.values() 102 | 103 | 104 | def build_model(): 105 | """Build the mobilenet_v1 model for evaluation. 106 | 107 | Returns: 108 | g: graph with rewrites after insertion of quantization ops and batch norm 109 | folding. 110 | eval_ops: eval ops for inference. 111 | variables_to_restore: List of variables to restore from checkpoint. 112 | """ 113 | g = tf.Graph() 114 | with g.as_default(): 115 | inputs, labels = imagenet_input(is_training=False) 116 | 117 | scope = mobilenet_v1.mobilenet_v1_arg_scope( 118 | is_training=False, weight_decay=0.0) 119 | with slim.arg_scope(scope): 120 | logits, _ = mobilenet_v1.mobilenet_v1( 121 | inputs, 122 | is_training=False, 123 | depth_multiplier=FLAGS.depth_multiplier, 124 | num_classes=FLAGS.num_classes) 125 | 126 | if FLAGS.quantize: 127 | tf.contrib.quantize.create_eval_graph() 128 | 129 | eval_ops = metrics(logits, labels) 130 | 131 | return g, eval_ops 132 | 133 | 134 | def eval_model(): 135 | """Evaluates mobilenet_v1.""" 136 | g, eval_ops = build_model() 137 | with g.as_default(): 138 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) 139 | slim.evaluation.evaluate_once( 140 | FLAGS.master, 141 | FLAGS.checkpoint_dir, 142 | logdir=FLAGS.eval_dir, 143 | num_evals=num_batches, 144 | eval_op=eval_ops) 145 | 146 | 147 | def main(unused_arg): 148 | eval_model() 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run(main) 153 | -------------------------------------------------------------------------------- /nets/mobilenet_v1_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Build and train mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from datasets import dataset_factory 24 | from nets import mobilenet_v1 25 | from preprocessing import preprocessing_factory 26 | 27 | slim = tf.contrib.slim 28 | 29 | flags = tf.app.flags 30 | 31 | flags.DEFINE_string('master', '', 'Session master') 32 | flags.DEFINE_integer('task', 0, 'Task') 33 | flags.DEFINE_integer('ps_tasks', 0, 'Number of ps') 34 | flags.DEFINE_integer('batch_size', 64, 'Batch size') 35 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 36 | flags.DEFINE_integer('number_of_steps', None, 37 | 'Number of training steps to perform before stopping') 38 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 39 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 40 | flags.DEFINE_bool('quantize', False, 'Quantize training') 41 | flags.DEFINE_string('fine_tune_checkpoint', '', 42 | 'Checkpoint from which to start finetuning.') 43 | flags.DEFINE_string('checkpoint_dir', '', 44 | 'Directory for writing training checkpoints and logs') 45 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 46 | flags.DEFINE_integer('log_every_n_steps', 100, 'Number of steps per log') 47 | flags.DEFINE_integer('save_summaries_secs', 100, 48 | 'How often to save summaries, secs') 49 | flags.DEFINE_integer('save_interval_secs', 100, 50 | 'How often to save checkpoints, secs') 51 | 52 | FLAGS = flags.FLAGS 53 | 54 | _LEARNING_RATE_DECAY_FACTOR = 0.94 55 | 56 | 57 | def get_learning_rate(): 58 | if FLAGS.fine_tune_checkpoint: 59 | # If we are fine tuning a checkpoint we need to start at a lower learning 60 | # rate since we are farther along on training. 61 | return 1e-4 62 | else: 63 | return 0.045 64 | 65 | 66 | def get_quant_delay(): 67 | if FLAGS.fine_tune_checkpoint: 68 | # We can start quantizing immediately if we are finetuning. 69 | return 0 70 | else: 71 | # We need to wait for the model to train a bit before we quantize if we are 72 | # training from scratch. 73 | return 250000 74 | 75 | 76 | def imagenet_input(is_training): 77 | """Data reader for imagenet. 78 | 79 | Reads in imagenet data and performs pre-processing on the images. 80 | 81 | Args: 82 | is_training: bool specifying if train or validation dataset is needed. 83 | Returns: 84 | A batch of images and labels. 85 | """ 86 | if is_training: 87 | dataset = dataset_factory.get_dataset('imagenet', 'train', 88 | FLAGS.dataset_dir) 89 | else: 90 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 91 | FLAGS.dataset_dir) 92 | 93 | provider = slim.dataset_data_provider.DatasetDataProvider( 94 | dataset, 95 | shuffle=is_training, 96 | common_queue_capacity=2 * FLAGS.batch_size, 97 | common_queue_min=FLAGS.batch_size) 98 | [image, label] = provider.get(['image', 'label']) 99 | 100 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 101 | 'mobilenet_v1', is_training=is_training) 102 | 103 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 104 | 105 | images, labels = tf.train.batch( 106 | [image, label], 107 | batch_size=FLAGS.batch_size, 108 | num_threads=4, 109 | capacity=5 * FLAGS.batch_size) 110 | labels = slim.one_hot_encoding(labels, FLAGS.num_classes) 111 | return images, labels 112 | 113 | 114 | def build_model(): 115 | """Builds graph for model to train with rewrites for quantization. 116 | 117 | Returns: 118 | g: Graph with fake quantization ops and batch norm folding suitable for 119 | training quantized weights. 120 | train_tensor: Train op for execution during training. 121 | """ 122 | g = tf.Graph() 123 | with g.as_default(), tf.device( 124 | tf.train.replica_device_setter(FLAGS.ps_tasks)): 125 | inputs, labels = imagenet_input(is_training=True) 126 | with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)): 127 | logits, _ = mobilenet_v1.mobilenet_v1( 128 | inputs, 129 | is_training=True, 130 | depth_multiplier=FLAGS.depth_multiplier, 131 | num_classes=FLAGS.num_classes) 132 | 133 | tf.losses.softmax_cross_entropy(labels, logits) 134 | 135 | # Call rewriter to produce graph with fake quant ops and folded batch norms 136 | # quant_delay delays start of quantization till quant_delay steps, allowing 137 | # for better model accuracy. 138 | if FLAGS.quantize: 139 | tf.contrib.quantize.create_training_graph(quant_delay=get_quant_delay()) 140 | 141 | total_loss = tf.losses.get_total_loss(name='total_loss') 142 | # Configure the learning rate using an exponential decay. 143 | num_epochs_per_decay = 2.5 144 | imagenet_size = 1271167 145 | decay_steps = int(imagenet_size / FLAGS.batch_size * num_epochs_per_decay) 146 | 147 | learning_rate = tf.train.exponential_decay( 148 | get_learning_rate(), 149 | tf.train.get_or_create_global_step(), 150 | decay_steps, 151 | _LEARNING_RATE_DECAY_FACTOR, 152 | staircase=True) 153 | opt = tf.train.GradientDescentOptimizer(learning_rate) 154 | 155 | train_tensor = slim.learning.create_train_op( 156 | total_loss, 157 | optimizer=opt) 158 | 159 | slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses') 160 | slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training') 161 | return g, train_tensor 162 | 163 | 164 | def get_checkpoint_init_fn(): 165 | """Returns the checkpoint init_fn if the checkpoint is provided.""" 166 | if FLAGS.fine_tune_checkpoint: 167 | variables_to_restore = slim.get_variables_to_restore() 168 | global_step_reset = tf.assign(tf.train.get_or_create_global_step(), 0) 169 | # When restoring from a floating point model, the min/max values for 170 | # quantized weights and activations are not present. 171 | # We instruct slim to ignore variables that are missing during restoration 172 | # by setting ignore_missing_vars=True 173 | slim_init_fn = slim.assign_from_checkpoint_fn( 174 | FLAGS.fine_tune_checkpoint, 175 | variables_to_restore, 176 | ignore_missing_vars=True) 177 | 178 | def init_fn(sess): 179 | slim_init_fn(sess) 180 | # If we are restoring from a floating point model, we need to initialize 181 | # the global step to zero for the exponential decay to result in 182 | # reasonable learning rates. 183 | sess.run(global_step_reset) 184 | return init_fn 185 | else: 186 | return None 187 | 188 | 189 | def train_model(): 190 | """Trains mobilenet_v1.""" 191 | g, train_tensor = build_model() 192 | with g.as_default(): 193 | slim.learning.train( 194 | train_tensor, 195 | FLAGS.checkpoint_dir, 196 | is_chief=(FLAGS.task == 0), 197 | master=FLAGS.master, 198 | log_every_n_steps=FLAGS.log_every_n_steps, 199 | graph=g, 200 | number_of_steps=FLAGS.number_of_steps, 201 | save_summaries_secs=FLAGS.save_summaries_secs, 202 | save_interval_secs=FLAGS.save_interval_secs, 203 | init_fn=get_checkpoint_init_fn(), 204 | global_step=tf.train.get_global_step()) 205 | 206 | 207 | def main(unused_arg): 208 | train_model() 209 | 210 | 211 | if __name__ == '__main__': 212 | tf.app.run(main) 213 | -------------------------------------------------------------------------------- /nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /nets/nasnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/nasnet/__init__.py -------------------------------------------------------------------------------- /nets/nasnet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/nasnet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /nets/nasnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/nasnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/nasnet/__pycache__/nasnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/nasnet/__pycache__/nasnet.cpython-36.pyc -------------------------------------------------------------------------------- /nets/nasnet/__pycache__/nasnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/nasnet/__pycache__/nasnet.cpython-37.pyc -------------------------------------------------------------------------------- /nets/nasnet/__pycache__/nasnet_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/nasnet/__pycache__/nasnet_utils.cpython-36.pyc -------------------------------------------------------------------------------- /nets/nasnet/__pycache__/nasnet_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/nets/nasnet/__pycache__/nasnet_utils.cpython-37.pyc -------------------------------------------------------------------------------- /nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import mobilenet_v1 29 | from nets import overfeat 30 | from nets import resnet_v1 31 | from nets import resnet_v2 32 | from nets import vgg 33 | from nets.nasnet import nasnet 34 | 35 | slim = tf.contrib.slim 36 | 37 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 38 | 'cifarnet': cifarnet.cifarnet, 39 | 'overfeat': overfeat.overfeat, 40 | 'vgg_a': vgg.vgg_a, 41 | 'vgg_16': vgg.vgg_16, 42 | 'vgg_19': vgg.vgg_19, 43 | 'inception_v1': inception.inception_v1, 44 | 'inception_v2': inception.inception_v2, 45 | 'inception_v3': inception.inception_v3, 46 | 'inception_v3_bap': inception.inception_v3_bap, 47 | 'inception_v3_topk': inception.inception_v3_topk, 48 | 'inception_v4': inception.inception_v4, 49 | 'inception_resnet_v2': inception.inception_resnet_v2, 50 | 'lenet': lenet.lenet, 51 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 52 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 53 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 54 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 55 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 56 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 57 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 58 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 59 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 60 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 61 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 62 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 63 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 64 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 65 | 'nasnet_large': nasnet.build_nasnet_large, 66 | } 67 | 68 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 69 | 'cifarnet': cifarnet.cifarnet_arg_scope, 70 | 'overfeat': overfeat.overfeat_arg_scope, 71 | 'vgg_a': vgg.vgg_arg_scope, 72 | 'vgg_16': vgg.vgg_arg_scope, 73 | 'vgg_19': vgg.vgg_arg_scope, 74 | 'inception_v1': inception.inception_v3_arg_scope, 75 | 'inception_v2': inception.inception_v3_arg_scope, 76 | 'inception_v3': inception.inception_v3_arg_scope, 77 | 'inception_v3_bap': inception.inception_v3_bap_arg_scope, 78 | 'inception_v3_topk': inception.inception_v3_topk_arg_scope, 79 | 'inception_v4': inception.inception_v4_arg_scope, 80 | 'inception_resnet_v2': 81 | inception.inception_resnet_v2_arg_scope, 82 | 'lenet': lenet.lenet_arg_scope, 83 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 84 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 85 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 86 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 87 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 88 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 89 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 90 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 91 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 92 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 93 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 94 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 95 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 96 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 97 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 98 | } 99 | 100 | 101 | def get_network_fn(name, dataset_name, num_classes, weight_decay=0.0, is_training=False): 102 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 103 | 104 | Args: 105 | name: The name of the network. 106 | num_classes: The number of classes to use for classification. If 0 or None, 107 | the logits layer is omitted and its input features are returned instead. 108 | weight_decay: The l2 coefficient for the model weights. 109 | is_training: `True` if the model is being used for training and `False` 110 | otherwise. 111 | 112 | Returns: 113 | network_fn: A function that applies the model to a batch of images. It has 114 | the following signature: 115 | net, end_points = network_fn(images) 116 | The `images` input is a tensor of shape [batch_size, height, width, 3] 117 | with height = width = network_fn.default_image_size. (The permissibility 118 | and treatment of other sizes depends on the network_fn.) 119 | The returned `end_points` are a dictionary of intermediate activations. 120 | The returned `net` is the topmost layer, depending on `num_classes`: 121 | If `num_classes` was a non-zero integer, `net` is a logits tensor 122 | of shape [batch_size, num_classes]. 123 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 124 | to the logits layer of shape [batch_size, 1, 1, num_features] or 125 | [batch_size, num_features]. Dropout has not been applied to this 126 | (even if the network's original classification does); it remains for 127 | the caller to do this or not. 128 | 129 | Raises: 130 | ValueError: If network `name` is not recognized. 131 | """ 132 | if name not in networks_map: 133 | raise ValueError('Name of network unknown %s' % name) 134 | func = networks_map[name] 135 | @functools.wraps(func) 136 | def network_fn(images, indices, **kwargs): 137 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 138 | with slim.arg_scope(arg_scope): 139 | return func(images, indices, dataset_name, num_classes, is_training=is_training, **kwargs) 140 | if hasattr(func, 'default_image_size'): 141 | network_fn.default_image_size = func.default_image_size 142 | 143 | return network_fn 144 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in list(nets_factory.networks_map.keys())[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in list(nets_factory.networks_map.keys())[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 281, 281 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 231, 231 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = overfeat.overfeat(inputs, num_classes) 70 | expected_names = ['overfeat/conv1', 71 | 'overfeat/pool1', 72 | 'overfeat/conv2', 73 | 'overfeat/pool2', 74 | 'overfeat/conv3', 75 | 'overfeat/conv4', 76 | 'overfeat/conv5', 77 | 'overfeat/pool5', 78 | 'overfeat/fc6', 79 | 'overfeat/fc7', 80 | 'overfeat/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 231, 231 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = overfeat.overfeat(inputs, num_classes) 91 | expected_names = ['overfeat/conv1', 92 | 'overfeat/pool1', 93 | 'overfeat/conv2', 94 | 'overfeat/pool2', 95 | 'overfeat/conv3', 96 | 'overfeat/conv4', 97 | 'overfeat/conv5', 98 | 'overfeat/pool5', 99 | 'overfeat/fc6', 100 | 'overfeat/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('overfeat/fc7')) 104 | 105 | def testModelVariables(self): 106 | batch_size = 5 107 | height, width = 231, 231 108 | num_classes = 1000 109 | with self.test_session(): 110 | inputs = tf.random_uniform((batch_size, height, width, 3)) 111 | overfeat.overfeat(inputs, num_classes) 112 | expected_names = ['overfeat/conv1/weights', 113 | 'overfeat/conv1/biases', 114 | 'overfeat/conv2/weights', 115 | 'overfeat/conv2/biases', 116 | 'overfeat/conv3/weights', 117 | 'overfeat/conv3/biases', 118 | 'overfeat/conv4/weights', 119 | 'overfeat/conv4/biases', 120 | 'overfeat/conv5/weights', 121 | 'overfeat/conv5/biases', 122 | 'overfeat/fc6/weights', 123 | 'overfeat/fc6/biases', 124 | 'overfeat/fc7/weights', 125 | 'overfeat/fc7/biases', 126 | 'overfeat/fc8/weights', 127 | 'overfeat/fc8/biases', 128 | ] 129 | model_variables = [v.op.name for v in slim.get_model_variables()] 130 | self.assertSetEqual(set(model_variables), set(expected_names)) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | height, width = 231, 231 135 | num_classes = 1000 136 | with self.test_session(): 137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | predictions = tf.argmax(logits, 1) 142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 143 | 144 | def testTrainEvalWithReuse(self): 145 | train_batch_size = 2 146 | eval_batch_size = 1 147 | train_height, train_width = 231, 231 148 | eval_height, eval_width = 281, 281 149 | num_classes = 1000 150 | with self.test_session(): 151 | train_inputs = tf.random_uniform( 152 | (train_batch_size, train_height, train_width, 3)) 153 | logits, _ = overfeat.overfeat(train_inputs) 154 | self.assertListEqual(logits.get_shape().as_list(), 155 | [train_batch_size, num_classes]) 156 | tf.get_variable_scope().reuse_variables() 157 | eval_inputs = tf.random_uniform( 158 | (eval_batch_size, eval_height, eval_width, 3)) 159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 160 | spatial_squeeze=False) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [eval_batch_size, 2, 2, num_classes]) 163 | logits = tf.reduce_mean(logits, [1, 2]) 164 | predictions = tf.argmax(logits, 1) 165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 166 | 167 | def testForward(self): 168 | batch_size = 1 169 | height, width = 231, 231 170 | with self.test_session() as sess: 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | logits, _ = overfeat.overfeat(inputs) 173 | sess.run(tf.global_variables_initializer()) 174 | output = sess.run(logits) 175 | self.assertTrue(output.any()) 176 | 177 | if __name__ == '__main__': 178 | tf.test.main() 179 | -------------------------------------------------------------------------------- /nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def _reduced_default_blocks(self): 28 | """Returns the default blocks, scaled down to make test run faster.""" 29 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 30 | for b in pix2pix._default_generator_blocks()] 31 | 32 | def test_output_size_nn_upsample_conv(self): 33 | batch_size = 2 34 | height, width = 256, 256 35 | num_outputs = 4 36 | 37 | images = tf.ones((batch_size, height, width, 3)) 38 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 39 | logits, _ = pix2pix.pix2pix_generator( 40 | images, num_outputs, blocks=self._reduced_default_blocks(), 41 | upsample_method='nn_upsample_conv') 42 | 43 | with self.test_session() as session: 44 | session.run(tf.global_variables_initializer()) 45 | np_outputs = session.run(logits) 46 | self.assertListEqual([batch_size, height, width, num_outputs], 47 | list(np_outputs.shape)) 48 | 49 | def test_output_size_conv2d_transpose(self): 50 | batch_size = 2 51 | height, width = 256, 256 52 | num_outputs = 4 53 | 54 | images = tf.ones((batch_size, height, width, 3)) 55 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 56 | logits, _ = pix2pix.pix2pix_generator( 57 | images, num_outputs, blocks=self._reduced_default_blocks(), 58 | upsample_method='conv2d_transpose') 59 | 60 | with self.test_session() as session: 61 | session.run(tf.global_variables_initializer()) 62 | np_outputs = session.run(logits) 63 | self.assertListEqual([batch_size, height, width, num_outputs], 64 | list(np_outputs.shape)) 65 | 66 | def test_block_number_dictates_number_of_layers(self): 67 | batch_size = 2 68 | height, width = 256, 256 69 | num_outputs = 4 70 | 71 | images = tf.ones((batch_size, height, width, 3)) 72 | blocks = [ 73 | pix2pix.Block(64, 0.5), 74 | pix2pix.Block(128, 0), 75 | ] 76 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 77 | _, end_points = pix2pix.pix2pix_generator( 78 | images, num_outputs, blocks) 79 | 80 | num_encoder_layers = 0 81 | num_decoder_layers = 0 82 | for end_point in end_points: 83 | if end_point.startswith('encoder'): 84 | num_encoder_layers += 1 85 | elif end_point.startswith('decoder'): 86 | num_decoder_layers += 1 87 | 88 | self.assertEqual(num_encoder_layers, len(blocks)) 89 | self.assertEqual(num_decoder_layers, len(blocks)) 90 | 91 | 92 | class DiscriminatorTest(tf.test.TestCase): 93 | 94 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 95 | return (input_size + pad * 2 - kernel_size) // stride + 1 96 | 97 | def test_four_layers(self): 98 | batch_size = 2 99 | input_size = 256 100 | 101 | output_size = self._layer_output_size(input_size) 102 | output_size = self._layer_output_size(output_size) 103 | output_size = self._layer_output_size(output_size) 104 | output_size = self._layer_output_size(output_size, stride=1) 105 | output_size = self._layer_output_size(output_size, stride=1) 106 | 107 | images = tf.ones((batch_size, input_size, input_size, 3)) 108 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 109 | logits, end_points = pix2pix.pix2pix_discriminator( 110 | images, num_filters=[64, 128, 256, 512]) 111 | self.assertListEqual([batch_size, output_size, output_size, 1], 112 | logits.shape.as_list()) 113 | self.assertListEqual([batch_size, output_size, output_size, 1], 114 | end_points['predictions'].shape.as_list()) 115 | 116 | def test_four_layers_no_padding(self): 117 | batch_size = 2 118 | input_size = 256 119 | 120 | output_size = self._layer_output_size(input_size, pad=0) 121 | output_size = self._layer_output_size(output_size, pad=0) 122 | output_size = self._layer_output_size(output_size, pad=0) 123 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 124 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 125 | 126 | images = tf.ones((batch_size, input_size, input_size, 3)) 127 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 128 | logits, end_points = pix2pix.pix2pix_discriminator( 129 | images, num_filters=[64, 128, 256, 512], padding=0) 130 | self.assertListEqual([batch_size, output_size, output_size, 1], 131 | logits.shape.as_list()) 132 | self.assertListEqual([batch_size, output_size, output_size, 1], 133 | end_points['predictions'].shape.as_list()) 134 | 135 | def test_four_layers_wrog_paddig(self): 136 | batch_size = 2 137 | input_size = 256 138 | 139 | images = tf.ones((batch_size, input_size, input_size, 3)) 140 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 141 | with self.assertRaises(TypeError): 142 | pix2pix.pix2pix_discriminator( 143 | images, num_filters=[64, 128, 256, 512], padding=1.5) 144 | 145 | def test_four_layers_negative_padding(self): 146 | batch_size = 2 147 | input_size = 256 148 | 149 | images = tf.ones((batch_size, input_size, input_size, 3)) 150 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 151 | with self.assertRaises(ValueError): 152 | pix2pix.pix2pix_discriminator( 153 | images, num_filters=[64, 128, 256, 512], padding=-1) 154 | 155 | if __name__ == '__main__': 156 | tf.test.main() 157 | -------------------------------------------------------------------------------- /num_bboxes.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/num_bboxes.pkl -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/cifarnet_preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/cifarnet_preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/cifarnet_preprocessing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/cifarnet_preprocessing.cpython-37.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/inception_preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/inception_preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/inception_preprocessing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/inception_preprocessing.cpython-37.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/lenet_preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/lenet_preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/lenet_preprocessing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/lenet_preprocessing.cpython-37.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/preprocessing_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/preprocessing_factory.cpython-36.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/preprocessing_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/preprocessing_factory.cpython-37.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/vgg_preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/vgg_preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/vgg_preprocessing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/preprocessing/__pycache__/vgg_preprocessing.cpython-37.pyc -------------------------------------------------------------------------------- /preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v3_bap': inception_preprocessing, 54 | 'inception_v3_topk': inception_preprocessing, 55 | 'inception_v4': inception_preprocessing, 56 | 'inception_resnet_v2': inception_preprocessing, 57 | 'lenet': lenet_preprocessing, 58 | 'mobilenet_v1': inception_preprocessing, 59 | 'nasnet_mobile': inception_preprocessing, 60 | 'nasnet_large': inception_preprocessing, 61 | 'resnet_v1_50': vgg_preprocessing, 62 | 'resnet_v1_101': vgg_preprocessing, 63 | 'resnet_v1_152': vgg_preprocessing, 64 | 'resnet_v1_200': vgg_preprocessing, 65 | 'resnet_v2_50': vgg_preprocessing, 66 | 'resnet_v2_101': vgg_preprocessing, 67 | 'resnet_v2_152': vgg_preprocessing, 68 | 'resnet_v2_200': vgg_preprocessing, 69 | 'vgg': vgg_preprocessing, 70 | 'vgg_a': vgg_preprocessing, 71 | 'vgg_16': vgg_preprocessing, 72 | 'vgg_19': vgg_preprocessing, 73 | } 74 | 75 | if name not in preprocessing_fn_map: 76 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 77 | 78 | def preprocessing_fn(image, output_height, output_width, **kwargs): 79 | return preprocessing_fn_map[name].preprocess_image( 80 | image, output_height, output_width, is_training=is_training, **kwargs) 81 | 82 | return preprocessing_fn 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.5 2 | opencv_python==4.5.2.54 3 | scipy==1.5.4 4 | setuptools==57.0.0 5 | six==1.15.0 6 | tensorflow==1.14.0 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup script for slim.""" 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | 21 | setup( 22 | name='slim', 23 | version='0.1', 24 | include_package_data=True, 25 | packages=find_packages(), 26 | description='tf-slim', 27 | ) 28 | -------------------------------------------------------------------------------- /train_sample_aircraft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET="Aircraft" 4 | TRAIN_DIR="./$DATASET/SAC/TRAIN/Aircraft" 5 | MODEL_PATH='./pre_trained/inception_v3.ckpt' 6 | 7 | python train_sample.py --learning_rate=0.0007 \ 8 | --dataset_name=$DATASET \ 9 | --dataset_dir="./$DATASET/Data/tfrecords" \ 10 | --train_dir=$TRAIN_DIR \ 11 | --checkpoint_path=$MODEL_PATH \ 12 | --max_number_of_steps=80000 \ 13 | --weight_decay=1e-5 \ 14 | --model_name='inception_v3_topk' \ 15 | --checkpoint_exclude_scopes="InceptionV3/bilinear_attention_pooling" \ 16 | --batch_size=12 \ 17 | --train_image_size=448 \ 18 | --num_clones=1 \ 19 | --gpus="0"\ 20 | --feature_maps="Mixed_6e"\ 21 | --attention_maps="Mixed_7a_b0"\ 22 | --num_parts=32 23 | -------------------------------------------------------------------------------- /train_sample_bird.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET="Bird" 4 | TRAIN_DIR="./$DATASET/SAC/TRAIN/Bird" 5 | MODEL_PATH='./pre_trained/inception_v3.ckpt' 6 | 7 | python train_sample.py --learning_rate=0.001\ 8 | --dataset_name=$DATASET \ 9 | --dataset_dir="./$DATASET/Data/tfrecords" \ 10 | --train_dir=$TRAIN_DIR \ 11 | --checkpoint_path=$MODEL_PATH \ 12 | --max_number_of_steps=80000 \ 13 | --weight_decay=1e-5 \ 14 | --model_name='inception_v3_topk' \ 15 | --checkpoint_exclude_scopes="InceptionV3/bilinear_attention_pooling" \ 16 | --batch_size=12 \ 17 | --train_image_size=299 \ 18 | --num_clones=1 \ 19 | --gpus="0"\ 20 | --feature_maps="Mixed_6e"\ 21 | --attention_maps="Mixed_7a_b0"\ 22 | --num_parts=32\ 23 | --ignore_missing_vars=True\ 24 | --save_interval_secs=120 25 | -------------------------------------------------------------------------------- /train_sample_dog.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET="Dog" 4 | TRAIN_DIR="./$DATASET/SAC/TRAIN/Dog" 5 | MODEL_PATH='./pre_trained/inception_v3.ckpt' 6 | 7 | python train_sample.py --learning_rate=0.00001\ 8 | --dataset_name=$DATASET \ 9 | --dataset_dir="./$DATASET/Data/tfrecords" \ 10 | --train_dir=$TRAIN_DIR \ 11 | --checkpoint_path=$MODEL_PATH \ 12 | --max_number_of_steps=80000 \ 13 | --weight_decay=1e-5 \ 14 | --model_name='inception_v3_topk' \ 15 | --checkpoint_exclude_scopes="InceptionV3/bilinear_attention_pooling" \ 16 | --batch_size=1 \ 17 | --train_image_size=448 \ 18 | --num_clones=1 \ 19 | --gpus="0"\ 20 | --feature_maps="Mixed_7c"\ 21 | --attention_maps="Mixed_7c"\ 22 | --num_parts=32 23 | -------------------------------------------------------------------------------- /utils/__pycache__/lstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/utils/__pycache__/lstm.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/lstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/utils/__pycache__/lstm.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aioz-ai/SAC/834424cae2fcbb1a80a1ac7b5c575be0e15bbed6/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/lstm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | import tensorflow as tf 5 | 6 | np.random.seed(1) 7 | rnn_cell = tf.contrib.rnn 8 | 9 | 10 | def class_embedding(sentence, word2vec, word2index, emb_dim): 11 | batch, num_class, max_words = sentence.shape 12 | rnn_size = 1024 13 | sentence = tf.reshape(sentence, [batch*num_class, max_words]) 14 | sentence = tf.cast(sentence, dtype=tf.int32) 15 | # create word embedding 16 | embed_ques_W = tf.Variable(word2vec) 17 | # create LSTM 18 | lstm_1 = rnn_cell.LSTMCell(rnn_size, emb_dim) 19 | lstm_dropout_1 = rnn_cell.DropoutWrapper(lstm_1, output_keep_prob=0.8) 20 | lstm_2 = rnn_cell.LSTMCell(rnn_size, rnn_size) 21 | lstm_dropout_2 = rnn_cell.DropoutWrapper(lstm_2, output_keep_prob=0.8) 22 | stacked_lstm = rnn_cell.MultiRNNCell([lstm_dropout_1, lstm_dropout_2]) 23 | state = stacked_lstm.zero_state(batch*num_class, tf.float32) 24 | 25 | with tf.variable_scope("embed"): 26 | for i in range(max_words): 27 | if i > 0: 28 | tf.get_variable_scope().reuse_variables() 29 | 30 | cls_emb_linear = tf.nn.embedding_lookup(embed_ques_W, sentence[:, i]) 31 | cls_emb_drop = tf.nn.dropout(cls_emb_linear, .8) 32 | cls_emb = tf.tanh(cls_emb_drop) 33 | 34 | output, state = stacked_lstm(cls_emb, state) 35 | output = tf.reshape(output, [batch, num_class, rnn_size]) 36 | return output -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pickle 3 | from math import * 4 | 5 | def convert_index_to_word(cls_ids, classes): 6 | words = [] 7 | for ids in cls_ids: 8 | b_words = [] 9 | for idx in ids: 10 | cls = classes[idx] 11 | b_words.append(cls) 12 | words.append(b_words) 13 | words = np.array(words) 14 | return words 15 | 16 | 17 | def read_glove_vecs(glove_file, dictionary_file): 18 | d = pickle.load(open(dictionary_file, 'rb')) 19 | word_to_vec_map = np.load(glove_file) 20 | words_to_index = d[0] 21 | index_to_words = d[1] 22 | return words_to_index, index_to_words, word_to_vec_map 23 | 24 | 25 | def sentences_to_indices(X, word_to_index, max_len): 26 | """ 27 | Converts an array of sentences (strings) into an array of indices corresponding to words in the sentences. 28 | The output shape should be such that it can be given to `Embedding()` 29 | 30 | Arguments: 31 | X -- array of sentences (strings), of shape (m, 1) 32 | word_to_index -- a dictionary containing the each word mapped to its index 33 | max_len -- maximum number of words in a sentence. You can assume every sentence in X is no longer than this. 34 | 35 | Returns: 36 | X_indices -- array of indices corresponding to words in the sentences from X, of shape (m, max_len) 37 | """ 38 | 39 | m = X.shape[0] # number of training examples 40 | # Initialize X_indices as a numpy matrix of zeros and the correct shape (≈ 1 line) 41 | X_indices = np.zeros((m, max_len)) 42 | 43 | for i in range(m): # loop over training examples 44 | # Convert the ith training sentence in lower case and split is into words. You should get a list of words. 45 | sentence_words = (X[i].lower()).split() 46 | sentence_words = sentence_words[:max_len] 47 | # Initialize j to 0 48 | j = 0 49 | # Loop over the words of sentence_words 50 | for w in sentence_words: 51 | # Set the (i,j)th entry of X_indices to the index of the correct word. 52 | X_indices[i, j] = word_to_index[w] 53 | # Increment j to j + 1 54 | j = j + 1 55 | return X_indices 56 | 57 | 58 | def cosine_annealing(step, n_iters, n_cycles, lrate_max): 59 | iter_per_cycle = n_iters / n_cycles 60 | cos_inner = (pi * (step % iter_per_cycle)) / (iter_per_cycle) 61 | lr = lrate_max / 2 * (cos(cos_inner) + 1) 62 | return np.array(lr).astype(np.float32) 63 | 64 | 65 | def num_params(variables): 66 | total_params = 0 67 | for v in variables: 68 | shape = v.get_shape() 69 | params = 1 70 | for dim in shape: 71 | params *= dim 72 | total_params += params 73 | return total_params 74 | --------------------------------------------------------------------------------