├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 128collage_old.png └── 256_external_collage.png ├── aws.py ├── c_train.sh ├── check_images.py ├── discriminators.py ├── experiments ├── eval_can_external_style.sh ├── eval_can_paper.sh ├── mnist_train_can.sh ├── mnist_train_wgan.sh ├── train_can_external_style.sh ├── train_can_paper.sh ├── wiki_can.sh ├── wiki_can_256.sh ├── wiki_can_larger_batch.sh ├── wiki_can_s3.sh ├── wiki_classify_gan.sh ├── wiki_external_can.sh ├── wiki_exxternal_can_ht=128,bs=16.sh ├── wiki_wgan.sh └── wiki_wgan_small.sh ├── generators.py ├── losses.py ├── main.py ├── model.py ├── ops.py ├── slim ├── 2 ├── BUILD ├── README.md ├── WORKSPACE ├── __init__.py ├── convert_wikiart.sh ├── datasets │ ├── __init__.py │ ├── build_imagenet_data.py │ ├── cifar10.py │ ├── convert_wikiart.py │ ├── dataset_factory.py │ ├── dataset_utils.py │ ├── download_and_convert_cifar10.py │ ├── download_and_convert_flowers.py │ ├── download_and_convert_imagenet.sh │ ├── download_and_convert_mnist.py │ ├── download_imagenet.sh │ ├── flowers.py │ ├── imagenet.py │ ├── imagenet_2012_validation_synset_labels.txt │ ├── imagenet_lsvrc_2015_synsets.txt │ ├── imagenet_metadata.txt │ ├── mnist.py │ ├── preprocess_imagenet_validation_data.py │ ├── process_bounding_boxes.py │ └── wikiart.py ├── deployment │ ├── __init__.py │ ├── model_deploy.py │ └── model_deploy_test.py ├── download_and_convert_data.py ├── eval_image_classifier.py ├── eval_wikiart_cpu.sh ├── export_inference_graph.py ├── export_inference_graph_test.py ├── finetune_inception_resnet_v2_on_wikiart.sh ├── nets │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── cyclegan.py │ ├── cyclegan_test.py │ ├── dcgan.py │ ├── dcgan_test.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_utils.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── inception_v4.py │ ├── inception_v4_test.py │ ├── lenet.py │ ├── mobilenet_v1.md │ ├── mobilenet_v1.png │ ├── mobilenet_v1.py │ ├── mobilenet_v1_test.py │ ├── nasnet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── nasnet.py │ │ ├── nasnet_test.py │ │ ├── nasnet_utils.py │ │ └── nasnet_utils_test.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── pix2pix.py │ ├── pix2pix_test.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── vgg.py │ └── vgg_test.py ├── preprocessing │ ├── __init__.py │ ├── cifarnet_preprocessing.py │ ├── inception_preprocessing.py │ ├── lenet_preprocessing.py │ ├── preprocessing_factory.py │ └── vgg_preprocessing.py ├── scripts │ ├── export_mobilenet.sh │ ├── finetune_inception_resnet_v2_on_flowers.sh │ ├── finetune_inception_v1_on_flowers.sh │ ├── finetune_inception_v3_on_flowers.sh │ ├── finetune_resnet_v1_50_on_flowers.sh │ ├── train_cifarnet_on_cifar10.sh │ └── train_lenet_on_mnist.sh ├── setup.py ├── slim_walkthrough.ipynb ├── train_image_classifier.py └── wikiart │ └── flowers_train_00000-of-00005.tfrecord └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data 3 | buffer 4 | samples 5 | *.zip 6 | logs 7 | test* 8 | 9 | web/js/gen_layers.js 10 | 11 | # checkpoint 12 | checkpoint 13 | 14 | # trash 15 | .dropbox 16 | .DS_Store 17 | 18 | # Created by https://www.gitignore.io/api/python,vim 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # Jupyter notebooks 27 | .ipynb_checkpoints/ 28 | */.ipynb_checkpoints/ 29 | 30 | # C extensions 31 | *.so 32 | 33 | # Distribution / packaging 34 | .Python 35 | env/ 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | .eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *,cover 70 | .hypothesis/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | 86 | ### Vim ### 87 | [._]*.s[a-w][a-z] 88 | [._]s[a-w][a-z] 89 | *.un~ 90 | Session.vim 91 | .netrwhist 92 | *~ 93 | 94 | 95 | # philkuz random folders 96 | _*/ 97 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Creative Adversarial Networks 2 | ![collage](assets/256_external_collage.png) 3 | 4 | *256x256 samples directly from CAN (no cherry picking) with fixed classification network trained on WikiArt* 5 | 6 | 7 | An implementation of [CAN: Creative Adversarial Networks, Generating "Art" 8 | by Learning About Styles and Deviating from Style Norms](https://arxiv.org/abs/1706.07068) with a variation that improves sample variance and quality significantly. 9 | 10 | Repo based on [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow). 11 | 12 | 13 | 14 | 15 | 16 | ## Getting the Dataset 17 | We used the [wikiart](https://www.wikiart.org/) dataset 18 | [available here](https://github.com/cs-chan/ICIP2016-PC/tree/f5d6f6b58a6d8a4bd05aaaedd9688d08c02df8f2/WikiArt%20Dataset). 19 | Using the dataset is subject to wikiart's [terms of use](https://www.wikiart.org/en/terms-of-use) 20 | 21 | ``` 22 | mkdir data 23 | cd data 24 | wget http://www.cs-chan.com/source/ICIP2017/wikiart.zip 25 | unzip wikiart.zip 26 | ``` 27 | 28 | ## Getting pretrained models 29 | We uploaded all of our models to this [google drive folder](https://drive.google.com/open?id=1FNDxvpb_UY5MZ3zBnOOfGDQCXzeE7hbs) 30 | 31 | ## Training a CAN model from scratch (architecture used in the paper) 32 | ``` 33 | bash experiments/train_can_paper.sh # must run from the root directory of the project 34 | ``` 35 | ## Evaluating an existing CAN model 36 | ``` 37 | # make sure that load_dir acts correctly 38 | bash experiments/eval_can_paper.sh 39 | ``` 40 | 41 | # External Style Classification network 42 | We ran an experiment where we trained an inception resnet to classify style (60% accuracy) 43 | and then used this for the style classification loss, removing the need to learn the layers 44 | in the discriminator. We hold the style classification network constant, so the style distribution 45 | doesn't change as the generator improves. We found that this improved the quality and diversity 46 | of our samples. 47 | 48 | ## Training CAN with External Style Network 49 | ``` 50 | # make sure that `style_net_checkpoint` is set correctly, or you will error out 51 | bash experiment/train_can_external_style.sh 52 | ``` 53 | 54 | ## Training the (ImageNet pre-trained) Inception Resnet 55 | Everything you need should be included in the script. The gist is that it converts the wikiart images into tf records 56 | trains the last layer of the model on these images, then fine-tunes the entire model for 100 epochs, at the end of which 57 | you should get roughlyy 60% validation accuracy. Since we're looking to generate artwork, this gives us a 58 | level of accuracy that is sufficient to try and generate new artwork. 59 | ``` 60 | cd slim/ 61 | vim finetune_inception_resnet_v2_on_wikiart.sh # edit INPUT_DATASET_DIR to match the location of where you downloaded wikiart 62 | bash finetune_inception_resnet_v2_on_wikiart.sh 63 | ``` 64 | ## Evaluating CAN with External Style Network 65 | ``` 66 | # make sure that `style_net_checkpoint` and `load_dir` point to the downloaded models. 67 | bash eval_can_external_style.sh 68 | ``` 69 | 70 | ## Experiments 71 | We have run a variety of experiments, all of which are available in the `experiments/` directory. 72 | ## Authors 73 | [Phillip Kravtsov](https://github.com/phillip-kravtsov) 74 | 75 | [Phillip Kuznetsov](https://github.com/philkuz) 76 | 77 | ## Citation 78 | 79 | If you use this implementation in your own work please cite the following 80 | ``` 81 | @misc{2017cans, 82 | author = {Phillip Kravtsov and Phillip Kuznetsov}, 83 | title = {Creative Adversarial Networks}, 84 | year = {2017}, 85 | howpublished = {\url{https://github.com/mlberkeley/Creative-Adversarial-Networks}}, 86 | note = {commit xxxxxxx} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /assets/128collage_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlberkeley/Creative-Adversarial-Networks/fea29d4348a650a40322fc4da645395d3d0f089c/assets/128collage_old.png -------------------------------------------------------------------------------- /assets/256_external_collage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlberkeley/Creative-Adversarial-Networks/fea29d4348a650a40322fc4da645395d3d0f089c/assets/256_external_collage.png -------------------------------------------------------------------------------- /aws.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import os 3 | 4 | # s3 = boto3.resource('s3') 5 | # s3.meta.client.upload_file('tmp/', BUCKET_NAME, 'tmp/') 6 | 7 | 8 | def bucket_exists(bucket): 9 | s3 = boto3.resource('s3') 10 | return s3.Bucket(bucket) in s3.buckets.all() 11 | def upload_path(local_directory, bucket, destination, certain_upload=False): 12 | client = boto3.client('s3') 13 | # enumerate local files recursively 14 | for root, dirs, files in os.walk(local_directory): 15 | 16 | for filename in files: 17 | 18 | # construct the full local path 19 | local_path = os.path.join(root, filename) 20 | 21 | # construct the full Dropbox path 22 | relative_path = os.path.relpath(local_path, local_directory) 23 | s3_path = os.path.join(destination, relative_path) 24 | 25 | if certain_upload: 26 | client.upload_file(local_path, bucket, s3_path) 27 | return 28 | 29 | print('Searching "%s" in "%s"' % (s3_path, bucket)) 30 | try: 31 | client.head_object(Bucket=bucket, Key=s3_path) 32 | # print("Path found on S3! Skipping %s..." % s3_path) 33 | except: 34 | print("Uploading %s..." % s3_path) 35 | client.upload_file(local_path, bucket, s3_path) 36 | -------------------------------------------------------------------------------- /c_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python3 train_classify.py 3 | -------------------------------------------------------------------------------- /check_images.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script to check whether images are corrupted. Without an argument, checks `data/wikiart/`. Otherwise checks 3 | the `data/` 4 | Usage: 5 | python check_images.py 6 | 7 | : `data/` 8 | ''' 9 | from utils import * 10 | from glob import glob 11 | import sys 12 | 13 | if len(sys.argv) == 1: 14 | test_images(glob("./data/wikiart/*/*.jpg")) 15 | else: 16 | test_images(glob("./data/" + str(sys.argv[1]) + "/*/*.jpg")) 17 | 18 | -------------------------------------------------------------------------------- /experiments/eval_can_external_style.sh: -------------------------------------------------------------------------------- 1 | # trains gan with an outside can network instead of having the discriminator learn style classification 2 | export PYTHONPATH="slim/:$PYTHONPATH" 3 | export CUDA_VISIBLE_DEVICES=0 4 | BATCH_SIZE=16 5 | python3 main.py \ 6 | --epoch 25 \ 7 | --learning_rate .0001 \ 8 | --beta 0.5 \ 9 | --batch_size $BATCH_SIZE \ 10 | --sample_size $BATCH_SIZE \ 11 | --input_height 256 \ 12 | --output_height 256 \ 13 | --lambda_val 1.0 \ 14 | --smoothing 1.0 \ 15 | --use_resize True \ 16 | --dataset wikiart \ 17 | --input_fname_pattern */*.jpg \ 18 | --crop False \ 19 | --visualize False \ 20 | --use_s3 False \ 21 | --can True \ 22 | --train \ 23 | --style_net_checkpoint "slim/logs/wikiart/inception_resnet_v2/all/bs=16,lr=0.0001,epochs=100/smol_adam_fixedLR" 24 | # --style_net_checkpoint "logs/inception_resnet_v2/" 25 | -------------------------------------------------------------------------------- /experiments/eval_can_paper.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 # edit this if you want to limit yourself to GPU 2 | export PYTHONPATH="slim/:$PYTHONPATH" 3 | python3 main.py \ 4 | --epoch 25 \ 5 | --learning_rate .0001 \ 6 | --beta 0.5 \ 7 | --batch_size 16 \ 8 | --sample_size 16 \ 9 | --input_height 256 \ 10 | --output_height 256 \ 11 | --lambda_val 1.0 \ 12 | --smoothing 1.0 \ 13 | --use_resize True \ 14 | --dataset wikiart \ 15 | --input_fname_pattern */*.jpg \ 16 | --load_dir "logs/can_paper" 17 | --crop False \ 18 | --visualize False \ 19 | --can True \ 20 | --train False 21 | -------------------------------------------------------------------------------- /experiments/mnist_train_can.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 4 \ 7 | --sample_size 9 \ 8 | --input_height 28 \ 9 | --output_height 28 \ 10 | --lambda_val 1.0 \ 11 | --smoothing 1 \ 12 | --dataset mnist \ 13 | --input_fname_pattern */*.jpg \ 14 | --checkpoint_dir checkpoint \ 15 | --sample_dir samples \ 16 | --crop False \ 17 | --visualize False \ 18 | --can True \ 19 | --wgan False \ 20 | --train 21 | -------------------------------------------------------------------------------- /experiments/mnist_train_wgan.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 128 \ 7 | --sample_size 100 \ 8 | --input_height 28 \ 9 | --output_height 28 \ 10 | --lambda_val 1.0 \ 11 | --smoothing 1 \ 12 | --dataset mnist \ 13 | --input_fname_pattern */*.jpg \ 14 | --checkpoint_dir checkpoint \ 15 | --sample_dir samples \ 16 | --crop False \ 17 | --visualize False \ 18 | --can False \ 19 | --use_s3 False \ 20 | --s3_bucket mlberkeley-cans \ 21 | --replay False \ 22 | --wgan True \ 23 | --train 24 | 25 | -------------------------------------------------------------------------------- /experiments/train_can_external_style.sh: -------------------------------------------------------------------------------- 1 | # trains gan with an outside can network instead of having the discriminator learn style classification 2 | export PYTHONPATH="slim/:$PYTHONPATH" 3 | export CUDA_VISIBLE_DEVICES=0 4 | BATCH_SIZE=16 5 | python3 main.py \ 6 | --epoch 25 \ 7 | --learning_rate .0001 \ 8 | --beta 0.5 \ 9 | --batch_size $BATCH_SIZE \ 10 | --sample_size $BATCH_SIZE \ 11 | --input_height 256 \ 12 | --output_height 256 \ 13 | --lambda_val 1.0 \ 14 | --smoothing 1.0 \ 15 | --use_resize True \ 16 | --dataset wikiart \ 17 | --input_fname_pattern */*.jpg \ 18 | --crop False \ 19 | --visualize False \ 20 | --use_s3 False \ 21 | --can True \ 22 | --train False \ 23 | --load_dir "logs/can_external_style" \ 24 | --style_net_checkpoint "logs/inception_resnet_v2/" 25 | -------------------------------------------------------------------------------- /experiments/train_can_paper.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 # edit this if you want to limit yourself to GPU 2 | export PYTHONPATH="slim/:$PYTHONPATH" 3 | python3 main.py \ 4 | --epoch 25 \ 5 | --learning_rate .0001 \ 6 | --beta 0.5 \ 7 | --batch_size 16 \ 8 | --sample_size 16 \ 9 | --input_height 256 \ 10 | --output_height 256 \ 11 | --lambda_val 1.0 \ 12 | --smoothing 1.0 \ 13 | --use_resize True \ 14 | --dataset wikiart \ 15 | --input_fname_pattern */*.jpg \ 16 | --crop False \ 17 | --visualize False \ 18 | --can True \ 19 | --train \ 20 | -------------------------------------------------------------------------------- /experiments/wiki_can.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 35 \ 7 | --sample_size 72 \ 8 | --input_height 128 \ 9 | --output_height 128 \ 10 | --lambda_val 1.0 \ 11 | --smoothing 1.0 \ 12 | --use_resize True \ 13 | --dataset wikiart \ 14 | --input_fname_pattern */*.jpg \ 15 | --crop False \ 16 | --visualize False \ 17 | --use_s3 False \ 18 | --can True \ 19 | --train \ 20 | -------------------------------------------------------------------------------- /experiments/wiki_can_256.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 # edit this if you want to limit yourself to GPU 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 16 \ 7 | --sample_size 16 \ 8 | --input_height 256 \ 9 | --output_height 256 \ 10 | --lambda_val 1.0 \ 11 | --smoothing 1.0 \ 12 | --use_resize True \ 13 | --dataset wikiart \ 14 | --input_fname_pattern */*.jpg \ 15 | --checkpoint_dir checkpoint \ 16 | --sample_dir samples \ 17 | --crop False \ 18 | --visualize False \ 19 | --can True \ 20 | --train \ 21 | -------------------------------------------------------------------------------- /experiments/wiki_can_larger_batch.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=1 2 | python3 main.py \ 3 | --epoch 100 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 35 \ 7 | --save_itr 250 \ 8 | --sample_size 30 \ 9 | --input_height 256 \ 10 | --output_height 256 \ 11 | --lambda_val 1.0 \ 12 | --smoothing 1.0 \ 13 | --use_resize True \ 14 | --dataset wikiart \ 15 | --input_fname_pattern */*.jpg \ 16 | --checkpoint_dir checkpoint \ 17 | --sample_dir samples \ 18 | --crop False \ 19 | --visualize False \ 20 | --use_s3 \ 21 | --s3_bucket "creative-adv-nets" \ 22 | --can True \ 23 | --train \ 24 | -------------------------------------------------------------------------------- /experiments/wiki_can_s3.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=1 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 35 \ 7 | --sample_size 72 \ 8 | --input_height 128 \ 9 | --output_height 128 \ 10 | --lambda_val 1.0 \ 11 | --smoothing 1.0 \ 12 | --use_resize True \ 13 | --dataset wikiart \ 14 | --input_fname_pattern */*.jpg \ 15 | --checkpoint_dir checkpoint \ 16 | --sample_dir samples \ 17 | --crop False \ 18 | --visualize False \ 19 | --use_s3 \ 20 | --can True \ 21 | --train \ 22 | --use_s3 True\ 23 | --s3_bucket 'adv-maml-models' \ 24 | -------------------------------------------------------------------------------- /experiments/wiki_classify_gan.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 32 \ 7 | --sample_size 72 \ 8 | --input_height 128 \ 9 | --output_height 128 \ 10 | --lambda_val 0.0 \ 11 | --smoothing 1.0 \ 12 | --use_resize True \ 13 | --dataset wikiart \ 14 | --input_fname_pattern */*.jpg \ 15 | --checkpoint_dir checkpoint \ 16 | --sample_dir samples \ 17 | --crop False \ 18 | --visualize False \ 19 | --can True \ 20 | --train 21 | -------------------------------------------------------------------------------- /experiments/wiki_external_can.sh: -------------------------------------------------------------------------------- 1 | # trains gan with an outside can network instead of having the discriminator learn style classification 2 | export PYTHONPATH="slim/:$PYTHONPATH" 3 | export CUDA_VISIBLE_DEVICES=0 4 | BATCH_SIZE=16 5 | python3 main.py \ 6 | --epoch 25 \ 7 | --learning_rate .0001 \ 8 | --beta 0.5 \ 9 | --batch_size $BATCH_SIZE \ 10 | --sample_size $BATCH_SIZE \ 11 | --input_height 256 \ 12 | --output_height 256 \ 13 | --lambda_val 1.0 \ 14 | --smoothing 1.0 \ 15 | --use_resize True \ 16 | --dataset wikiart \ 17 | --input_fname_pattern */*.jpg \ 18 | --crop False \ 19 | --visualize False \ 20 | --use_s3 False \ 21 | --can True \ 22 | --train \ 23 | --style_net_checkpoint "slim/logs/wikiart/inception_resnet_v2/all/bs=16,lr=0.0001,epochs=100/smol_adam_fixedLR" 24 | # --style_net_checkpoint "logs/inception_resnet_v2/" 25 | -------------------------------------------------------------------------------- /experiments/wiki_exxternal_can_ht=128,bs=16.sh: -------------------------------------------------------------------------------- 1 | # trains gan with an outside can network instead of having the discriminator learn style classification 2 | export PYTHONPATH="slim/:$PYTHONPATH" 3 | export CUDA_VISIBLE_DEVICES=1 4 | python3 main.py \ 5 | --epoch 25 \ 6 | --learning_rate .0001 \ 7 | --beta 0.5 \ 8 | --batch_size 16 \ 9 | --sample_size 72 \ 10 | --input_height 128 \ 11 | --output_height 128 \ 12 | --lambda_val 1.0 \ 13 | --smoothing 1.0 \ 14 | --use_resize True \ 15 | --dataset wikiart \ 16 | --input_fname_pattern */*.jpg \ 17 | --crop False \ 18 | --visualize False \ 19 | --use_s3 False \ 20 | --can True \ 21 | --train \ 22 | --style_net_checkpoint "slim/logs/wikiart/inception_resnet_v2/all/bs=16,lr=0.0001,epochs=100/smol_adam_fixedLR" 23 | -------------------------------------------------------------------------------- /experiments/wiki_wgan.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 27 \ 7 | --sample_size 36 \ 8 | --input_height 128 \ 9 | --output_height 128 \ 10 | --lambda_val 1.0 \ 11 | --smoothing 1.0 \ 12 | --use_resize True \ 13 | --dataset wikiart \ 14 | --input_fname_pattern */*.jpg \ 15 | --checkpoint_dir checkpoint \ 16 | --sample_dir samples \ 17 | --crop False \ 18 | --visualize False \ 19 | --replay False \ 20 | --can False \ 21 | --wgan True \ 22 | --train 23 | -------------------------------------------------------------------------------- /experiments/wiki_wgan_small.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python3 main.py \ 3 | --epoch 25 \ 4 | --learning_rate .0001 \ 5 | --beta 0.5 \ 6 | --batch_size 34 \ 7 | --sample_size 64 \ 8 | --input_height 64 \ 9 | --output_height 64 \ 10 | --lambda_val 1.0 \ 11 | --smoothing 1.0 \ 12 | --use_resize True \ 13 | --dataset wikiart \ 14 | --input_fname_pattern */*.jpg \ 15 | --checkpoint_dir checkpoint \ 16 | --sample_dir samples \ 17 | --crop False \ 18 | --visualize False \ 19 | --replay False \ 20 | --can False \ 21 | --wgan True \ 22 | --train 23 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.misc 3 | import numpy as np 4 | from glob import glob 5 | 6 | from model import DCGAN 7 | from utils import pp, visualize, show_all_variables 8 | 9 | import tensorflow as tf 10 | 11 | flags = tf.app.flags 12 | flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") 13 | flags.DEFINE_float("learning_rate", 0.0002, "Learning rate for adam [0.0002]") 14 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 15 | flags.DEFINE_float("smoothing", 0.9, "Smoothing term for discriminator real (class) loss [0.9]") 16 | flags.DEFINE_float("lambda_val", 1.0, "determines the relative importance of style ambiguity loss [1.0]") 17 | flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") 18 | flags.DEFINE_integer("save_itr", 500, "The number of iterations to run for saving checkpoints") 19 | flags.DEFINE_integer("sample_itr", 500, "The number of iterations to run for sampling from the sampler") 20 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 21 | flags.DEFINE_integer("sample_size", 64, "the size of sample images [64]") 22 | flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") 23 | flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") 24 | flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") 25 | flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") 26 | flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") 27 | flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") 28 | flags.DEFINE_string("log_dir", 'logs', "Directory to store logs [logs]") 29 | flags.DEFINE_string("checkpoint_dir", None, "Directory name to save the checkpoints [/checkpoint]") 30 | flags.DEFINE_string("sample_dir", None, "Directory name to save the image samples [/samples]") 31 | flags.DEFINE_string("load_dir", None, "Directory that specifies checkpoint to load") 32 | flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") 33 | flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") 34 | flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") 35 | flags.DEFINE_boolean("wgan", False, "True if WGAN, False if regular [G/C]AN [False]") 36 | flags.DEFINE_boolean("can", True, "True if CAN, False if regular GAN [True]") 37 | flags.DEFINE_boolean("use_s3", False, "True if you want to use s3 buckets, False if you don't. Need to set s3_bucket if True.") 38 | flags.DEFINE_string("s3_bucket", None, "the s3_bucket to upload results to") 39 | flags.DEFINE_boolean("replay", True, "True if using experience replay [True]") 40 | flags.DEFINE_boolean("use_resize", False, "True if resize conv for upsampling, False for fractionally strided conv [False]") 41 | flags.DEFINE_boolean("use_default_checkpoint", False, "True only if checkpoint_dir is None. Don't set this") 42 | flags.DEFINE_string("style_net_checkpoint", None, "The checkpoint to get style net. Leave default to note use stylenet") 43 | flags.DEFINE_boolean("allow_gpu_growth", False, "True if you want Tensorflow only to allocate the gpu memory it requires. Good for debugging, but can impact performance") 44 | FLAGS = flags.FLAGS 45 | 46 | def main(_): 47 | print('Before processing flags') 48 | pp.pprint(flags.FLAGS.__flags) 49 | if FLAGS.use_s3: 50 | import aws 51 | if FLAGS.s3_bucket is None: 52 | raise ValueError('use_s3 flag set, but no bucket set. ') 53 | # check to see if s3 bucket exists: 54 | elif not aws.bucket_exists(FLAGS.s3_bucket): 55 | raise ValueError('`use_s3` flag set, but bucket "%s" doesn\'t exist. Not using s3' % FLAGS.s3_bucket) 56 | 57 | 58 | if FLAGS.input_width is None: 59 | FLAGS.input_width = FLAGS.input_height 60 | if FLAGS.output_width is None: 61 | FLAGS.output_width = FLAGS.output_height 62 | 63 | 64 | 65 | # configure the log_dir to match the params 66 | log_dir = os.path.join(FLAGS.log_dir, "dataset={},isCan={},lr={},imsize={},hasStyleNet={},batch_size={}".format( 67 | FLAGS.dataset, 68 | FLAGS.can, 69 | FLAGS.learning_rate, 70 | FLAGS.input_height, 71 | FLAGS.style_net_checkpoint is not None, 72 | FLAGS.batch_size)) 73 | if not glob(log_dir + "*"): 74 | log_dir = os.path.join(log_dir, "000") 75 | else: 76 | containing_dir=os.path.join(log_dir, "[0-9][0-9][0-9]") 77 | nums = [int(x[-3:]) for x in glob(containing_dir)] # TODO FIX THESE HACKS 78 | if nums == []: 79 | num = 0 80 | else: 81 | num = max(nums) + 1 82 | log_dir = os.path.join(log_dir,"{:03d}".format(num)) 83 | FLAGS.log_dir = log_dir 84 | 85 | if FLAGS.checkpoint_dir is None: 86 | FLAGS.checkpoint_dir = os.path.join(FLAGS.log_dir, 'checkpoint') 87 | FLAGS.use_default_checkpoint = True 88 | elif FLAGS.use_default_checkpoint: 89 | raise ValueError('`use_default_checkpoint` flag only works if you keep checkpoint_dir as None') 90 | 91 | if FLAGS.sample_dir is None: 92 | FLAGS.sample_dir = os.path.join(FLAGS.log_dir, 'samples') 93 | 94 | if not os.path.exists(FLAGS.checkpoint_dir): 95 | os.makedirs(FLAGS.checkpoint_dir) 96 | if not os.path.exists(FLAGS.sample_dir): 97 | os.makedirs(FLAGS.sample_dir) 98 | print('After processing flags') 99 | pp.pprint(flags.FLAGS.__flags) 100 | if FLAGS.style_net_checkpoint: 101 | from slim.nets import nets_factory 102 | network_fn = nets_factory 103 | 104 | 105 | sess = None 106 | if FLAGS.dataset == 'mnist': 107 | y_dim = 10 108 | elif FLAGS.dataset == 'wikiart': 109 | y_dim = 27 110 | else: 111 | y_dim = None 112 | dcgan = DCGAN( 113 | sess, 114 | input_width=FLAGS.input_width, 115 | input_height=FLAGS.input_height, 116 | output_width=FLAGS.output_width, 117 | output_height=FLAGS.output_height, 118 | batch_size=FLAGS.batch_size, 119 | sample_num=FLAGS.sample_size, 120 | use_resize=FLAGS.use_resize, 121 | replay=FLAGS.replay, 122 | y_dim=y_dim, 123 | smoothing=FLAGS.smoothing, 124 | lamb = FLAGS.lambda_val, 125 | dataset_name=FLAGS.dataset, 126 | input_fname_pattern=FLAGS.input_fname_pattern, 127 | crop=FLAGS.crop, 128 | checkpoint_dir=FLAGS.checkpoint_dir, 129 | sample_dir=FLAGS.sample_dir, 130 | wgan=FLAGS.wgan, 131 | learning_rate = FLAGS.learning_rate, 132 | style_net_checkpoint=FLAGS.style_net_checkpoint, 133 | can=FLAGS.can) 134 | 135 | 136 | run_config = tf.ConfigProto() 137 | run_config.gpu_options.allow_growth=FLAGS.allow_gpu_growth 138 | with tf.Session(config=run_config) as sess: 139 | dcgan.set_sess(sess) 140 | # show_all_variables() 141 | 142 | if FLAGS.train: 143 | dcgan.train(FLAGS) 144 | else: 145 | if not dcgan.load(FLAGS.checkpoint_dir)[0]: 146 | raise Exception("[!] Train a model first, then run test mode") 147 | 148 | OPTION = 0 149 | visualize(sess, dcgan, FLAGS, OPTION) 150 | 151 | if __name__ == '__main__': 152 | tf.app.run() 153 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.framework import ops 6 | 7 | from utils import * 8 | 9 | try: 10 | image_summary = tf.image_summary 11 | scalar_summary = tf.scalar_summary 12 | histogram_summary = tf.histogram_summary 13 | merge_summary = tf.merge_summary 14 | SummaryWriter = tf.train.SummaryWriter 15 | except: 16 | image_summary = tf.summary.image 17 | scalar_summary = tf.summary.scalar 18 | histogram_summary = tf.summary.histogram 19 | merge_summary = tf.summary.merge 20 | SummaryWriter = tf.summary.FileWriter 21 | 22 | if "concat_v2" in dir(tf): 23 | def concat(tensors, axis, *args, **kwargs): 24 | return tf.concat_v2(tensors, axis, *args, **kwargs) 25 | else: 26 | def concat(tensors, axis, *args, **kwargs): 27 | return tf.concat(tensors, axis, *args, **kwargs) 28 | 29 | def conv_out_size_same(size, stride): 30 | return int(math.ceil(float(size) / float(stride))) 31 | 32 | def sigmoid_cross_entropy_with_logits(x, y): 33 | try: 34 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) 35 | except: 36 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y) 37 | 38 | def layer_norm(inputs, name): 39 | return tf.contrib.layers.layer_norm(inputs, scope=name) 40 | 41 | class batch_norm(object): 42 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 43 | with tf.variable_scope(name): 44 | self.epsilon = epsilon 45 | self.momentum = momentum 46 | self.name = name 47 | 48 | def __call__(self, x, train=True): 49 | return tf.contrib.layers.batch_norm(x, 50 | decay=self.momentum, 51 | updates_collections=None, 52 | epsilon=self.epsilon, 53 | scale=True, 54 | is_training=train, 55 | scope=self.name) 56 | 57 | def conv_cond_concat(x, y): 58 | """Concatenate conditioning vector on feature map axis.""" 59 | x_shapes = tf.shape(x) 60 | y_shapes = tf.shape(y) 61 | return concat([ 62 | x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 63 | 64 | def conv2d(input_, output_dim, 65 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 66 | name="conv2d",padding='SAME'): 67 | with tf.variable_scope(name): 68 | if padding=='VALID': 69 | paddings = np.array([[0,0],[1,1],[1,1],[0,0]]) 70 | input_ = tf.pad(input_, paddings) 71 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 72 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 73 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding) 74 | 75 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 76 | out_shape = [-1] + conv.get_shape()[1:].as_list() 77 | conv = tf.reshape(tf.nn.bias_add(conv, biases), out_shape) 78 | 79 | return conv 80 | 81 | def resizeconv(input_, output_shape, 82 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 83 | name="resconv"): 84 | with tf.variable_scope(name): 85 | 86 | resized = tf.image.resize_nearest_neighbor(input_,((output_shape[1]-1)*d_h + k_h-4, (output_shape[2]-1)*d_w + k_w-4)) 87 | #The 4 is because of same padding in tf.nn.conv2d. 88 | w = tf.get_variable('w', [k_h, k_w, resized.get_shape()[-1], output_shape[-1]], 89 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 90 | resconv = tf.nn.conv2d(resized, w, strides=[1, d_h, d_w, 1], padding='SAME') 91 | biases = tf.get_variable('biases', output_shape[-1], initializer=tf.constant_initializer(0.0)) 92 | 93 | return tf.nn.bias_add(resconv, biases) 94 | 95 | def deconv2d(input_, output_shape, 96 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 97 | name="deconv2d"): 98 | with tf.variable_scope(name): 99 | static_shape = input_.get_shape().as_list() 100 | dyn_input_shape = tf.shape(input_) 101 | batch_size = dyn_input_shape[0] 102 | out_h = output_shape[1] 103 | out_w = output_shape[2] 104 | out_shape = tf.stack([batch_size, out_h, out_w, output_shape[-1]]) 105 | 106 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 107 | initializer=tf.random_normal_initializer(stddev=stddev)) 108 | 109 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=out_shape, 110 | strides=[1, d_h, d_w, 1]) 111 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 112 | deconv = tf.nn.bias_add(deconv, biases) 113 | #deconv = tf.reshape(tf.nn.bias_add(deconv, biases), tf.shape(deconv)) 114 | deconv.set_shape([None] + output_shape[1:]) 115 | return deconv 116 | 117 | def lrelu(x, leak=0.2, name="lrelu"): 118 | return tf.maximum(x, leak*x) 119 | 120 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0): 121 | shape = input_.get_shape().as_list() 122 | with tf.variable_scope(scope or "Linear"): 123 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 124 | tf.random_normal_initializer(stddev=stddev)) 125 | bias = tf.get_variable("bias", [output_size], 126 | initializer=tf.constant_initializer(bias_start)) 127 | return tf.matmul(input_, matrix) + bias 128 | -------------------------------------------------------------------------------- /slim/2: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the wikiart dataset 19 | # 2. Fine-tunes an Inception Resnet V2 model on the wikiart training set. 20 | # 3. Evaluates the model on the wikiart validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_resnet_v2_on_wikiart.sh 25 | export CUDA_VISIBLE_DEVICES=1 26 | set -e 27 | 28 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 29 | PRETRAINED_CHECKPOINT_DIR=logs/pretrained 30 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 31 | MODEL_NAME=inception_resnet_v2 32 | 33 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 34 | TRAIN_DIR=logs/wikiart/${MODEL_NAME} 35 | 36 | # Where the dataset is saved to. 37 | INPUT_DATASET_DIR=/data/wikiart/ 38 | DATASET_DIR=/data/wikiart-records 39 | 40 | # Download the pre-trained checkpoint. 41 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 42 | mkdir -p ${PRETRAINED_CHECKPOINT_DIR} 43 | fi 44 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt ]; then 45 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 46 | tar -xvf inception_resnet_v2_2016_08_30.tar.gz 47 | mv inception_resnet_v2_2016_08_30.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt 48 | rm inception_resnet_v2_2016_08_30.tar.gz 49 | fi 50 | 51 | # # Download the dataset 52 | # python download_and_convert_data.py \ 53 | # --dataset_name=wikiart \ 54 | # --dataset_dir=${DATASET_DIR} 55 | # --input_dataset_dir=${INPUT_DATASET_DIR} 56 | 57 | # @philkuz I use this to create a nice initialization - haven't tried random 58 | # TODO try out if your'e curious to see whether random initialization of last 59 | # layer makes sense in this case. 60 | # Fine-tune only the new layers for 1000 steps. 61 | # python3 train_image_classifier.py \ 62 | # --train_dir=${TRAIN_DIR} \ 63 | # --dataset_name=wikiart \ 64 | # --dataset_split_name=train \ 65 | # --dataset_dir=${DATASET_DIR} \ 66 | # --model_name=${MODEL_NAME} \ 67 | # --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \ 68 | # --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 69 | # --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 70 | # --max_number_of_steps=10000 \ 71 | # --batch_size=32 \ 72 | # --learning_rate=0.01 \ 73 | # --learning_rate_decay_type=fixed \ 74 | # --save_interval_secs=300 \ 75 | # --save_summaries_secs=60 \ 76 | # --log_every_n_steps=200 \ 77 | # --optimizer=rmsprop \ 78 | # --train_image_size=256 79 | # --weight_decay=0.00004 80 | 81 | # Run evaluation. 82 | python3 eval_image_classifier.py \ 83 | --checkpoint_path=${TRAIN_DIR} \ 84 | --eval_dir=${TRAIN_DIR} \ 85 | --dataset_name=wikiart \ 86 | --dataset_split_name=validation \ 87 | --dataset_dir=${DATASET_DIR} \ 88 | --model_name=${MODEL_NAME} 89 | --eval_image_size=256 90 | 91 | # Fine-tune all the new layers for 500 steps. 92 | NUM_EPOCHS=2 93 | BATCH_SIZE=16 94 | EXPERIMENT_NAME=smol 95 | LR=0.0001 \ 96 | 97 | python3 train_image_classifier.py \ 98 | --train_dir=${TRAIN_DIR}/all \ 99 | --dataset_name=wikiart \ 100 | --dataset_split_name=train \ 101 | --dataset_dir=${DATASET_DIR} \ 102 | --model_name=${MODEL_NAME} \ 103 | --checkpoint_path=${TRAIN_DIR} \ 104 | --batch_size=${BATCH_SIZE} \ 105 | --learning_rate=${LR} \ 106 | --learning_rate_decay_type=fixed \ 107 | --save_interval_secs=300 \ 108 | --save_summaries_secs=60 \ 109 | --log_every_n_steps=200 \ 110 | --optimizer=adam \ 111 | --weight_decay=0.00004 \ 112 | --experiment_name=${EXPERIMENT_NAME} \ 113 | --num_epochs=${NUM_EPOCHS} \ 114 | --train_image_size=256 \ 115 | # --experiment_numbering # TODO flag to flip on experiment numbering independent of experiement name arg existing 116 | # # TODO catch the naming convention 117 | 118 | # Run evaluation. 119 | # TRAIN_DIR=logs/wikiart/inception_resnet_v2/all/bs=16,lr=0.0001,epochs=None/first/ 120 | python3 eval_image_classifier.py \ 121 | --checkpoint_path=${TRAIN_DIR}/all \ 122 | --eval_dir=${TRAIN_DIR} \ 123 | --dataset_name=wikiart \ 124 | --dataset_split_name=validation \ 125 | --dataset_dir=${DATASET_DIR} \ 126 | --model_name=${MODEL_NAME} 127 | --eval_image_size=256 128 | # --checkpoint_path=${TRAIN_DIR}/all \ 129 | -------------------------------------------------------------------------------- /slim/WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlberkeley/Creative-Adversarial-Networks/fea29d4348a650a40322fc4da645395d3d0f089c/slim/WORKSPACE -------------------------------------------------------------------------------- /slim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlberkeley/Creative-Adversarial-Networks/fea29d4348a650a40322fc4da645395d3d0f089c/slim/__init__.py -------------------------------------------------------------------------------- /slim/convert_wikiart.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python3 download_and_convert_data.py \ 3 | --dataset_name "wikiart" \ 4 | --dataset_dir "/data/wikiart-records" \ 5 | --input_dataset_dir "/data/wikiart/" 6 | -------------------------------------------------------------------------------- /slim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_cifar10.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [32 x 32 x 3] color image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if not reader: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /slim/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | from datasets import wikiart 26 | 27 | datasets_map = { 28 | 'cifar10': cifar10, 29 | 'flowers': flowers, 30 | 'imagenet': imagenet, 31 | 'mnist': mnist, 32 | 'wikiart': wikiart, 33 | } 34 | 35 | 36 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 37 | """Given a dataset name and a split_name returns a Dataset. 38 | 39 | Args: 40 | name: String, the name of the dataset. 41 | split_name: A train/test split name. 42 | dataset_dir: The directory where the dataset files are stored. 43 | file_pattern: The file pattern to use for matching the dataset source files. 44 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 45 | reader defined by each dataset is used. 46 | 47 | Returns: 48 | A `Dataset` class. 49 | 50 | Raises: 51 | ValueError: If the dataset `name` is unknown. 52 | """ 53 | if name not in datasets_map: 54 | raise ValueError('Name of dataset unknown %s' % name) 55 | return datasets_map[name].get_split( 56 | split_name, 57 | dataset_dir, 58 | file_pattern, 59 | reader) 60 | -------------------------------------------------------------------------------- /slim/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(values): 31 | """Returns a TF-Feature of int64s. 32 | 33 | Args: 34 | values: A scalar or list of values. 35 | 36 | Returns: 37 | A TF-Feature. 38 | """ 39 | if not isinstance(values, (tuple, list)): 40 | values = [values] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 42 | 43 | 44 | def bytes_feature(values): 45 | """Returns a TF-Feature of bytes. 46 | 47 | Args: 48 | values: A string. 49 | 50 | Returns: 51 | A TF-Feature. 52 | """ 53 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 54 | 55 | 56 | def float_feature(values): 57 | """Returns a TF-Feature of floats. 58 | 59 | Args: 60 | values: A scalar of list of values. 61 | 62 | Returns: 63 | A TF-Feature. 64 | """ 65 | if not isinstance(values, (tuple, list)): 66 | values = [values] 67 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 68 | 69 | 70 | def image_to_tfexample(image_data, image_format, height, width, class_id): 71 | return tf.train.Example(features=tf.train.Features(feature={ 72 | 'image/encoded': bytes_feature(image_data), 73 | 'image/format': bytes_feature(image_format), 74 | 'image/class/label': int64_feature(class_id), 75 | 'image/height': int64_feature(height), 76 | 'image/width': int64_feature(width), 77 | })) 78 | 79 | 80 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 81 | """Downloads the `tarball_url` and uncompresses it locally. 82 | 83 | Args: 84 | tarball_url: The URL of a tarball file. 85 | dataset_dir: The directory where the temporary files are stored. 86 | """ 87 | filename = tarball_url.split('/')[-1] 88 | filepath = os.path.join(dataset_dir, filename) 89 | 90 | def _progress(count, block_size, total_size): 91 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 92 | filename, float(count * block_size) / float(total_size) * 100.0)) 93 | sys.stdout.flush() 94 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 95 | print() 96 | statinfo = os.stat(filepath) 97 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 98 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 99 | 100 | 101 | def write_label_file(labels_to_class_names, dataset_dir, 102 | filename=LABELS_FILENAME): 103 | """Writes a file with the list of class names. 104 | 105 | Args: 106 | labels_to_class_names: A map of (integer) labels to class names. 107 | dataset_dir: The directory in which the labels file should be written. 108 | filename: The filename where the class names are written. 109 | """ 110 | labels_filename = os.path.join(dataset_dir, filename) 111 | with tf.gfile.Open(labels_filename, 'w') as f: 112 | for label in labels_to_class_names: 113 | class_name = labels_to_class_names[label] 114 | f.write('%d:%s\n' % (label, class_name)) 115 | 116 | 117 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 118 | """Specifies whether or not the dataset directory contains a label map file. 119 | 120 | Args: 121 | dataset_dir: The directory in which the labels file is found. 122 | filename: The filename where the class names are written. 123 | 124 | Returns: 125 | `True` if the labels file exists and `False` otherwise. 126 | """ 127 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 128 | 129 | 130 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 131 | """Reads the labels file and returns a mapping from ID to class name. 132 | 133 | Args: 134 | dataset_dir: The directory in which the labels file is found. 135 | filename: The filename where the class names are written. 136 | 137 | Returns: 138 | A map from a label (integer) to class name. 139 | """ 140 | labels_filename = os.path.join(dataset_dir, filename) 141 | with tf.gfile.Open(labels_filename, 'rb') as f: 142 | lines = f.read().decode() 143 | lines = lines.split('\n') 144 | lines = filter(None, lines) 145 | 146 | labels_to_class_names = {} 147 | for line in lines: 148 | index = line.index(':') 149 | labels_to_class_names[int(line[:index])] = line[index+1:] 150 | return labels_to_class_names 151 | -------------------------------------------------------------------------------- /slim/datasets/download_and_convert_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the cifar10 data, uncompresses it, reads the files 18 | that make up the cifar10 data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take several minutes to run. 23 | 24 | """ 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | import sys 31 | import tarfile 32 | 33 | import numpy as np 34 | from six.moves import cPickle 35 | from six.moves import urllib 36 | import tensorflow as tf 37 | 38 | from datasets import dataset_utils 39 | 40 | # The URL where the CIFAR data can be downloaded. 41 | _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 42 | 43 | # The number of training files. 44 | _NUM_TRAIN_FILES = 5 45 | 46 | # The height and width of each image. 47 | _IMAGE_SIZE = 32 48 | 49 | # The names of the classes. 50 | _CLASS_NAMES = [ 51 | 'airplane', 52 | 'automobile', 53 | 'bird', 54 | 'cat', 55 | 'deer', 56 | 'dog', 57 | 'frog', 58 | 'horse', 59 | 'ship', 60 | 'truck', 61 | ] 62 | 63 | 64 | def _add_to_tfrecord(filename, tfrecord_writer, offset=0): 65 | """Loads data from the cifar10 pickle files and writes files to a TFRecord. 66 | 67 | Args: 68 | filename: The filename of the cifar10 pickle file. 69 | tfrecord_writer: The TFRecord writer to use for writing. 70 | offset: An offset into the absolute number of images previously written. 71 | 72 | Returns: 73 | The new offset. 74 | """ 75 | with tf.gfile.Open(filename, 'rb') as f: 76 | if sys.version_info < (3,): 77 | data = cPickle.load(f) 78 | else: 79 | data = cPickle.load(f, encoding='bytes') 80 | 81 | images = data[b'data'] 82 | num_images = images.shape[0] 83 | 84 | images = images.reshape((num_images, 3, 32, 32)) 85 | labels = data[b'labels'] 86 | 87 | with tf.Graph().as_default(): 88 | image_placeholder = tf.placeholder(dtype=tf.uint8) 89 | encoded_image = tf.image.encode_png(image_placeholder) 90 | 91 | with tf.Session('') as sess: 92 | 93 | for j in range(num_images): 94 | sys.stdout.write('\r>> Reading file [%s] image %d/%d' % ( 95 | filename, offset + j + 1, offset + num_images)) 96 | sys.stdout.flush() 97 | 98 | image = np.squeeze(images[j]).transpose((1, 2, 0)) 99 | label = labels[j] 100 | 101 | png_string = sess.run(encoded_image, 102 | feed_dict={image_placeholder: image}) 103 | 104 | example = dataset_utils.image_to_tfexample( 105 | png_string, b'png', _IMAGE_SIZE, _IMAGE_SIZE, label) 106 | tfrecord_writer.write(example.SerializeToString()) 107 | 108 | return offset + num_images 109 | 110 | 111 | def _get_output_filename(dataset_dir, split_name): 112 | """Creates the output filename. 113 | 114 | Args: 115 | dataset_dir: The dataset directory where the dataset is stored. 116 | split_name: The name of the train/test split. 117 | 118 | Returns: 119 | An absolute file path. 120 | """ 121 | return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name) 122 | 123 | 124 | def _download_and_uncompress_dataset(dataset_dir): 125 | """Downloads cifar10 and uncompresses it locally. 126 | 127 | Args: 128 | dataset_dir: The directory where the temporary files are stored. 129 | """ 130 | filename = _DATA_URL.split('/')[-1] 131 | filepath = os.path.join(dataset_dir, filename) 132 | 133 | if not os.path.exists(filepath): 134 | def _progress(count, block_size, total_size): 135 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 136 | filename, float(count * block_size) / float(total_size) * 100.0)) 137 | sys.stdout.flush() 138 | filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress) 139 | print() 140 | statinfo = os.stat(filepath) 141 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 142 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 143 | 144 | 145 | def _clean_up_temporary_files(dataset_dir): 146 | """Removes temporary files used to create the dataset. 147 | 148 | Args: 149 | dataset_dir: The directory where the temporary files are stored. 150 | """ 151 | filename = _DATA_URL.split('/')[-1] 152 | filepath = os.path.join(dataset_dir, filename) 153 | tf.gfile.Remove(filepath) 154 | 155 | tmp_dir = os.path.join(dataset_dir, 'cifar-10-batches-py') 156 | tf.gfile.DeleteRecursively(tmp_dir) 157 | 158 | 159 | def run(dataset_dir): 160 | """Runs the download and conversion operation. 161 | 162 | Args: 163 | dataset_dir: The dataset directory where the dataset is stored. 164 | """ 165 | if not tf.gfile.Exists(dataset_dir): 166 | tf.gfile.MakeDirs(dataset_dir) 167 | 168 | training_filename = _get_output_filename(dataset_dir, 'train') 169 | testing_filename = _get_output_filename(dataset_dir, 'test') 170 | 171 | if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename): 172 | print('Dataset files already exist. Exiting without re-creating them.') 173 | return 174 | 175 | dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 176 | 177 | # First, process the training data: 178 | with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer: 179 | offset = 0 180 | for i in range(_NUM_TRAIN_FILES): 181 | filename = os.path.join(dataset_dir, 182 | 'cifar-10-batches-py', 183 | 'data_batch_%d' % (i + 1)) # 1-indexed. 184 | offset = _add_to_tfrecord(filename, tfrecord_writer, offset) 185 | 186 | # Next, process the testing data: 187 | with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer: 188 | filename = os.path.join(dataset_dir, 189 | 'cifar-10-batches-py', 190 | 'test_batch') 191 | _add_to_tfrecord(filename, tfrecord_writer) 192 | 193 | # Finally, write the labels file: 194 | labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) 195 | dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 196 | 197 | _clean_up_temporary_files(dataset_dir) 198 | print('\nFinished converting the Cifar10 dataset!') 199 | -------------------------------------------------------------------------------- /slim/datasets/download_and_convert_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download and preprocess ImageNet Challenge 2012 18 | # training and validation data set. 19 | # 20 | # The final output of this script are sharded TFRecord files containing 21 | # serialized Example protocol buffers. See build_imagenet_data.py for 22 | # details of how the Example protocol buffers contain the ImageNet data. 23 | # 24 | # The final output of this script appears as such: 25 | # 26 | # data_dir/train-00000-of-01024 27 | # data_dir/train-00001-of-01024 28 | # ... 29 | # data_dir/train-00127-of-01024 30 | # 31 | # and 32 | # 33 | # data_dir/validation-00000-of-00128 34 | # data_dir/validation-00001-of-00128 35 | # ... 36 | # data_dir/validation-00127-of-00128 37 | # 38 | # Note that this script may take several hours to run to completion. The 39 | # conversion of the ImageNet data to TFRecords alone takes 2-3 hours depending 40 | # on the speed of your machine. Please be patient. 41 | # 42 | # **IMPORTANT** 43 | # To download the raw images, the user must create an account with image-net.org 44 | # and generate a username and access_key. The latter two are required for 45 | # downloading the raw images. 46 | # 47 | # usage: 48 | # cd research/slim 49 | # bazel build :download_and_convert_imagenet 50 | # ./bazel-bin/download_and_convert_imagenet.sh [data-dir] 51 | set -e 52 | 53 | if [ -z "$1" ]; then 54 | echo "usage download_and_convert_imagenet.sh [data dir]" 55 | exit 56 | fi 57 | 58 | # Create the output and temporary directories. 59 | DATA_DIR="${1%/}" 60 | SCRATCH_DIR="${DATA_DIR}/raw-data/" 61 | mkdir -p "${DATA_DIR}" 62 | mkdir -p "${SCRATCH_DIR}" 63 | WORK_DIR="$0.runfiles/__main__" 64 | 65 | # Download the ImageNet data. 66 | LABELS_FILE="${WORK_DIR}/datasets/imagenet_lsvrc_2015_synsets.txt" 67 | DOWNLOAD_SCRIPT="${WORK_DIR}/datasets/download_imagenet.sh" 68 | "${DOWNLOAD_SCRIPT}" "${SCRATCH_DIR}" "${LABELS_FILE}" 69 | 70 | # Note the locations of the train and validation data. 71 | TRAIN_DIRECTORY="${SCRATCH_DIR}train/" 72 | VALIDATION_DIRECTORY="${SCRATCH_DIR}validation/" 73 | 74 | # Preprocess the validation data by moving the images into the appropriate 75 | # sub-directory based on the label (synset) of the image. 76 | echo "Organizing the validation data into sub-directories." 77 | PREPROCESS_VAL_SCRIPT="${WORK_DIR}/datasets/preprocess_imagenet_validation_data.py" 78 | VAL_LABELS_FILE="${WORK_DIR}/datasets/imagenet_2012_validation_synset_labels.txt" 79 | 80 | "${PREPROCESS_VAL_SCRIPT}" "${VALIDATION_DIRECTORY}" "${VAL_LABELS_FILE}" 81 | 82 | # Convert the XML files for bounding box annotations into a single CSV. 83 | echo "Extracting bounding box information from XML." 84 | BOUNDING_BOX_SCRIPT="${WORK_DIR}/datasets/process_bounding_boxes.py" 85 | BOUNDING_BOX_FILE="${SCRATCH_DIR}/imagenet_2012_bounding_boxes.csv" 86 | BOUNDING_BOX_DIR="${SCRATCH_DIR}bounding_boxes/" 87 | 88 | "${BOUNDING_BOX_SCRIPT}" "${BOUNDING_BOX_DIR}" "${LABELS_FILE}" \ 89 | | sort >"${BOUNDING_BOX_FILE}" 90 | echo "Finished downloading and preprocessing the ImageNet data." 91 | 92 | # Build the TFRecords version of the ImageNet data. 93 | BUILD_SCRIPT="${WORK_DIR}/build_imagenet_data" 94 | OUTPUT_DIRECTORY="${DATA_DIR}" 95 | IMAGENET_METADATA_FILE="${WORK_DIR}/datasets/imagenet_metadata.txt" 96 | 97 | "${BUILD_SCRIPT}" \ 98 | --train_directory="${TRAIN_DIRECTORY}" \ 99 | --validation_directory="${VALIDATION_DIRECTORY}" \ 100 | --output_directory="${OUTPUT_DIRECTORY}" \ 101 | --imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \ 102 | --labels_file="${LABELS_FILE}" \ 103 | --bounding_box_file="${BOUNDING_BOX_FILE}" 104 | -------------------------------------------------------------------------------- /slim/datasets/download_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download ImageNet Challenge 2012 training and validation data set. 18 | # 19 | # Downloads and decompresses raw images and bounding boxes. 20 | # 21 | # **IMPORTANT** 22 | # To download the raw images, the user must create an account with image-net.org 23 | # and generate a username and access_key. The latter two are required for 24 | # downloading the raw images. 25 | # 26 | # usage: 27 | # ./download_imagenet.sh [dirname] 28 | set -e 29 | 30 | if [ "x$IMAGENET_ACCESS_KEY" == x -o "x$IMAGENET_USERNAME" == x ]; then 31 | cat < ') 61 | sys.exit(-1) 62 | data_dir = sys.argv[1] 63 | validation_labels_file = sys.argv[2] 64 | 65 | # Read in the 50000 synsets associated with the validation data set. 66 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 67 | unique_labels = set(labels) 68 | 69 | # Make all sub-directories in the validation data dir. 70 | for label in unique_labels: 71 | labeled_data_dir = os.path.join(data_dir, label) 72 | os.makedirs(labeled_data_dir) 73 | 74 | # Move all of the image to the appropriate sub-directory. 75 | for i in xrange(len(labels)): 76 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 77 | original_filename = os.path.join(data_dir, basename) 78 | if not os.path.exists(original_filename): 79 | print('Failed to find: ' % original_filename) 80 | sys.exit(-1) 81 | new_filename = os.path.join(data_dir, labels[i], basename) 82 | os.rename(original_filename, new_filename) 83 | -------------------------------------------------------------------------------- /slim/datasets/wikiart.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the flowers dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_flowers.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'wikiart_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 65445, 'validation': 8000, 'test': 8000} 35 | 36 | _NUM_CLASSES = 27 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 26', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading flowers. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /slim/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/download_and_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import download_and_convert_cifar10 40 | from datasets import download_and_convert_flowers 41 | from datasets import download_and_convert_mnist 42 | from datasets import convert_wikiart 43 | 44 | FLAGS = tf.app.flags.FLAGS 45 | 46 | tf.app.flags.DEFINE_string( 47 | 'dataset_name', 48 | None, 49 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 50 | 51 | tf.app.flags.DEFINE_string( 52 | 'dataset_dir', 53 | None, 54 | 'The directory where the output TFRecords and temporary files are saved.') 55 | tf.app.flags.DEFINE_string( 56 | 'input_dataset_dir', 57 | None, 58 | 'The input directory where the images are stored') 59 | 60 | 61 | def main(_): 62 | if not FLAGS.dataset_name: 63 | raise ValueError('You must supply the dataset name with --dataset_name') 64 | if not FLAGS.dataset_dir: 65 | raise ValueError('You must supply the dataset directory with --dataset_dir') 66 | 67 | if FLAGS.dataset_name == 'cifar10': 68 | download_and_convert_cifar10.run(FLAGS.dataset_dir) 69 | elif FLAGS.dataset_name == 'flowers': 70 | download_and_convert_flowers.run(FLAGS.dataset_dir) 71 | elif FLAGS.dataset_name == 'mnist': 72 | download_and_convert_mnist.run(FLAGS.dataset_dir) 73 | elif FLAGS.dataset_name == 'wikiart': 74 | if not FLAGS.input_dataset_dir is None: 75 | convert_wikiart.run(FLAGS.input_dataset_dir, FLAGS.dataset_dir) 76 | 77 | else: 78 | raise ValueError("For wikiart, you must supply a valid input directory with --input_dataset_dir") 79 | else: 80 | raise ValueError( 81 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) 82 | 83 | if __name__ == '__main__': 84 | tf.app.run() 85 | -------------------------------------------------------------------------------- /slim/eval_image_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Generic evaluation script that evaluates a model using a given dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import nets_factory 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | tf.app.flags.DEFINE_integer( 31 | 'batch_size', 50, 'The number of samples in each batch.') 32 | 33 | tf.app.flags.DEFINE_integer( 34 | 'max_num_batches', None, 35 | 'Max number of batches to evaluate by default use all.') 36 | 37 | tf.app.flags.DEFINE_string( 38 | 'master', '', 'The address of the TensorFlow master to use.') 39 | 40 | tf.app.flags.DEFINE_string( 41 | 'checkpoint_path', '/tmp/tfmodel/', 42 | 'The directory where the model was written to or an absolute path to a ' 43 | 'checkpoint file.') 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.') 47 | 48 | tf.app.flags.DEFINE_integer( 49 | 'num_preprocessing_threads', 4, 50 | 'The number of threads used to create the batches.') 51 | 52 | tf.app.flags.DEFINE_string( 53 | 'dataset_name', 'imagenet', 'The name of the dataset to load.') 54 | 55 | tf.app.flags.DEFINE_string( 56 | 'dataset_split_name', 'test', 'The name of the train/test split.') 57 | 58 | tf.app.flags.DEFINE_string( 59 | 'dataset_dir', None, 'The directory where the dataset files are stored.') 60 | 61 | tf.app.flags.DEFINE_integer( 62 | 'labels_offset', 0, 63 | 'An offset for the labels in the dataset. This flag is primarily used to ' 64 | 'evaluate the VGG and ResNet architectures which do not use a background ' 65 | 'class for the ImageNet dataset.') 66 | 67 | tf.app.flags.DEFINE_string( 68 | 'model_name', 'inception_v3', 'The name of the architecture to evaluate.') 69 | 70 | tf.app.flags.DEFINE_string( 71 | 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 72 | 'as `None`, then the model_name flag is used.') 73 | 74 | tf.app.flags.DEFINE_float( 75 | 'moving_average_decay', None, 76 | 'The decay to use for the moving average.' 77 | 'If left as None, then moving averages are not used.') 78 | 79 | tf.app.flags.DEFINE_integer( 80 | 'eval_image_size', None, 'Eval image size') 81 | 82 | FLAGS = tf.app.flags.FLAGS 83 | 84 | 85 | def main(_): 86 | if not FLAGS.dataset_dir: 87 | raise ValueError('You must supply the dataset directory with --dataset_dir') 88 | 89 | tf.logging.set_verbosity(tf.logging.INFO) 90 | with tf.Graph().as_default(): 91 | tf_global_step = slim.get_or_create_global_step() 92 | 93 | ###################### 94 | # Select the dataset # 95 | ###################### 96 | dataset = dataset_factory.get_dataset( 97 | FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 98 | 99 | #################### 100 | # Select the model # 101 | #################### 102 | network_fn = nets_factory.get_network_fn( 103 | FLAGS.model_name, 104 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 105 | is_training=False) 106 | 107 | ############################################################## 108 | # Create a dataset provider that loads data from the dataset # 109 | ############################################################## 110 | provider = slim.dataset_data_provider.DatasetDataProvider( 111 | dataset, 112 | shuffle=False, 113 | common_queue_capacity=2 * FLAGS.batch_size, 114 | common_queue_min=FLAGS.batch_size) 115 | [image, label] = provider.get(['image', 'label']) 116 | label -= FLAGS.labels_offset 117 | 118 | ##################################### 119 | # Select the preprocessing function # 120 | ##################################### 121 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 122 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 123 | preprocessing_name, 124 | is_training=False) 125 | 126 | eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size 127 | 128 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size) 129 | 130 | images, labels = tf.train.batch( 131 | [image, label], 132 | batch_size=FLAGS.batch_size, 133 | num_threads=FLAGS.num_preprocessing_threads, 134 | capacity=5 * FLAGS.batch_size) 135 | 136 | #################### 137 | # Define the model # 138 | #################### 139 | logits, _ = network_fn(images) 140 | 141 | if FLAGS.moving_average_decay: 142 | variable_averages = tf.train.ExponentialMovingAverage( 143 | FLAGS.moving_average_decay, tf_global_step) 144 | variables_to_restore = variable_averages.variables_to_restore( 145 | slim.get_model_variables()) 146 | variables_to_restore[tf_global_step.op.name] = tf_global_step 147 | else: 148 | variables_to_restore = slim.get_variables_to_restore() 149 | 150 | predictions = tf.argmax(logits, 1) 151 | labels = tf.squeeze(labels) 152 | 153 | # Define the metrics: 154 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 155 | 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels), 156 | 'Recall_5': slim.metrics.streaming_recall_at_k( 157 | logits, labels, 5), 158 | }) 159 | 160 | # Print the summaries to screen. 161 | for name, value in names_to_values.items(): 162 | summary_name = 'eval/%s' % name 163 | op = tf.summary.scalar(summary_name, value, collections=[]) 164 | op = tf.Print(op, [value], summary_name) 165 | tf.add_to_collection(tf.GraphKeys.SUMMARIES, op) 166 | 167 | # TODO(sguada) use num_epochs=1 168 | if FLAGS.max_num_batches: 169 | num_batches = FLAGS.max_num_batches 170 | else: 171 | # This ensures that we make a single pass over all of the data. 172 | num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size)) 173 | 174 | if tf.gfile.IsDirectory(FLAGS.checkpoint_path): 175 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 176 | else: 177 | checkpoint_path = FLAGS.checkpoint_path 178 | 179 | tf.logging.info('Evaluating %s' % checkpoint_path) 180 | 181 | slim.evaluation.evaluate_once( 182 | master=FLAGS.master, 183 | checkpoint_path=checkpoint_path, 184 | logdir=FLAGS.eval_dir, 185 | num_evals=num_batches, 186 | eval_op=list(names_to_updates.values()), 187 | variables_to_restore=variables_to_restore) 188 | 189 | 190 | if __name__ == '__main__': 191 | tf.app.run() 192 | -------------------------------------------------------------------------------- /slim/eval_wikiart_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the wikiart dataset 19 | # 2. Fine-tunes an Inception Resnet V2 model on the wikiart training set. 20 | # 3. Evaluates the model on the wikiart validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_resnet_v2_on_wikiart.sh 25 | export CUDA_VISIBLE_DEVICES=1 26 | set -e 27 | 28 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 29 | PRETRAINED_CHECKPOINT_DIR=logs/pretrained 30 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 31 | MODEL_NAME=inception_resnet_v2 32 | 33 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 34 | TRAIN_DIR=logs/wikiart/${MODEL_NAME} 35 | 36 | # Where the dataset is saved to. 37 | INPUT_DATASET_DIR=/data/wikiart/ 38 | DATASET_DIR=/data/wikiart-records 39 | 40 | # Download the pre-trained checkpoint. 41 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 42 | mkdir -p ${PRETRAINED_CHECKPOINT_DIR} 43 | fi 44 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt ]; then 45 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 46 | tar -xvf inception_resnet_v2_2016_08_30.tar.gz 47 | mv inception_resnet_v2_2016_08_30.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt 48 | rm inception_resnet_v2_2016_08_30.tar.gz 49 | fi 50 | 51 | # # Download the dataset 52 | # python download_and_convert_data.py \ 53 | # --dataset_name=wikiart \ 54 | # --dataset_dir=${DATASET_DIR} 55 | # --input_dataset_dir=${INPUT_DATASET_DIR} 56 | 57 | # @philkuz I use this to create a nice initialization - haven't tried random 58 | # TODO try out if your'e curious to see whether random initialization of last 59 | # layer makes sense in this case. 60 | # Fine-tune only the new layers for 1000 steps. 61 | # python3 train_image_classifier.py \ 62 | # --train_dir=${TRAIN_DIR} \ 63 | # --dataset_name=wikiart \ 64 | # --dataset_split_name=train \ 65 | # --dataset_dir=${DATASET_DIR} \ 66 | # --model_name=${MODEL_NAME} \ 67 | # --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \ 68 | # --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 69 | # --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 70 | # --max_number_of_steps=10000 \ 71 | # --batch_size=32 \ 72 | # --learning_rate=0.01 \ 73 | # --learning_rate_decay_type=fixed \ 74 | # --save_interval_secs=300 \ 75 | # --save_summaries_secs=60 \ 76 | # --log_every_n_steps=200 \ 77 | # --optimizer=rmsprop \ 78 | # --train_image_size=256 \ 79 | # --weight_decay=0.00004 80 | 81 | # Run evaluation. 82 | # python3 eval_image_classifier.py \ 83 | # --checkpoint_path=${TRAIN_DIR} \ 84 | # --eval_dir=${TRAIN_DIR} \ 85 | # --dataset_name=wikiart \ 86 | # --dataset_split_name=validation \ 87 | # --dataset_dir=${DATASET_DIR} \ 88 | # --model_name=${MODEL_NAME} \ 89 | # --eval_image_size=256 90 | 91 | # Fine-tune all the new layers for 500 steps. 92 | NUM_EPOCHS=2 93 | BATCH_SIZE=16 94 | EXPERIMENT_NAME=smol_adam 95 | LR=0.0001 \ 96 | 97 | # python3 train_image_classifier.py \ 98 | # --train_dir=${TRAIN_DIR}/all \ 99 | # --dataset_name=wikiart \ 100 | # --dataset_split_name=train \ 101 | # --dataset_dir=${DATASET_DIR} \ 102 | # --model_name=${MODEL_NAME} \ 103 | # --checkpoint_path=${TRAIN_DIR} \ 104 | # --batch_size=${BATCH_SIZE} \ 105 | # --learning_rate=${LR} \ 106 | # --learning_rate_decay_type=exponential \ 107 | # --save_interval_secs=300 \ 108 | # --save_summaries_secs=60 \ 109 | # --num_epochs_per_decay=0.2 \ 110 | # --log_every_n_steps=200 \ 111 | # --optimizer=adam \ 112 | # --weight_decay=0.00004 \ 113 | # --experiment_name=${EXPERIMENT_NAME} \ 114 | # --num_epochs=${NUM_EPOCHS} \ 115 | # --train_image_size=256 \ 116 | # --continue_training False \ 117 | # --experiment_numbering # TODO flag to flip on experiment numbering independent of experiement name arg existing 118 | # # TODO catch the naming convention 119 | 120 | # Run evaluation. 121 | EVAL_DIR=logs/wikiart/inception_resnet_v2/all/bs=${BATCH_SIZE},lr=${LR},epochs=${NUM_EPOCHS}/${EXPERIMENT_NAME} 122 | python3 eval_image_classifier.py \ 123 | --checkpoint_path=${EVAL_DIR} \ 124 | --eval_dir=${EVAL_DIR} \ 125 | --dataset_name=wikiart \ 126 | --dataset_split_name=validation \ 127 | --dataset_dir=${DATASET_DIR} \ 128 | --model_name=${MODEL_NAME} \ 129 | --eval_image_size=256 \ 130 | -------------------------------------------------------------------------------- /slim/export_inference_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Saves out a GraphDef containing the architecture of the model. 16 | 17 | To use it, run something like this, with a model name defined by slim: 18 | 19 | bazel build tensorflow_models/research/slim:export_inference_graph 20 | bazel-bin/tensorflow_models/research/slim/export_inference_graph \ 21 | --model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb 22 | 23 | If you then want to use the resulting model with your own or pretrained 24 | checkpoints as part of a mobile model, you can run freeze_graph to get a graph 25 | def with the variables inlined as constants using: 26 | 27 | bazel build tensorflow/python/tools:freeze_graph 28 | bazel-bin/tensorflow/python/tools/freeze_graph \ 29 | --input_graph=/tmp/inception_v3_inf_graph.pb \ 30 | --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \ 31 | --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \ 32 | --output_node_names=InceptionV3/Predictions/Reshape_1 33 | 34 | The output node names will vary depending on the model, but you can inspect and 35 | estimate them using the summarize_graph tool: 36 | 37 | bazel build tensorflow/tools/graph_transforms:summarize_graph 38 | bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ 39 | --in_graph=/tmp/inception_v3_inf_graph.pb 40 | 41 | To run the resulting graph in C++, you can look at the label_image sample code: 42 | 43 | bazel build tensorflow/examples/label_image:label_image 44 | bazel-bin/tensorflow/examples/label_image/label_image \ 45 | --image=${HOME}/Pictures/flowers.jpg \ 46 | --input_layer=input \ 47 | --output_layer=InceptionV3/Predictions/Reshape_1 \ 48 | --graph=/tmp/frozen_inception_v3.pb \ 49 | --labels=/tmp/imagenet_slim_labels.txt \ 50 | --input_mean=0 \ 51 | --input_std=255 52 | 53 | """ 54 | 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | 61 | from tensorflow.python.platform import gfile 62 | from datasets import dataset_factory 63 | from nets import nets_factory 64 | 65 | 66 | slim = tf.contrib.slim 67 | 68 | tf.app.flags.DEFINE_string( 69 | 'model_name', 'inception_v3', 'The name of the architecture to save.') 70 | 71 | tf.app.flags.DEFINE_boolean( 72 | 'is_training', False, 73 | 'Whether to save out a training-focused version of the model.') 74 | 75 | tf.app.flags.DEFINE_integer( 76 | 'image_size', None, 77 | 'The image size to use, otherwise use the model default_image_size.') 78 | 79 | tf.app.flags.DEFINE_integer( 80 | 'batch_size', None, 81 | 'Batch size for the exported model. Defaulted to "None" so batch size can ' 82 | 'be specified at model runtime.') 83 | 84 | tf.app.flags.DEFINE_string('dataset_name', 'imagenet', 85 | 'The name of the dataset to use with the model.') 86 | 87 | tf.app.flags.DEFINE_integer( 88 | 'labels_offset', 0, 89 | 'An offset for the labels in the dataset. This flag is primarily used to ' 90 | 'evaluate the VGG and ResNet architectures which do not use a background ' 91 | 'class for the ImageNet dataset.') 92 | 93 | tf.app.flags.DEFINE_string( 94 | 'output_file', '', 'Where to save the resulting file to.') 95 | 96 | tf.app.flags.DEFINE_string( 97 | 'dataset_dir', '', 'Directory to save intermediate dataset files to') 98 | 99 | FLAGS = tf.app.flags.FLAGS 100 | 101 | 102 | def main(_): 103 | if not FLAGS.output_file: 104 | raise ValueError('You must supply the path to save to with --output_file') 105 | tf.logging.set_verbosity(tf.logging.INFO) 106 | with tf.Graph().as_default() as graph: 107 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train', 108 | FLAGS.dataset_dir) 109 | network_fn = nets_factory.get_network_fn( 110 | FLAGS.model_name, 111 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 112 | is_training=FLAGS.is_training) 113 | image_size = FLAGS.image_size or network_fn.default_image_size 114 | placeholder = tf.placeholder(name='input', dtype=tf.float32, 115 | shape=[FLAGS.batch_size, image_size, 116 | image_size, 3]) 117 | network_fn(placeholder) 118 | graph_def = graph.as_graph_def() 119 | with gfile.GFile(FLAGS.output_file, 'wb') as f: 120 | f.write(graph_def.SerializeToString()) 121 | 122 | 123 | if __name__ == '__main__': 124 | tf.app.run() 125 | -------------------------------------------------------------------------------- /slim/export_inference_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for export_inference_graph.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | 25 | import tensorflow as tf 26 | 27 | from tensorflow.python.platform import gfile 28 | import export_inference_graph 29 | 30 | 31 | class ExportInferenceGraphTest(tf.test.TestCase): 32 | 33 | def testExportInferenceGraph(self): 34 | tmpdir = self.get_temp_dir() 35 | output_file = os.path.join(tmpdir, 'inception_v3.pb') 36 | flags = tf.app.flags.FLAGS 37 | flags.output_file = output_file 38 | flags.model_name = 'inception_v3' 39 | flags.dataset_dir = tmpdir 40 | export_inference_graph.main(None) 41 | self.assertTrue(gfile.Exists(output_file)) 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /slim/finetune_inception_resnet_v2_on_wikiart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the wikiart dataset 19 | # 2. Fine-tunes an Inception Resnet V2 model on the wikiart training set. 20 | # 3. Evaluates the model on the wikiart validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_resnet_v2_on_wikiart.sh 25 | export CUDA_VISIBLE_DEVICES=1 26 | set -e 27 | 28 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 29 | PRETRAINED_CHECKPOINT_DIR=logs/pretrained 30 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 31 | MODEL_NAME=inception_resnet_v2 32 | 33 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 34 | TRAIN_DIR=logs/wikiart/${MODEL_NAME} 35 | 36 | # Where the dataset is saved to. 37 | INPUT_DATASET_DIR=/data/wikiart/ 38 | DATASET_DIR=/data/wikiart-records 39 | 40 | # Download the pre-trained checkpoint. 41 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 42 | mkdir -p ${PRETRAINED_CHECKPOINT_DIR} 43 | fi 44 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt ]; then 45 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 46 | tar -xvf inception_resnet_v2_2016_08_30.tar.gz 47 | mv inception_resnet_v2_2016_08_30.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt 48 | rm inception_resnet_v2_2016_08_30.tar.gz 49 | fi 50 | 51 | # # Download the dataset 52 | python download_and_convert_data.py \ 53 | --dataset_name=wikiart \ 54 | --dataset_dir=${DATASET_DIR} 55 | --input_dataset_dir=${INPUT_DATASET_DIR} 56 | 57 | # @philkuz I use this to create a nice initialization - haven't tried random 58 | # TODO try out if your'e curious to see whether random initialization of last 59 | # layer makes sense in this case. 60 | # Fine-tune only the last layer for 1000 steps. 61 | python3 train_image_classifier.py \ 62 | --train_dir=${TRAIN_DIR} \ 63 | --dataset_name=wikiart \ 64 | --dataset_split_name=train \ 65 | --dataset_dir=${DATASET_DIR} \ 66 | --model_name=${MODEL_NAME} \ 67 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \ 68 | --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 69 | --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 70 | --max_number_of_steps=10000 \ 71 | --batch_size=32 \ 72 | --learning_rate=0.01 \ 73 | --learning_rate_decay_type=fixed \ 74 | --save_interval_secs=300 \ 75 | --save_summaries_secs=60 \ 76 | --log_every_n_steps=200 \ 77 | --optimizer=rmsprop \ 78 | --train_image_size=256 \ 79 | --weight_decay=0.00004 80 | 81 | # Run evaluation. 82 | python3 eval_image_classifier.py \ 83 | --checkpoint_path=${TRAIN_DIR} \ 84 | --eval_dir=${TRAIN_DIR} \ 85 | --dataset_name=wikiart \ 86 | --dataset_split_name=validation \ 87 | --dataset_dir=${DATASET_DIR} \ 88 | --model_name=${MODEL_NAME} \ 89 | --eval_image_size=256 90 | 91 | # Fine-tune all the new layers for 500 steps. 92 | NUM_EPOCHS=100 93 | BATCH_SIZE=16 94 | EXPERIMENT_NAME=inception_resnet_v2 95 | LR=0.0001 \ 96 | 97 | TRAIN_DIR=logs/wikiart/inception_resnet_v2/experiments/${EXPERIMENT_NAME}/bs=${BATCH_SIZE},lr=${LR},epochs=${NUM_EPOCHS}/ 98 | 99 | python3 train_image_classifier.py \ 100 | --train_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=wikiart \ 102 | --dataset_split_name=train \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=${MODEL_NAME} \ 105 | --checkpoint_path=${TRAIN_DIR} \ 106 | --batch_size=${BATCH_SIZE} \ 107 | --learning_rate=${LR} \ 108 | --learning_rate_decay_type=fixed \ 109 | --save_interval_secs=300 \ 110 | --save_summaries_secs=60 \ 111 | --num_epochs_per_decay=1 \ 112 | --log_every_n_steps=200 \ 113 | --optimizer=adam \ 114 | --weight_decay=0.00004 \ 115 | --experiment_name=${EXPERIMENT_NAME} \ 116 | --num_epochs=${NUM_EPOCHS} \ 117 | --train_image_size=256 \ 118 | --continue_training False \ 119 | # --experiment_numbering # TODO flag to flip on experiment numbering independent of experiement name arg existing 120 | # # TODO catch the naming convention 121 | 122 | # Run evaluation. 123 | EVAL_DIR=logs/wikiart/inception_resnet_v2/all/bs=${BATCH_SIZE},lr=${LR},epochs=${NUM_EPOCHS}/${EXPERIMENT_NAME} 124 | python3 eval_image_classifier.py \ 125 | --checkpoint_path=${EVAL_DIR} \ 126 | --eval_dir=${EVAL_DIR} \ 127 | --dataset_name=wikiart \ 128 | --dataset_split_name=validation \ 129 | --dataset_dir=${DATASET_DIR} \ 130 | --model_name=${MODEL_NAME} \ 131 | --eval_image_size=256 \ 132 | -------------------------------------------------------------------------------- /slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /slim/nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /slim/nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /slim/nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import dcgan 23 | 24 | 25 | class DCGANTest(tf.test.TestCase): 26 | 27 | def test_generator_run(self): 28 | tf.set_random_seed(1234) 29 | noise = tf.random_normal([100, 64]) 30 | image, _ = dcgan.generator(noise) 31 | with self.test_session() as sess: 32 | sess.run(tf.global_variables_initializer()) 33 | image.eval() 34 | 35 | def test_generator_graph(self): 36 | tf.set_random_seed(1234) 37 | # Check graph construction for a number of image size/depths and batch 38 | # sizes. 39 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 40 | tf.reset_default_graph() 41 | final_size = 2 ** i 42 | noise = tf.random_normal([batch_size, 64]) 43 | image, end_points = dcgan.generator( 44 | noise, 45 | depth=32, 46 | final_size=final_size) 47 | 48 | self.assertAllEqual([batch_size, final_size, final_size, 3], 49 | image.shape.as_list()) 50 | 51 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 52 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 53 | 54 | # Check layer depths. 55 | for j in range(1, i): 56 | layer = end_points['deconv%i' % j] 57 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 58 | 59 | def test_generator_invalid_input(self): 60 | wrong_dim_input = tf.zeros([5, 32, 32]) 61 | with self.assertRaises(ValueError): 62 | dcgan.generator(wrong_dim_input) 63 | 64 | correct_input = tf.zeros([3, 2]) 65 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 66 | dcgan.generator(correct_input, final_size=30) 67 | 68 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 69 | dcgan.generator(correct_input, final_size=4) 70 | 71 | def test_discriminator_run(self): 72 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 73 | output, _ = dcgan.discriminator(image) 74 | with self.test_session() as sess: 75 | sess.run(tf.global_variables_initializer()) 76 | output.eval() 77 | 78 | def test_discriminator_graph(self): 79 | # Check graph construction for a number of image size/depths and batch 80 | # sizes. 81 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 82 | tf.reset_default_graph() 83 | img_w = 2 ** i 84 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 85 | output, end_points = dcgan.discriminator( 86 | image, 87 | depth=32) 88 | 89 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 90 | 91 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 92 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 93 | 94 | # Check layer depths. 95 | for j in range(1, i+1): 96 | layer = end_points['conv%i' % j] 97 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 98 | 99 | def test_discriminator_invalid_input(self): 100 | wrong_dim_img = tf.zeros([5, 32, 32]) 101 | with self.assertRaises(ValueError): 102 | dcgan.discriminator(wrong_dim_img) 103 | 104 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 105 | with self.assertRaises(ValueError): 106 | dcgan.discriminator(spatially_undefined_shape) 107 | 108 | not_square = tf.zeros([5, 32, 16, 3]) 109 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 110 | dcgan.discriminator(not_square) 111 | 112 | not_power_2 = tf.zeros([5, 30, 30, 3]) 113 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 114 | dcgan.discriminator(not_power_2) 115 | 116 | 117 | if __name__ == '__main__': 118 | tf.test.main() 119 | -------------------------------------------------------------------------------- /slim/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /slim/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu): 37 | """Defines the default arg scope for inception models. 38 | 39 | Args: 40 | weight_decay: The weight decay to use for regularizing the model. 41 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 42 | batch_norm_decay: Decay for batch norm moving average. 43 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 44 | in batch norm. 45 | activation_fn: Activation function for conv2d. 46 | 47 | Returns: 48 | An `arg_scope` to use for the inception models. 49 | """ 50 | batch_norm_params = { 51 | # Decay for the moving averages. 52 | 'decay': batch_norm_decay, 53 | # epsilon to prevent 0s in variance. 54 | 'epsilon': batch_norm_epsilon, 55 | # collection containing update_ops. 56 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 57 | # use fused batch norm if possible. 58 | 'fused': None, 59 | } 60 | if use_batch_norm: 61 | normalizer_fn = slim.batch_norm 62 | normalizer_params = batch_norm_params 63 | else: 64 | normalizer_fn = None 65 | normalizer_params = {} 66 | # Set weight_decay for weights in Conv and FC layers. 67 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 68 | weights_regularizer=slim.l2_regularizer(weight_decay)): 69 | with slim.arg_scope( 70 | [slim.conv2d], 71 | weights_initializer=slim.variance_scaling_initializer(), 72 | activation_fn=activation_fn, 73 | normalizer_fn=normalizer_fn, 74 | normalizer_params=normalizer_params) as sc: 75 | return sc 76 | -------------------------------------------------------------------------------- /slim/nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /slim/nets/mobilenet_v1.md: -------------------------------------------------------------------------------- 1 | # MobileNet_v1 2 | 3 | [MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 4 | 5 | MobileNets trade off between latency, size and accuracy while comparing favorably with popular models from the literature. 6 | 7 | ![alt text](mobilenet_v1.png "MobileNet Graph") 8 | 9 | # Pre-trained Models 10 | 11 | Choose the right MobileNet model to fit your latency and size budget. The size of the network in memory and on disk is proportional to the number of parameters. The latency and power usage of the network scales with the number of Multiply-Accumulates (MACs) which measures the number of fused Multiplication and Addition operations. These MobileNet models have been trained on the 12 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 13 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 14 | 15 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 16 | :----:|:------------:|:----------:|:-------:|:-------:| 17 | [MobileNet_v1_1.0_224](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|569|4.24|70.7|89.5| 18 | [MobileNet_v1_1.0_192](http://download.tensorflow.org/models/mobilenet_v1_1.0_192_2017_06_14.tar.gz)|418|4.24|69.3|88.9| 19 | [MobileNet_v1_1.0_160](http://download.tensorflow.org/models/mobilenet_v1_1.0_160_2017_06_14.tar.gz)|291|4.24|67.2|87.5| 20 | [MobileNet_v1_1.0_128](http://download.tensorflow.org/models/mobilenet_v1_1.0_128_2017_06_14.tar.gz)|186|4.24|64.1|85.3| 21 | [MobileNet_v1_0.75_224](http://download.tensorflow.org/models/mobilenet_v1_0.75_224_2017_06_14.tar.gz)|317|2.59|68.4|88.2| 22 | [MobileNet_v1_0.75_192](http://download.tensorflow.org/models/mobilenet_v1_0.75_192_2017_06_14.tar.gz)|233|2.59|67.4|87.3| 23 | [MobileNet_v1_0.75_160](http://download.tensorflow.org/models/mobilenet_v1_0.75_160_2017_06_14.tar.gz)|162|2.59|65.2|86.1| 24 | [MobileNet_v1_0.75_128](http://download.tensorflow.org/models/mobilenet_v1_0.75_128_2017_06_14.tar.gz)|104|2.59|61.8|83.6| 25 | [MobileNet_v1_0.50_224](http://download.tensorflow.org/models/mobilenet_v1_0.50_224_2017_06_14.tar.gz)|150|1.34|64.0|85.4| 26 | [MobileNet_v1_0.50_192](http://download.tensorflow.org/models/mobilenet_v1_0.50_192_2017_06_14.tar.gz)|110|1.34|62.1|84.0| 27 | [MobileNet_v1_0.50_160](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|77|1.34|59.9|82.5| 28 | [MobileNet_v1_0.50_128](http://download.tensorflow.org/models/mobilenet_v1_0.50_128_2017_06_14.tar.gz)|49|1.34|56.2|79.6| 29 | [MobileNet_v1_0.25_224](http://download.tensorflow.org/models/mobilenet_v1_0.25_224_2017_06_14.tar.gz)|41|0.47|50.6|75.0| 30 | [MobileNet_v1_0.25_192](http://download.tensorflow.org/models/mobilenet_v1_0.25_192_2017_06_14.tar.gz)|34|0.47|49.0|73.6| 31 | [MobileNet_v1_0.25_160](http://download.tensorflow.org/models/mobilenet_v1_0.25_160_2017_06_14.tar.gz)|21|0.47|46.0|70.7| 32 | [MobileNet_v1_0.25_128](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|14|0.47|41.3|66.2| 33 | 34 | 35 | Here is an example of how to download the MobileNet_v1_1.0_224 checkpoint: 36 | 37 | ```shell 38 | $ CHECKPOINT_DIR=/tmp/checkpoints 39 | $ mkdir ${CHECKPOINT_DIR} 40 | $ wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 41 | $ tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz 42 | $ mv mobilenet_v1_1.0_224.ckpt.* ${CHECKPOINT_DIR} 43 | $ rm mobilenet_v1_1.0_224_2017_06_14.tar.gz 44 | ``` 45 | More information on integrating MobileNets into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 46 | 47 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 48 | -------------------------------------------------------------------------------- /slim/nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlberkeley/Creative-Adversarial-Networks/fea29d4348a650a40322fc4da645395d3d0f089c/slim/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /slim/nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 \ 48 | --moving_average_decay=0.9999 49 | ``` 50 | 51 | Run eval with the NASNet-A large ImageNet model 52 | 53 | ```shell 54 | DATASET_DIR=/tmp/imagenet 55 | EVAL_DIR=/tmp/tfmodel/eval 56 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 57 | python tensorflow_models/research/slim/eval_image_classifier \ 58 | --checkpoint_path=${CHECKPOINT_DIR} \ 59 | --eval_dir=${EVAL_DIR} \ 60 | --dataset_dir=${DATASET_DIR} \ 61 | --dataset_name=imagenet \ 62 | --dataset_split_name=validation \ 63 | --model_name=nasnet_large \ 64 | --eval_image_size=331 \ 65 | --moving_average_decay=0.9999 66 | ``` 67 | -------------------------------------------------------------------------------- /slim/nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /slim/nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import mobilenet_v1 29 | from nets import overfeat 30 | from nets import resnet_v1 31 | from nets import resnet_v2 32 | from nets import vgg 33 | from nets.nasnet import nasnet 34 | 35 | slim = tf.contrib.slim 36 | 37 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 38 | 'cifarnet': cifarnet.cifarnet, 39 | 'overfeat': overfeat.overfeat, 40 | 'vgg_a': vgg.vgg_a, 41 | 'vgg_16': vgg.vgg_16, 42 | 'vgg_19': vgg.vgg_19, 43 | 'inception_v1': inception.inception_v1, 44 | 'inception_v2': inception.inception_v2, 45 | 'inception_v3': inception.inception_v3, 46 | 'inception_v4': inception.inception_v4, 47 | 'inception_resnet_v2': inception.inception_resnet_v2, 48 | 'lenet': lenet.lenet, 49 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 50 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 51 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 52 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 53 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 54 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 55 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 56 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 57 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 58 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 59 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 60 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 61 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 62 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 63 | 'nasnet_large': nasnet.build_nasnet_large, 64 | } 65 | 66 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 67 | 'cifarnet': cifarnet.cifarnet_arg_scope, 68 | 'overfeat': overfeat.overfeat_arg_scope, 69 | 'vgg_a': vgg.vgg_arg_scope, 70 | 'vgg_16': vgg.vgg_arg_scope, 71 | 'vgg_19': vgg.vgg_arg_scope, 72 | 'inception_v1': inception.inception_v3_arg_scope, 73 | 'inception_v2': inception.inception_v3_arg_scope, 74 | 'inception_v3': inception.inception_v3_arg_scope, 75 | 'inception_v4': inception.inception_v4_arg_scope, 76 | 'inception_resnet_v2': 77 | inception.inception_resnet_v2_arg_scope, 78 | 'lenet': lenet.lenet_arg_scope, 79 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 80 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 81 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 82 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 83 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 84 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 85 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 86 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 87 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 88 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 89 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 90 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 91 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 92 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 93 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 94 | } 95 | 96 | 97 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 98 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 99 | 100 | Args: 101 | name: The name of the network. 102 | num_classes: The number of classes to use for classification. If 0 or None, 103 | the logits layer is omitted and its input features are returned instead. 104 | weight_decay: The l2 coefficient for the model weights. 105 | is_training: `True` if the model is being used for training and `False` 106 | otherwise. 107 | 108 | Returns: 109 | network_fn: A function that applies the model to a batch of images. It has 110 | the following signature: 111 | net, end_points = network_fn(images) 112 | The `images` input is a tensor of shape [batch_size, height, width, 3] 113 | with height = width = network_fn.default_image_size. (The permissibility 114 | and treatment of other sizes depends on the network_fn.) 115 | The returned `end_points` are a dictionary of intermediate activations. 116 | The returned `net` is the topmost layer, depending on `num_classes`: 117 | If `num_classes` was a non-zero integer, `net` is a logits tensor 118 | of shape [batch_size, num_classes]. 119 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 120 | to the logits layer of shape [batch_size, 1, 1, num_features] or 121 | [batch_size, num_features]. Dropout has not been applied to this 122 | (even if the network's original classification does); it remains for 123 | the caller to do this or not. 124 | 125 | Raises: 126 | ValueError: If network `name` is not recognized. 127 | """ 128 | if name not in networks_map: 129 | raise ValueError('Name of network unknown %s' % name) 130 | func = networks_map[name] 131 | @functools.wraps(func) 132 | def network_fn(images, **kwargs): 133 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 134 | with slim.arg_scope(arg_scope): 135 | return func(images, num_classes, is_training=is_training, **kwargs) 136 | if hasattr(func, 'default_image_size'): 137 | network_fn.default_image_size = func.default_image_size 138 | 139 | return network_fn 140 | -------------------------------------------------------------------------------- /slim/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map.keys()[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in nets_factory.networks_map.keys()[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /slim/nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /slim/nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 281, 281 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 231, 231 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = overfeat.overfeat(inputs, num_classes) 70 | expected_names = ['overfeat/conv1', 71 | 'overfeat/pool1', 72 | 'overfeat/conv2', 73 | 'overfeat/pool2', 74 | 'overfeat/conv3', 75 | 'overfeat/conv4', 76 | 'overfeat/conv5', 77 | 'overfeat/pool5', 78 | 'overfeat/fc6', 79 | 'overfeat/fc7', 80 | 'overfeat/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 231, 231 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = overfeat.overfeat(inputs, num_classes) 91 | expected_names = ['overfeat/conv1', 92 | 'overfeat/pool1', 93 | 'overfeat/conv2', 94 | 'overfeat/pool2', 95 | 'overfeat/conv3', 96 | 'overfeat/conv4', 97 | 'overfeat/conv5', 98 | 'overfeat/pool5', 99 | 'overfeat/fc6', 100 | 'overfeat/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('overfeat/fc7')) 104 | 105 | def testModelVariables(self): 106 | batch_size = 5 107 | height, width = 231, 231 108 | num_classes = 1000 109 | with self.test_session(): 110 | inputs = tf.random_uniform((batch_size, height, width, 3)) 111 | overfeat.overfeat(inputs, num_classes) 112 | expected_names = ['overfeat/conv1/weights', 113 | 'overfeat/conv1/biases', 114 | 'overfeat/conv2/weights', 115 | 'overfeat/conv2/biases', 116 | 'overfeat/conv3/weights', 117 | 'overfeat/conv3/biases', 118 | 'overfeat/conv4/weights', 119 | 'overfeat/conv4/biases', 120 | 'overfeat/conv5/weights', 121 | 'overfeat/conv5/biases', 122 | 'overfeat/fc6/weights', 123 | 'overfeat/fc6/biases', 124 | 'overfeat/fc7/weights', 125 | 'overfeat/fc7/biases', 126 | 'overfeat/fc8/weights', 127 | 'overfeat/fc8/biases', 128 | ] 129 | model_variables = [v.op.name for v in slim.get_model_variables()] 130 | self.assertSetEqual(set(model_variables), set(expected_names)) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | height, width = 231, 231 135 | num_classes = 1000 136 | with self.test_session(): 137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | predictions = tf.argmax(logits, 1) 142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 143 | 144 | def testTrainEvalWithReuse(self): 145 | train_batch_size = 2 146 | eval_batch_size = 1 147 | train_height, train_width = 231, 231 148 | eval_height, eval_width = 281, 281 149 | num_classes = 1000 150 | with self.test_session(): 151 | train_inputs = tf.random_uniform( 152 | (train_batch_size, train_height, train_width, 3)) 153 | logits, _ = overfeat.overfeat(train_inputs) 154 | self.assertListEqual(logits.get_shape().as_list(), 155 | [train_batch_size, num_classes]) 156 | tf.get_variable_scope().reuse_variables() 157 | eval_inputs = tf.random_uniform( 158 | (eval_batch_size, eval_height, eval_width, 3)) 159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 160 | spatial_squeeze=False) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [eval_batch_size, 2, 2, num_classes]) 163 | logits = tf.reduce_mean(logits, [1, 2]) 164 | predictions = tf.argmax(logits, 1) 165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 166 | 167 | def testForward(self): 168 | batch_size = 1 169 | height, width = 231, 231 170 | with self.test_session() as sess: 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | logits, _ = overfeat.overfeat(inputs) 173 | sess.run(tf.global_variables_initializer()) 174 | output = sess.run(logits) 175 | self.assertTrue(output.any()) 176 | 177 | if __name__ == '__main__': 178 | tf.test.main() 179 | -------------------------------------------------------------------------------- /slim/nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def test_nonsquare_inputs_raise_exception(self): 28 | batch_size = 2 29 | height, width = 240, 320 30 | num_outputs = 4 31 | 32 | images = tf.ones((batch_size, height, width, 3)) 33 | 34 | with self.assertRaises(ValueError): 35 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 36 | pix2pix.pix2pix_generator( 37 | images, num_outputs, upsample_method='nn_upsample_conv') 38 | 39 | def _reduced_default_blocks(self): 40 | """Returns the default blocks, scaled down to make test run faster.""" 41 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 42 | for b in pix2pix._default_generator_blocks()] 43 | 44 | def test_output_size_nn_upsample_conv(self): 45 | batch_size = 2 46 | height, width = 256, 256 47 | num_outputs = 4 48 | 49 | images = tf.ones((batch_size, height, width, 3)) 50 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 51 | logits, _ = pix2pix.pix2pix_generator( 52 | images, num_outputs, blocks=self._reduced_default_blocks(), 53 | upsample_method='nn_upsample_conv') 54 | 55 | with self.test_session() as session: 56 | session.run(tf.global_variables_initializer()) 57 | np_outputs = session.run(logits) 58 | self.assertListEqual([batch_size, height, width, num_outputs], 59 | list(np_outputs.shape)) 60 | 61 | def test_output_size_conv2d_transpose(self): 62 | batch_size = 2 63 | height, width = 256, 256 64 | num_outputs = 4 65 | 66 | images = tf.ones((batch_size, height, width, 3)) 67 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 68 | logits, _ = pix2pix.pix2pix_generator( 69 | images, num_outputs, blocks=self._reduced_default_blocks(), 70 | upsample_method='conv2d_transpose') 71 | 72 | with self.test_session() as session: 73 | session.run(tf.global_variables_initializer()) 74 | np_outputs = session.run(logits) 75 | self.assertListEqual([batch_size, height, width, num_outputs], 76 | list(np_outputs.shape)) 77 | 78 | def test_block_number_dictates_number_of_layers(self): 79 | batch_size = 2 80 | height, width = 256, 256 81 | num_outputs = 4 82 | 83 | images = tf.ones((batch_size, height, width, 3)) 84 | blocks = [ 85 | pix2pix.Block(64, 0.5), 86 | pix2pix.Block(128, 0), 87 | ] 88 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 89 | _, end_points = pix2pix.pix2pix_generator( 90 | images, num_outputs, blocks) 91 | 92 | num_encoder_layers = 0 93 | num_decoder_layers = 0 94 | for end_point in end_points: 95 | if end_point.startswith('encoder'): 96 | num_encoder_layers += 1 97 | elif end_point.startswith('decoder'): 98 | num_decoder_layers += 1 99 | 100 | self.assertEqual(num_encoder_layers, len(blocks)) 101 | self.assertEqual(num_decoder_layers, len(blocks)) 102 | 103 | 104 | class DiscriminatorTest(tf.test.TestCase): 105 | 106 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 107 | return (input_size + pad * 2 - kernel_size) // stride + 1 108 | 109 | def test_four_layers(self): 110 | batch_size = 2 111 | input_size = 256 112 | 113 | output_size = self._layer_output_size(input_size) 114 | output_size = self._layer_output_size(output_size) 115 | output_size = self._layer_output_size(output_size) 116 | output_size = self._layer_output_size(output_size, stride=1) 117 | output_size = self._layer_output_size(output_size, stride=1) 118 | 119 | images = tf.ones((batch_size, input_size, input_size, 3)) 120 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 121 | logits, end_points = pix2pix.pix2pix_discriminator( 122 | images, num_filters=[64, 128, 256, 512]) 123 | self.assertListEqual([batch_size, output_size, output_size, 1], 124 | logits.shape.as_list()) 125 | self.assertListEqual([batch_size, output_size, output_size, 1], 126 | end_points['predictions'].shape.as_list()) 127 | 128 | def test_four_layers_no_padding(self): 129 | batch_size = 2 130 | input_size = 256 131 | 132 | output_size = self._layer_output_size(input_size, pad=0) 133 | output_size = self._layer_output_size(output_size, pad=0) 134 | output_size = self._layer_output_size(output_size, pad=0) 135 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 136 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 137 | 138 | images = tf.ones((batch_size, input_size, input_size, 3)) 139 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 140 | logits, end_points = pix2pix.pix2pix_discriminator( 141 | images, num_filters=[64, 128, 256, 512], padding=0) 142 | self.assertListEqual([batch_size, output_size, output_size, 1], 143 | logits.shape.as_list()) 144 | self.assertListEqual([batch_size, output_size, output_size, 1], 145 | end_points['predictions'].shape.as_list()) 146 | 147 | def test_four_layers_wrog_paddig(self): 148 | batch_size = 2 149 | input_size = 256 150 | 151 | images = tf.ones((batch_size, input_size, input_size, 3)) 152 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 153 | with self.assertRaises(TypeError): 154 | pix2pix.pix2pix_discriminator( 155 | images, num_filters=[64, 128, 256, 512], padding=1.5) 156 | 157 | def test_four_layers_negative_padding(self): 158 | batch_size = 2 159 | input_size = 256 160 | 161 | images = tf.ones((batch_size, input_size, input_size, 3)) 162 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 163 | with self.assertRaises(ValueError): 164 | pix2pix.pix2pix_discriminator( 165 | images, num_filters=[64, 128, 256, 512], padding=-1) 166 | 167 | if __name__ == '__main__': 168 | tf.test.main() 169 | -------------------------------------------------------------------------------- /slim/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /slim/preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /slim/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'nasnet_mobile': inception_preprocessing, 58 | 'nasnet_large': inception_preprocessing, 59 | 'resnet_v1_50': vgg_preprocessing, 60 | 'resnet_v1_101': vgg_preprocessing, 61 | 'resnet_v1_152': vgg_preprocessing, 62 | 'resnet_v1_200': vgg_preprocessing, 63 | 'resnet_v2_50': vgg_preprocessing, 64 | 'resnet_v2_101': vgg_preprocessing, 65 | 'resnet_v2_152': vgg_preprocessing, 66 | 'resnet_v2_200': vgg_preprocessing, 67 | 'vgg': vgg_preprocessing, 68 | 'vgg_a': vgg_preprocessing, 69 | 'vgg_16': vgg_preprocessing, 70 | 'vgg_19': vgg_preprocessing, 71 | } 72 | 73 | if name not in preprocessing_fn_map: 74 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 75 | 76 | def preprocessing_fn(image, output_height, output_width, **kwargs): 77 | return preprocessing_fn_map[name].preprocess_image( 78 | image, output_height, output_width, is_training=is_training, **kwargs) 79 | 80 | return preprocessing_fn 81 | -------------------------------------------------------------------------------- /slim/scripts/export_mobilenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # This script prepares the various different versions of MobileNet models for 18 | # use in a mobile application. If you don't specify your own trained checkpoint 19 | # file, it will download pretrained checkpoints for ImageNet. You'll also need 20 | # to have a copy of the TensorFlow source code to run some of the commands, 21 | # by default it will be looked for in ./tensorflow, but you can set the 22 | # TENSORFLOW_PATH environment variable before calling the script if your source 23 | # is in a different location. 24 | # The main slim/nets/mobilenet_v1.md description has more details about the 25 | # model, but the main points are that it comes in four size versions, 1.0, 0.75, 26 | # 0.50, and 0.25, which controls the number of parameters and so the file size 27 | # of the model, and the input image size, which can be 224, 192, 160, or 128 28 | # pixels, and affects the amount of computation needed, and the latency. 29 | # Here's an example generating a frozen model from pretrained weights: 30 | # 31 | 32 | set -e 33 | 34 | print_usage () { 35 | echo "Creates a frozen mobilenet model suitable for mobile use" 36 | echo "Usage:" 37 | echo "$0 [checkpoint path]" 38 | } 39 | 40 | MOBILENET_VERSION=$1 41 | IMAGE_SIZE=$2 42 | CHECKPOINT=$3 43 | 44 | if [[ ${MOBILENET_VERSION} = "1.0" ]]; then 45 | SLIM_NAME=mobilenet_v1 46 | elif [[ ${MOBILENET_VERSION} = "0.75" ]]; then 47 | SLIM_NAME=mobilenet_v1_075 48 | elif [[ ${MOBILENET_VERSION} = "0.50" ]]; then 49 | SLIM_NAME=mobilenet_v1_050 50 | elif [[ ${MOBILENET_VERSION} = "0.25" ]]; then 51 | SLIM_NAME=mobilenet_v1_025 52 | else 53 | echo "Bad mobilenet version, should be one of 1.0, 0.75, 0.50, or 0.25" 54 | print_usage 55 | exit 1 56 | fi 57 | 58 | if [[ ${IMAGE_SIZE} -ne "224" ]] && [[ ${IMAGE_SIZE} -ne "192" ]] && [[ ${IMAGE_SIZE} -ne "160" ]] && [[ ${IMAGE_SIZE} -ne "128" ]]; then 59 | echo "Bad input image size, should be one of 224, 192, 160, or 128" 60 | print_usage 61 | exit 1 62 | fi 63 | 64 | if [[ ${TENSORFLOW_PATH} -eq "" ]]; then 65 | TENSORFLOW_PATH=../tensorflow 66 | fi 67 | 68 | if [[ ! -d ${TENSORFLOW_PATH} ]]; then 69 | echo "TensorFlow source folder not found. You should download the source and then set" 70 | echo "the TENSORFLOW_PATH environment variable to point to it, like this:" 71 | echo "export TENSORFLOW_PATH=/my/path/to/tensorflow" 72 | print_usage 73 | exit 1 74 | fi 75 | 76 | MODEL_FOLDER=/tmp/mobilenet_v1_${MOBILENET_VERSION}_${IMAGE_SIZE} 77 | if [[ -d ${MODEL_FOLDER} ]]; then 78 | echo "Model folder ${MODEL_FOLDER} already exists!" 79 | echo "If you want to overwrite it, then 'rm -rf ${MODEL_FOLDER}' first." 80 | print_usage 81 | exit 1 82 | fi 83 | mkdir ${MODEL_FOLDER} 84 | 85 | if [[ ${CHECKPOINT} = "" ]]; then 86 | echo "*******" 87 | echo "Downloading pretrained weights" 88 | echo "*******" 89 | curl "http://download.tensorflow.org/models/mobilenet_v1_${MOBILENET_VERSION}_${IMAGE_SIZE}_2017_06_14.tar.gz" \ 90 | -o ${MODEL_FOLDER}/checkpoints.tar.gz 91 | tar xzf ${MODEL_FOLDER}/checkpoints.tar.gz --directory ${MODEL_FOLDER} 92 | CHECKPOINT=${MODEL_FOLDER}/mobilenet_v1_${MOBILENET_VERSION}_${IMAGE_SIZE}.ckpt 93 | fi 94 | 95 | echo "*******" 96 | echo "Exporting graph architecture to ${MODEL_FOLDER}/unfrozen_graph.pb" 97 | echo "*******" 98 | bazel run slim:export_inference_graph -- \ 99 | --model_name=${SLIM_NAME} --image_size=${IMAGE_SIZE} --logtostderr \ 100 | --output_file=${MODEL_FOLDER}/unfrozen_graph.pb --dataset_dir=${MODEL_FOLDER} 101 | 102 | cd ../tensorflow 103 | 104 | echo "*******" 105 | echo "Freezing graph to ${MODEL_FOLDER}/frozen_graph.pb" 106 | echo "*******" 107 | bazel run tensorflow/python/tools:freeze_graph -- \ 108 | --input_graph=${MODEL_FOLDER}/unfrozen_graph.pb \ 109 | --input_checkpoint=${CHECKPOINT} \ 110 | --input_binary=true --output_graph=${MODEL_FOLDER}/frozen_graph.pb \ 111 | --output_node_names=MobilenetV1/Predictions/Reshape_1 112 | 113 | echo "Quantizing weights to ${MODEL_FOLDER}/quantized_graph.pb" 114 | bazel run tensorflow/tools/graph_transforms:transform_graph -- \ 115 | --in_graph=${MODEL_FOLDER}/frozen_graph.pb \ 116 | --out_graph=${MODEL_FOLDER}/quantized_graph.pb \ 117 | --inputs=input --outputs=MobilenetV1/Predictions/Reshape_1 \ 118 | --transforms='fold_constants fold_batch_norms quantize_weights' 119 | 120 | echo "*******" 121 | echo "Running label_image using the graph" 122 | echo "*******" 123 | bazel build tensorflow/examples/label_image:label_image 124 | bazel-bin/tensorflow/examples/label_image/label_image \ 125 | --input_layer=input --output_layer=MobilenetV1/Predictions/Reshape_1 \ 126 | --graph=${MODEL_FOLDER}/quantized_graph.pb --input_mean=-127 --input_std=127 \ 127 | --image=tensorflow/examples/label_image/data/grace_hopper.jpg \ 128 | --input_width=${IMAGE_SIZE} --input_height=${IMAGE_SIZE} --labels=${MODEL_FOLDER}/labels.txt 129 | 130 | echo "*******" 131 | echo "Saved graphs to ${MODEL_FOLDER}/frozen_graph.pb and ${MODEL_FOLDER}/quantized_graph.pb" 132 | echo "*******" 133 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_resnet_v2_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an Inception Resnet V2 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_resnet_v2_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 31 | MODEL_NAME=inception_resnet_v2 32 | 33 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 34 | TRAIN_DIR=/tmp/flowers-models/${MODEL_NAME} 35 | 36 | # Where the dataset is saved to. 37 | DATASET_DIR=/tmp/flowers 38 | 39 | # Download the pre-trained checkpoint. 40 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 41 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 42 | fi 43 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt ]; then 44 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 45 | tar -xvf inception_resnet_v2_2016_08_30.tar.gz 46 | mv inception_resnet_v2.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt 47 | rm inception_resnet_v2_2016_08_30.tar.gz 48 | fi 49 | 50 | # Download the dataset 51 | python download_and_convert_data.py \ 52 | --dataset_name=flowers \ 53 | --dataset_dir=${DATASET_DIR} 54 | 55 | # Fine-tune only the new layers for 1000 steps. 56 | python train_image_classifier.py \ 57 | --train_dir=${TRAIN_DIR} \ 58 | --dataset_name=flowers \ 59 | --dataset_split_name=train \ 60 | --dataset_dir=${DATASET_DIR} \ 61 | --model_name=${MODEL_NAME} \ 62 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \ 63 | --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 64 | --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 65 | --max_number_of_steps=1000 \ 66 | --batch_size=32 \ 67 | --learning_rate=0.01 \ 68 | --learning_rate_decay_type=fixed \ 69 | --save_interval_secs=60 \ 70 | --save_summaries_secs=60 \ 71 | --log_every_n_steps=10 \ 72 | --optimizer=rmsprop \ 73 | --weight_decay=0.00004 74 | 75 | # Run evaluation. 76 | python eval_image_classifier.py \ 77 | --checkpoint_path=${TRAIN_DIR} \ 78 | --eval_dir=${TRAIN_DIR} \ 79 | --dataset_name=flowers \ 80 | --dataset_split_name=validation \ 81 | --dataset_dir=${DATASET_DIR} \ 82 | --model_name=${MODEL_NAME} 83 | 84 | # Fine-tune all the new layers for 500 steps. 85 | python train_image_classifier.py \ 86 | --train_dir=${TRAIN_DIR}/all \ 87 | --dataset_name=flowers \ 88 | --dataset_split_name=train \ 89 | --dataset_dir=${DATASET_DIR} \ 90 | --model_name=${MODEL_NAME} \ 91 | --checkpoint_path=${TRAIN_DIR} \ 92 | --max_number_of_steps=500 \ 93 | --batch_size=32 \ 94 | --learning_rate=0.0001 \ 95 | --learning_rate_decay_type=fixed \ 96 | --save_interval_secs=60 \ 97 | --save_summaries_secs=60 \ 98 | --log_every_n_steps=10 \ 99 | --optimizer=rmsprop \ 100 | --weight_decay=0.00004 101 | 102 | # Run evaluation. 103 | python eval_image_classifier.py \ 104 | --checkpoint_path=${TRAIN_DIR}/all \ 105 | --eval_dir=${TRAIN_DIR}/all \ 106 | --dataset_name=flowers \ 107 | --dataset_split_name=validation \ 108 | --dataset_dir=${DATASET_DIR} \ 109 | --model_name=${MODEL_NAME} 110 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_v1_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained InceptionV1 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/inception_v1 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 42 | tar -xvf inception_v1_2016_08_28.tar.gz 43 | mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 44 | rm inception_v1_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 2000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v1 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV1/Logits \ 61 | --trainable_scopes=InceptionV1/Logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 70 | 71 | # Run evaluation. 72 | python eval_image_classifier.py \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --eval_dir=${TRAIN_DIR} \ 75 | --dataset_name=flowers \ 76 | --dataset_split_name=validation \ 77 | --dataset_dir=${DATASET_DIR} \ 78 | --model_name=inception_v1 79 | 80 | # Fine-tune all the new layers for 1000 steps. 81 | python train_image_classifier.py \ 82 | --train_dir=${TRAIN_DIR}/all \ 83 | --dataset_name=flowers \ 84 | --dataset_split_name=train \ 85 | --dataset_dir=${DATASET_DIR} \ 86 | --checkpoint_path=${TRAIN_DIR} \ 87 | --model_name=inception_v1 \ 88 | --max_number_of_steps=1000 \ 89 | --batch_size=32 \ 90 | --learning_rate=0.001 \ 91 | --save_interval_secs=60 \ 92 | --save_summaries_secs=60 \ 93 | --log_every_n_steps=100 \ 94 | --optimizer=rmsprop \ 95 | --weight_decay=0.00004 96 | 97 | # Run evaluation. 98 | python eval_image_classifier.py \ 99 | --checkpoint_path=${TRAIN_DIR}/all \ 100 | --eval_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=flowers \ 102 | --dataset_split_name=validation \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=inception_v1 105 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_v3_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV3 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v3_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained InceptionV3 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/inception_v3 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 42 | tar -xvf inception_v3_2016_08_28.tar.gz 43 | mv inception_v3.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt 44 | rm inception_v3_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 1000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v3 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 61 | --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 62 | --max_number_of_steps=1000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --learning_rate_decay_type=fixed \ 66 | --save_interval_secs=60 \ 67 | --save_summaries_secs=60 \ 68 | --log_every_n_steps=100 \ 69 | --optimizer=rmsprop \ 70 | --weight_decay=0.00004 71 | 72 | # Run evaluation. 73 | python eval_image_classifier.py \ 74 | --checkpoint_path=${TRAIN_DIR} \ 75 | --eval_dir=${TRAIN_DIR} \ 76 | --dataset_name=flowers \ 77 | --dataset_split_name=validation \ 78 | --dataset_dir=${DATASET_DIR} \ 79 | --model_name=inception_v3 80 | 81 | # Fine-tune all the new layers for 500 steps. 82 | python train_image_classifier.py \ 83 | --train_dir=${TRAIN_DIR}/all \ 84 | --dataset_name=flowers \ 85 | --dataset_split_name=train \ 86 | --dataset_dir=${DATASET_DIR} \ 87 | --model_name=inception_v3 \ 88 | --checkpoint_path=${TRAIN_DIR} \ 89 | --max_number_of_steps=500 \ 90 | --batch_size=32 \ 91 | --learning_rate=0.0001 \ 92 | --learning_rate_decay_type=fixed \ 93 | --save_interval_secs=60 \ 94 | --save_summaries_secs=60 \ 95 | --log_every_n_steps=10 \ 96 | --optimizer=rmsprop \ 97 | --weight_decay=0.00004 98 | 99 | # Run evaluation. 100 | python eval_image_classifier.py \ 101 | --checkpoint_path=${TRAIN_DIR}/all \ 102 | --eval_dir=${TRAIN_DIR}/all \ 103 | --dataset_name=flowers \ 104 | --dataset_split_name=validation \ 105 | --dataset_dir=${DATASET_DIR} \ 106 | --model_name=inception_v3 107 | -------------------------------------------------------------------------------- /slim/scripts/finetune_resnet_v1_50_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes a ResNetV1-50 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained ResNetV1-50 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/resnet_v1_50 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then 41 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 42 | tar -xvf resnet_v1_50_2016_08_28.tar.gz 43 | mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt 44 | rm resnet_v1_50_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 3000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=resnet_v1_50 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \ 60 | --checkpoint_exclude_scopes=resnet_v1_50/logits \ 61 | --trainable_scopes=resnet_v1_50/logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 70 | 71 | # Run evaluation. 72 | python eval_image_classifier.py \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --eval_dir=${TRAIN_DIR} \ 75 | --dataset_name=flowers \ 76 | --dataset_split_name=validation \ 77 | --dataset_dir=${DATASET_DIR} \ 78 | --model_name=resnet_v1_50 79 | 80 | # Fine-tune all the new layers for 1000 steps. 81 | python train_image_classifier.py \ 82 | --train_dir=${TRAIN_DIR}/all \ 83 | --dataset_name=flowers \ 84 | --dataset_split_name=train \ 85 | --dataset_dir=${DATASET_DIR} \ 86 | --checkpoint_path=${TRAIN_DIR} \ 87 | --model_name=resnet_v1_50 \ 88 | --max_number_of_steps=1000 \ 89 | --batch_size=32 \ 90 | --learning_rate=0.001 \ 91 | --save_interval_secs=60 \ 92 | --save_summaries_secs=60 \ 93 | --log_every_n_steps=100 \ 94 | --optimizer=rmsprop \ 95 | --weight_decay=0.00004 96 | 97 | # Run evaluation. 98 | python eval_image_classifier.py \ 99 | --checkpoint_path=${TRAIN_DIR}/all \ 100 | --eval_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=flowers \ 102 | --dataset_split_name=validation \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=resnet_v1_50 105 | -------------------------------------------------------------------------------- /slim/scripts/train_cifarnet_on_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Cifar10 dataset 19 | # 2. Trains a CifarNet model on the Cifar10 training set. 20 | # 3. Evaluates the model on the Cifar10 testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./scripts/train_cifarnet_on_cifar10.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/cifarnet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/cifar10 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=cifar10 \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=cifar10 \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=cifarnet \ 45 | --preprocessing_name=cifarnet \ 46 | --max_number_of_steps=100000 \ 47 | --batch_size=128 \ 48 | --save_interval_secs=120 \ 49 | --save_summaries_secs=120 \ 50 | --log_every_n_steps=100 \ 51 | --optimizer=sgd \ 52 | --learning_rate=0.1 \ 53 | --learning_rate_decay_factor=0.1 \ 54 | --num_epochs_per_decay=200 \ 55 | --weight_decay=0.004 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=cifar10 \ 62 | --dataset_split_name=test \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=cifarnet 65 | -------------------------------------------------------------------------------- /slim/scripts/train_lenet_on_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the MNIST dataset 19 | # 2. Trains a LeNet model on the MNIST training set. 20 | # 3. Evaluates the model on the MNIST testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/train_lenet_on_mnist.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/lenet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/mnist 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=mnist \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=mnist \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=lenet \ 45 | --preprocessing_name=lenet \ 46 | --max_number_of_steps=20000 \ 47 | --batch_size=50 \ 48 | --learning_rate=0.01 \ 49 | --save_interval_secs=60 \ 50 | --save_summaries_secs=60 \ 51 | --log_every_n_steps=100 \ 52 | --optimizer=sgd \ 53 | --learning_rate_decay_type=fixed \ 54 | --weight_decay=0 55 | 56 | # Run evaluation. 57 | python eval_image_classifier.py \ 58 | --checkpoint_path=${TRAIN_DIR} \ 59 | --eval_dir=${TRAIN_DIR} \ 60 | --dataset_name=mnist \ 61 | --dataset_split_name=test \ 62 | --dataset_dir=${DATASET_DIR} \ 63 | --model_name=lenet 64 | -------------------------------------------------------------------------------- /slim/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup script for slim.""" 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | 21 | setup( 22 | name='slim', 23 | version='0.1', 24 | include_package_data=True, 25 | packages=find_packages(), 26 | description='tf-slim', 27 | ) 28 | -------------------------------------------------------------------------------- /slim/wikiart/flowers_train_00000-of-00005.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlberkeley/Creative-Adversarial-Networks/fea29d4348a650a40322fc4da645395d3d0f089c/slim/wikiart/flowers_train_00000-of-00005.tfrecord --------------------------------------------------------------------------------