├── .gitignore ├── README.md ├── conf ├── candy.yml ├── cubist.yml ├── denoised_starry.yml ├── feathers.yml ├── mosaic.yml ├── scream.yml ├── udnie.yml └── wave.yml ├── eval.py ├── export.py ├── img ├── candy.jpg ├── cubist.jpg ├── denoised_starry.jpg ├── feathers.jpg ├── mosaic.jpg ├── results │ ├── cubist.jpg │ ├── denoised_starry.jpg │ ├── feathers.jpg │ ├── mosaic.jpg │ ├── scream.jpg │ ├── style_cubist.jpg │ ├── style_denoised_starry.jpg │ ├── style_feathers.jpg │ ├── style_mosaic.jpg │ ├── style_scream.jpg │ ├── style_udnie.jpg │ ├── style_wave.jpg │ ├── udnie.jpg │ └── wave.jpg ├── scream.jpg ├── starry.jpg ├── test.jpg ├── test1.jpg ├── test2.jpg ├── test3.jpg ├── test4.jpg ├── test5.jpg ├── udnie.jpg └── wave.jpg ├── losses.py ├── model.py ├── nets ├── __init__.py ├── alexnet.py ├── alexnet_test.py ├── cifarnet.py ├── inception.py ├── inception_resnet_v2.py ├── inception_resnet_v2_test.py ├── inception_utils.py ├── inception_v1.py ├── inception_v1_test.py ├── inception_v2.py ├── inception_v2_test.py ├── inception_v3.py ├── inception_v3_test.py ├── inception_v4.py ├── inception_v4_test.py ├── lenet.py ├── nets_factory.py ├── nets_factory_test.py ├── overfeat.py ├── overfeat_test.py ├── resnet_utils.py ├── resnet_v1.py ├── resnet_v1_test.py ├── resnet_v2.py ├── resnet_v2_test.py ├── vgg.py └── vgg_test.py ├── preprocessing ├── __init__.py ├── cifarnet_preprocessing.py ├── inception_preprocessing.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── vgg_preprocessing.py ├── reader.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | 92 | slim-official/ 93 | train2014 94 | generated/ 95 | models/ 96 | tensorboard/ 97 | result/ 98 | pretrained/ 99 | 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fast-neural-style-tensorflow 2 | 3 | A tensorflow implementation for [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155). 4 | 5 | This code is based on [Tensorflow-Slim](https://github.com/tensorflow/models/tree/master/slim) and [OlavHN/fast-neural-style](https://github.com/OlavHN/fast-neural-style). 6 | 7 | ## Samples: 8 | 9 | | configuration | style | sample | 10 | | :---: | :----: | :----: | 11 | | [wave.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/wave.yml) |![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/style_wave.jpg)| ![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/wave.jpg) | 12 | | [cubist.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/cubist.yml) |![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/style_cubist.jpg)| ![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/cubist.jpg) | 13 | | [denoised_starry.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/denoised_starry.yml) |![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/style_denoised_starry.jpg)| ![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/denoised_starry.jpg) | 14 | | [mosaic.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/mosaic.yml) |![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/style_mosaic.jpg)| ![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/mosaic.jpg) | 15 | | [scream.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/scream.yml) |![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/style_scream.jpg)| ![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/scream.jpg) | 16 | | [feathers.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/feathers.yml) |![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/style_feathers.jpg)| ![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/feathers.jpg) | 17 | | [udnie.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/udnie.yml) |![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/style_udnie.jpg)| ![](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/img/results/udnie.jpg) | 18 | 19 | ## Requirements and Prerequisites: 20 | - Python 2.7.x 21 | - Now support Tensorflow >= 1.0 22 | 23 | Attention: This code also supports Tensorflow == 0.11. If it is your version, use the commit 5309a2a (git reset --hard 5309a2a). 24 | 25 | And make sure you installed pyyaml: 26 | ``` 27 | pip install pyyaml 28 | ``` 29 | 30 | ## Use Trained Models: 31 | 32 | You can download all the 7 trained models from [Baidu Drive](https://pan.baidu.com/s/1i4GTS4d). 33 | 34 | To generate a sample from the model "wave.ckpt-done", run: 35 | 36 | ``` 37 | python eval.py --model_file --image_file img/test.jpg 38 | ``` 39 | 40 | Then check out generated/res.jpg. 41 | 42 | ## Train a Model: 43 | To train a model from scratch, you should first download [VGG16 model](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz) from Tensorflow Slim. Extract the file vgg_16.ckpt. Then copy it to the folder pretrained/ : 44 | ``` 45 | cd 46 | mkdir pretrained 47 | cp pretrained/ 48 | ``` 49 | 50 | Then download the [COCO dataset](http://msvocds.blob.core.windows.net/coco2014/train2014.zip). Please unzip it, and you will have a folder named "train2014" with many raw images in it. Then create a symbol link to it: 51 | ``` 52 | cd 53 | ln -s train2014 54 | ``` 55 | 56 | Train the model of "wave": 57 | ``` 58 | python train.py -c conf/wave.yml 59 | ``` 60 | 61 | (Optional) Use tensorboard: 62 | ``` 63 | tensorboard --logdir models/wave/ 64 | ``` 65 | 66 | Checkpoints will be written to "models/wave/". 67 | 68 | View the [configuration file](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/wave.yml) for details. 69 | -------------------------------------------------------------------------------- /conf/candy.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/candy.jpg # targeted style image 3 | naming: "candy" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 50.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint 27 | -------------------------------------------------------------------------------- /conf/cubist.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/cubist.jpg # targeted style image 3 | naming: "cubist" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 180.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/denoised_starry.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/denoised_starry.jpg # targeted style image 3 | naming: "denoised_starry" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 250 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/feathers.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/feathers.jpg # targeted style image 3 | naming: "feathers" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 220.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/mosaic.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/mosaic.jpg # targeted style image 3 | naming: "mosaic" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 100.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/scream.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/scream.jpg # targeted style image 3 | naming: "scream" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 250.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/udnie.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/udnie.jpg # targeted style image 3 | naming: "udnie" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 200.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/wave.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: img/wave.jpg # targeted style image 3 | naming: "wave" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 220.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint 27 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | import tensorflow as tf 4 | from preprocessing import preprocessing_factory 5 | import reader 6 | import model 7 | import time 8 | import os 9 | 10 | tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'The name of the architecture to evaluate. ' 11 | 'You can view all the support models in nets/nets_factory.py') 12 | tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.') 13 | tf.app.flags.DEFINE_string("model_file", "models.ckpt", "") 14 | tf.app.flags.DEFINE_string("image_file", "a.jpg", "") 15 | 16 | FLAGS = tf.app.flags.FLAGS 17 | 18 | 19 | def main(_): 20 | 21 | # Get image's height and width. 22 | height = 0 23 | width = 0 24 | with open(FLAGS.image_file, 'rb') as img: 25 | with tf.Session().as_default() as sess: 26 | if FLAGS.image_file.lower().endswith('png'): 27 | image = sess.run(tf.image.decode_png(img.read())) 28 | else: 29 | image = sess.run(tf.image.decode_jpeg(img.read())) 30 | height = image.shape[0] 31 | width = image.shape[1] 32 | tf.logging.info('Image size: %dx%d' % (width, height)) 33 | 34 | with tf.Graph().as_default(): 35 | with tf.Session().as_default() as sess: 36 | 37 | # Read image data. 38 | image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing( 39 | FLAGS.loss_model, 40 | is_training=False) 41 | image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn) 42 | 43 | # Add batch dimension 44 | image = tf.expand_dims(image, 0) 45 | 46 | generated = model.net(image, training=False) 47 | generated = tf.cast(generated, tf.uint8) 48 | 49 | # Remove batch dimension 50 | generated = tf.squeeze(generated, [0]) 51 | 52 | # Restore model variables. 53 | saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1) 54 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 55 | # Use absolute path 56 | FLAGS.model_file = os.path.abspath(FLAGS.model_file) 57 | saver.restore(sess, FLAGS.model_file) 58 | 59 | # Make sure 'generated' directory exists. 60 | generated_file = 'generated/res.jpg' 61 | if os.path.exists('generated') is False: 62 | os.makedirs('generated') 63 | 64 | # Generate and write image data to file. 65 | with open(generated_file, 'wb') as img: 66 | start_time = time.time() 67 | img.write(sess.run(tf.image.encode_jpeg(generated))) 68 | end_time = time.time() 69 | tf.logging.info('Elapsed time: %fs' % (end_time - start_time)) 70 | 71 | tf.logging.info('Done. Please check %s.' % generated_file) 72 | 73 | 74 | if __name__ == '__main__': 75 | tf.logging.set_verbosity(tf.logging.INFO) 76 | tf.app.run() 77 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | import tensorflow as tf 4 | import argparse 5 | import time 6 | import os 7 | 8 | import model 9 | import utils 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-m', '--model_file', help='the path to the model file') 15 | parser.add_argument('-n', '--model_name', default='transfer', help='the name of the model') 16 | parser.add_argument('-d', dest='is_debug', action='store_true') 17 | parser.set_defaults(is_debug=False) 18 | return parser.parse_args() 19 | 20 | 21 | def main(args): 22 | g = tf.Graph() # A new graph 23 | with g.as_default(): 24 | with tf.Session() as sess: 25 | # Building graph. 26 | image_data = tf.placeholder(tf.int32, name='input_image') 27 | height = tf.placeholder(tf.int32, name='height') 28 | width = tf.placeholder(tf.int32, name='width') 29 | 30 | # Reshape data 31 | image = tf.reshape(image_data, [height, width, 3]) 32 | 33 | processed_image = utils.mean_image_subtraction( 34 | image, [123.68, 116.779, 103.939]) # Preprocessing image 35 | batched_image = tf.expand_dims(processed_image, 0) # Add batch dimension 36 | generated_image = model.net(batched_image, training=False) 37 | casted_image = tf.cast(generated_image, tf.int32) 38 | # Remove batch dimension 39 | squeezed_image = tf.squeeze(casted_image, [0]) 40 | cropped_image = tf.slice(squeezed_image, [0, 0, 0], [height, width, 3]) 41 | # stylized_image = tf.image.encode_jpeg(squeezed_image, name='output_image') 42 | stylized_image_data = tf.reshape(cropped_image, [-1], name='output_image') 43 | 44 | # Restore model variables. 45 | saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1) 46 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 47 | # Use absolute path. 48 | model_file = os.path.abspath(args.model_file) 49 | saver.restore(sess, model_file) 50 | 51 | if args.is_debug: 52 | content_file = '/Users/Lex/Desktop/t.jpg' 53 | generated_file = '/Users/Lex/Desktop/xwz-stylized.jpg' 54 | 55 | with open(generated_file, 'wb') as img: 56 | image_bytes = tf.read_file(content_file) 57 | input_array, decoded_image = sess.run([ 58 | tf.reshape(tf.image.decode_jpeg(image_bytes, channels=3), [-1]), 59 | tf.image.decode_jpeg(image_bytes, channels=3)]) 60 | 61 | start_time = time.time() 62 | img.write(sess.run(tf.image.encode_jpeg(tf.cast(cropped_image, tf.uint8)), feed_dict={ 63 | image_data: input_array, 64 | height: decoded_image.shape[0], 65 | width: decoded_image.shape[1]})) 66 | end_time = time.time() 67 | 68 | tf.logging.info('Elapsed time: %fs' % (end_time - start_time)) 69 | else: 70 | output_graph_def = tf.graph_util.convert_variables_to_constants( 71 | sess, sess.graph_def, output_node_names=['output_image']) 72 | 73 | with tf.gfile.FastGFile('/Users/Lex/Desktop/' + args.model_name + '.pb', mode='wb') as f: 74 | f.write(output_graph_def.SerializeToString()) 75 | 76 | # tf.train.write_graph(g.as_graph_def(), '/Users/Lex/Desktop', 77 | # args.model_name + '.pb', as_text=False) 78 | 79 | 80 | if __name__ == '__main__': 81 | tf.logging.set_verbosity(tf.logging.INFO) 82 | args = parse_args() 83 | print(args) 84 | main(args) 85 | -------------------------------------------------------------------------------- /img/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/candy.jpg -------------------------------------------------------------------------------- /img/cubist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/cubist.jpg -------------------------------------------------------------------------------- /img/denoised_starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/denoised_starry.jpg -------------------------------------------------------------------------------- /img/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/feathers.jpg -------------------------------------------------------------------------------- /img/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/mosaic.jpg -------------------------------------------------------------------------------- /img/results/cubist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/cubist.jpg -------------------------------------------------------------------------------- /img/results/denoised_starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/denoised_starry.jpg -------------------------------------------------------------------------------- /img/results/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/feathers.jpg -------------------------------------------------------------------------------- /img/results/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/mosaic.jpg -------------------------------------------------------------------------------- /img/results/scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/scream.jpg -------------------------------------------------------------------------------- /img/results/style_cubist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_cubist.jpg -------------------------------------------------------------------------------- /img/results/style_denoised_starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_denoised_starry.jpg -------------------------------------------------------------------------------- /img/results/style_feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_feathers.jpg -------------------------------------------------------------------------------- /img/results/style_mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_mosaic.jpg -------------------------------------------------------------------------------- /img/results/style_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_scream.jpg -------------------------------------------------------------------------------- /img/results/style_udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_udnie.jpg -------------------------------------------------------------------------------- /img/results/style_wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_wave.jpg -------------------------------------------------------------------------------- /img/results/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/udnie.jpg -------------------------------------------------------------------------------- /img/results/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/wave.jpg -------------------------------------------------------------------------------- /img/scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/scream.jpg -------------------------------------------------------------------------------- /img/starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/starry.jpg -------------------------------------------------------------------------------- /img/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test.jpg -------------------------------------------------------------------------------- /img/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test1.jpg -------------------------------------------------------------------------------- /img/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test2.jpg -------------------------------------------------------------------------------- /img/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test3.jpg -------------------------------------------------------------------------------- /img/test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test4.jpg -------------------------------------------------------------------------------- /img/test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test5.jpg -------------------------------------------------------------------------------- /img/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/udnie.jpg -------------------------------------------------------------------------------- /img/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/wave.jpg -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | import tensorflow as tf 4 | from nets import nets_factory 5 | from preprocessing import preprocessing_factory 6 | import utils 7 | import os 8 | 9 | slim = tf.contrib.slim 10 | 11 | 12 | def gram(layer): 13 | shape = tf.shape(layer) 14 | num_images = shape[0] 15 | width = shape[1] 16 | height = shape[2] 17 | num_filters = shape[3] 18 | filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters])) 19 | grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters) 20 | 21 | return grams 22 | 23 | 24 | def get_style_features(FLAGS): 25 | """ 26 | For the "style_image", the preprocessing step is: 27 | 1. Resize the shorter side to FLAGS.image_size 28 | 2. Apply central crop 29 | """ 30 | with tf.Graph().as_default(): 31 | network_fn = nets_factory.get_network_fn( 32 | FLAGS.loss_model, 33 | num_classes=1, 34 | is_training=False) 35 | image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( 36 | FLAGS.loss_model, 37 | is_training=False) 38 | 39 | # Get the style image data 40 | size = FLAGS.image_size 41 | img_bytes = tf.read_file(FLAGS.style_image) 42 | if FLAGS.style_image.lower().endswith('png'): 43 | image = tf.image.decode_png(img_bytes) 44 | else: 45 | image = tf.image.decode_jpeg(img_bytes) 46 | # image = _aspect_preserving_resize(image, size) 47 | 48 | # Add the batch dimension 49 | images = tf.expand_dims(image_preprocessing_fn(image, size, size), 0) 50 | # images = tf.stack([image_preprocessing_fn(image, size, size)]) 51 | 52 | _, endpoints_dict = network_fn(images, spatial_squeeze=False) 53 | features = [] 54 | for layer in FLAGS.style_layers: 55 | feature = endpoints_dict[layer] 56 | feature = tf.squeeze(gram(feature), [0]) # remove the batch dimension 57 | features.append(feature) 58 | 59 | with tf.Session() as sess: 60 | # Restore variables for loss network. 61 | init_func = utils._get_init_fn(FLAGS) 62 | init_func(sess) 63 | 64 | # Make sure the 'generated' directory is exists. 65 | if os.path.exists('generated') is False: 66 | os.makedirs('generated') 67 | # Indicate cropped style image path 68 | save_file = 'generated/target_style_' + FLAGS.naming + '.jpg' 69 | # Write preprocessed style image to indicated path 70 | with open(save_file, 'wb') as f: 71 | target_image = image_unprocessing_fn(images[0, :]) 72 | value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8)) 73 | f.write(sess.run(value)) 74 | tf.logging.info('Target style pattern is saved to: %s.' % save_file) 75 | 76 | # Return the features those layers are use for measuring style loss. 77 | return sess.run(features) 78 | 79 | 80 | def style_loss(endpoints_dict, style_features_t, style_layers): 81 | style_loss = 0 82 | style_loss_summary = {} 83 | for style_gram, layer in zip(style_features_t, style_layers): 84 | generated_images, _ = tf.split(endpoints_dict[layer], 2, 0) 85 | size = tf.size(generated_images) 86 | layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size) 87 | style_loss_summary[layer] = layer_style_loss 88 | style_loss += layer_style_loss 89 | return style_loss, style_loss_summary 90 | 91 | 92 | def content_loss(endpoints_dict, content_layers): 93 | content_loss = 0 94 | for layer in content_layers: 95 | generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0) 96 | size = tf.size(generated_images) 97 | content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper 98 | return content_loss 99 | 100 | 101 | def total_variation_loss(layer): 102 | shape = tf.shape(layer) 103 | height = shape[1] 104 | width = shape[2] 105 | y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1]) 106 | x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1]) 107 | loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y)) 108 | return loss 109 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'): 5 | with tf.variable_scope('conv'): 6 | 7 | shape = [kernel, kernel, input_filters, output_filters] 8 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight') 9 | x_padded = tf.pad(x, [[0, 0], [int(kernel / 2), int(kernel / 2)], [int(kernel / 2), int(kernel / 2)], [0, 0]], mode=mode) 10 | return tf.nn.conv2d(x_padded, weight, strides=[1, strides, strides, 1], padding='VALID', name='conv') 11 | 12 | 13 | def conv2d_transpose(x, input_filters, output_filters, kernel, strides): 14 | with tf.variable_scope('conv_transpose'): 15 | 16 | shape = [kernel, kernel, output_filters, input_filters] 17 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight') 18 | 19 | batch_size = tf.shape(x)[0] 20 | height = tf.shape(x)[1] * strides 21 | width = tf.shape(x)[2] * strides 22 | output_shape = tf.stack([batch_size, height, width, output_filters]) 23 | return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose') 24 | 25 | 26 | def resize_conv2d(x, input_filters, output_filters, kernel, strides, training): 27 | ''' 28 | An alternative to transposed convolution where we first resize, then convolve. 29 | See http://distill.pub/2016/deconv-checkerboard/ 30 | 31 | For some reason the shape needs to be statically known for gradient propagation 32 | through tf.image.resize_images, but we only know that for fixed image size, so we 33 | plumb through a "training" argument 34 | ''' 35 | with tf.variable_scope('conv_transpose'): 36 | height = x.get_shape()[1].value if training else tf.shape(x)[1] 37 | width = x.get_shape()[2].value if training else tf.shape(x)[2] 38 | 39 | new_height = height * strides * 2 40 | new_width = width * strides * 2 41 | 42 | x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR) 43 | 44 | # shape = [kernel, kernel, input_filters, output_filters] 45 | # weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight') 46 | return conv2d(x_resized, input_filters, output_filters, kernel, strides) 47 | 48 | 49 | def instance_norm(x): 50 | epsilon = 1e-9 51 | 52 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) 53 | 54 | return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon))) 55 | 56 | 57 | def batch_norm(x, size, training, decay=0.999): 58 | beta = tf.Variable(tf.zeros([size]), name='beta') 59 | scale = tf.Variable(tf.ones([size]), name='scale') 60 | pop_mean = tf.Variable(tf.zeros([size])) 61 | pop_var = tf.Variable(tf.ones([size])) 62 | epsilon = 1e-3 63 | 64 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2]) 65 | train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) 66 | train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) 67 | 68 | def batch_statistics(): 69 | with tf.control_dependencies([train_mean, train_var]): 70 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm') 71 | 72 | def population_statistics(): 73 | return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm') 74 | 75 | return tf.cond(training, batch_statistics, population_statistics) 76 | 77 | 78 | def relu(input): 79 | relu = tf.nn.relu(input) 80 | # convert nan to zero (nan != nan) 81 | nan_to_zero = tf.where(tf.equal(relu, relu), relu, tf.zeros_like(relu)) 82 | return nan_to_zero 83 | 84 | 85 | def residual(x, filters, kernel, strides): 86 | with tf.variable_scope('residual'): 87 | conv1 = conv2d(x, filters, filters, kernel, strides) 88 | conv2 = conv2d(relu(conv1), filters, filters, kernel, strides) 89 | 90 | residual = x + conv2 91 | 92 | return residual 93 | 94 | 95 | def net(image, training): 96 | # Less border effects when padding a little before passing through .. 97 | image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT') 98 | 99 | with tf.variable_scope('conv1'): 100 | conv1 = relu(instance_norm(conv2d(image, 3, 32, 9, 1))) 101 | with tf.variable_scope('conv2'): 102 | conv2 = relu(instance_norm(conv2d(conv1, 32, 64, 3, 2))) 103 | with tf.variable_scope('conv3'): 104 | conv3 = relu(instance_norm(conv2d(conv2, 64, 128, 3, 2))) 105 | with tf.variable_scope('res1'): 106 | res1 = residual(conv3, 128, 3, 1) 107 | with tf.variable_scope('res2'): 108 | res2 = residual(res1, 128, 3, 1) 109 | with tf.variable_scope('res3'): 110 | res3 = residual(res2, 128, 3, 1) 111 | with tf.variable_scope('res4'): 112 | res4 = residual(res3, 128, 3, 1) 113 | with tf.variable_scope('res5'): 114 | res5 = residual(res4, 128, 3, 1) 115 | # print(res5.get_shape()) 116 | with tf.variable_scope('deconv1'): 117 | # deconv1 = relu(instance_norm(conv2d_transpose(res5, 128, 64, 3, 2))) 118 | deconv1 = relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training))) 119 | with tf.variable_scope('deconv2'): 120 | # deconv2 = relu(instance_norm(conv2d_transpose(deconv1, 64, 32, 3, 2))) 121 | deconv2 = relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training))) 122 | with tf.variable_scope('deconv3'): 123 | # deconv_test = relu(instance_norm(conv2d(deconv2, 32, 32, 2, 1))) 124 | deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1))) 125 | 126 | y = (deconv3 + 1) * 127.5 127 | 128 | # Remove border effect reducing padding. 129 | height = tf.shape(y)[1] 130 | width = tf.shape(y)[2] 131 | y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1])) 132 | 133 | return y 134 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2'): 61 | """AlexNet version 2. 62 | 63 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 64 | Parameters from: 65 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 66 | layers-imagenet-1gpu.cfg 67 | 68 | Note: All the fully_connected layers have been transformed to conv2d layers. 69 | To use in classification mode, resize input to 224x224. To use in fully 70 | convolutional mode, set spatial_squeeze to false. 71 | The LRN layers have been removed and change the initializers from 72 | random_normal_initializer to xavier_initializer. 73 | 74 | Args: 75 | inputs: a tensor of size [batch_size, height, width, channels]. 76 | num_classes: number of predicted classes. 77 | is_training: whether or not the model is being trained. 78 | dropout_keep_prob: the probability that activations are kept in the dropout 79 | layers during training. 80 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 81 | outputs. Useful to remove unnecessary dimensions for classification. 82 | scope: Optional scope for the variables. 83 | 84 | Returns: 85 | the last op containing the log predictions and end_points dict. 86 | """ 87 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 88 | end_points_collection = sc.name + '_end_points' 89 | # Collect outputs for conv2d, fully_connected and max_pool2d. 90 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 91 | outputs_collections=[end_points_collection]): 92 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 93 | scope='conv1') 94 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 95 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 96 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 97 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 98 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 99 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 100 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 101 | 102 | # Use conv2d instead of fully_connected layers. 103 | with slim.arg_scope([slim.conv2d], 104 | weights_initializer=trunc_normal(0.005), 105 | biases_initializer=tf.constant_initializer(0.1)): 106 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 107 | scope='fc6') 108 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 109 | scope='dropout6') 110 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 111 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 112 | scope='dropout7') 113 | net = slim.conv2d(net, num_classes, [1, 1], 114 | activation_fn=None, 115 | normalizer_fn=None, 116 | biases_initializer=tf.zeros_initializer, 117 | scope='fc8') 118 | 119 | # Convert end_points_collection into a end_point dict. 120 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 121 | if spatial_squeeze: 122 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 123 | end_points[sc.name + '/fc8'] = net 124 | return net, end_points 125 | alexnet_v2.default_image_size = 224 126 | -------------------------------------------------------------------------------- /nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testEndPoints(self): 52 | batch_size = 5 53 | height, width = 224, 224 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 58 | expected_names = ['alexnet_v2/conv1', 59 | 'alexnet_v2/pool1', 60 | 'alexnet_v2/conv2', 61 | 'alexnet_v2/pool2', 62 | 'alexnet_v2/conv3', 63 | 'alexnet_v2/conv4', 64 | 'alexnet_v2/conv5', 65 | 'alexnet_v2/pool5', 66 | 'alexnet_v2/fc6', 67 | 'alexnet_v2/fc7', 68 | 'alexnet_v2/fc8' 69 | ] 70 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 71 | 72 | def testModelVariables(self): 73 | batch_size = 5 74 | height, width = 224, 224 75 | num_classes = 1000 76 | with self.test_session(): 77 | inputs = tf.random_uniform((batch_size, height, width, 3)) 78 | alexnet.alexnet_v2(inputs, num_classes) 79 | expected_names = ['alexnet_v2/conv1/weights', 80 | 'alexnet_v2/conv1/biases', 81 | 'alexnet_v2/conv2/weights', 82 | 'alexnet_v2/conv2/biases', 83 | 'alexnet_v2/conv3/weights', 84 | 'alexnet_v2/conv3/biases', 85 | 'alexnet_v2/conv4/weights', 86 | 'alexnet_v2/conv4/biases', 87 | 'alexnet_v2/conv5/weights', 88 | 'alexnet_v2/conv5/biases', 89 | 'alexnet_v2/fc6/weights', 90 | 'alexnet_v2/fc6/biases', 91 | 'alexnet_v2/fc7/weights', 92 | 'alexnet_v2/fc7/biases', 93 | 'alexnet_v2/fc8/weights', 94 | 'alexnet_v2/fc8/biases', 95 | ] 96 | model_variables = [v.op.name for v in slim.get_model_variables()] 97 | self.assertSetEqual(set(model_variables), set(expected_names)) 98 | 99 | def testEvaluation(self): 100 | batch_size = 2 101 | height, width = 224, 224 102 | num_classes = 1000 103 | with self.test_session(): 104 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 105 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 106 | self.assertListEqual(logits.get_shape().as_list(), 107 | [batch_size, num_classes]) 108 | predictions = tf.argmax(logits, 1) 109 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 110 | 111 | def testTrainEvalWithReuse(self): 112 | train_batch_size = 2 113 | eval_batch_size = 1 114 | train_height, train_width = 224, 224 115 | eval_height, eval_width = 300, 400 116 | num_classes = 1000 117 | with self.test_session(): 118 | train_inputs = tf.random_uniform( 119 | (train_batch_size, train_height, train_width, 3)) 120 | logits, _ = alexnet.alexnet_v2(train_inputs) 121 | self.assertListEqual(logits.get_shape().as_list(), 122 | [train_batch_size, num_classes]) 123 | tf.get_variable_scope().reuse_variables() 124 | eval_inputs = tf.random_uniform( 125 | (eval_batch_size, eval_height, eval_width, 3)) 126 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 127 | spatial_squeeze=False) 128 | self.assertListEqual(logits.get_shape().as_list(), 129 | [eval_batch_size, 4, 7, num_classes]) 130 | logits = tf.reduce_mean(logits, [1, 2]) 131 | predictions = tf.argmax(logits, 1) 132 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 133 | 134 | def testForward(self): 135 | batch_size = 1 136 | height, width = 224, 224 137 | with self.test_session() as sess: 138 | inputs = tf.random_uniform((batch_size, height, width, 3)) 139 | logits, _ = alexnet.alexnet_v2(inputs) 140 | sess.run(tf.initialize_all_variables()) 141 | output = sess.run(logits) 142 | self.assertTrue(output.any()) 143 | 144 | if __name__ == '__main__': 145 | tf.test.main() 146 | -------------------------------------------------------------------------------- /nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | logits: the pre-softmax activations, a tensor of size 54 | [batch_size, `num_classes`] 55 | end_points: a dictionary from components of the network to the corresponding 56 | activation. 57 | """ 58 | end_points = {} 59 | 60 | with tf.variable_scope(scope, 'CifarNet', [images, num_classes]): 61 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 62 | end_points['conv1'] = net 63 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | end_points['pool1'] = net 65 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 66 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 67 | end_points['conv2'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 69 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 70 | end_points['pool2'] = net 71 | net = slim.flatten(net) 72 | end_points['Flatten'] = net 73 | net = slim.fully_connected(net, 384, scope='fc3') 74 | end_points['fc3'] = net 75 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 76 | scope='dropout3') 77 | net = slim.fully_connected(net, 192, scope='fc4') 78 | end_points['fc4'] = net 79 | logits = slim.fully_connected(net, num_classes, 80 | biases_initializer=tf.zeros_initializer, 81 | weights_initializer=trunc_normal(1/192.0), 82 | weights_regularizer=None, 83 | activation_fn=None, 84 | scope='logits') 85 | 86 | end_points['Logits'] = logits 87 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 88 | 89 | return logits, end_points 90 | cifarnet.default_image_size = 32 91 | 92 | 93 | def cifarnet_arg_scope(weight_decay=0.004): 94 | """Defines the default cifarnet argument scope. 95 | 96 | Args: 97 | weight_decay: The weight decay to use for regularizing the model. 98 | 99 | Returns: 100 | An `arg_scope` to use for the inception v3 model. 101 | """ 102 | with slim.arg_scope( 103 | [slim.conv2d], 104 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 105 | activation_fn=tf.nn.relu): 106 | with slim.arg_scope( 107 | [slim.fully_connected], 108 | biases_initializer=tf.constant_initializer(0.1), 109 | weights_initializer=trunc_normal(0.04), 110 | weights_regularizer=slim.l2_regularizer(weight_decay), 111 | activation_fn=tf.nn.relu) as sc: 112 | return sc 113 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_v1 import inception_v1 25 | from nets.inception_v1 import inception_v1_arg_scope 26 | from nets.inception_v1 import inception_v1_base 27 | from nets.inception_v2 import inception_v2 28 | from nets.inception_v2 import inception_v2_arg_scope 29 | from nets.inception_v2 import inception_v2_base 30 | from nets.inception_v3 import inception_v3 31 | from nets.inception_v3 import inception_v3_arg_scope 32 | from nets.inception_v3 import inception_v3_base 33 | from nets.inception_v4 import inception_v4 34 | from nets.inception_v4 import inception_v4_arg_scope 35 | from nets.inception_v4 import inception_v4_base 36 | # pylint: enable=unused-import 37 | -------------------------------------------------------------------------------- /nets/inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition of the Inception Resnet V2 architecture. 16 | 17 | As described in http://arxiv.org/abs/1602.07261. 18 | 19 | Inception-v4, Inception-ResNet and the Impact of Residual Connections 20 | on Learning 21 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | 28 | import tensorflow as tf 29 | 30 | slim = tf.contrib.slim 31 | 32 | 33 | def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 34 | """Builds the 35x35 resnet block.""" 35 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse): 36 | with tf.variable_scope('Branch_0'): 37 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1') 38 | with tf.variable_scope('Branch_1'): 39 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 40 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3') 41 | with tf.variable_scope('Branch_2'): 42 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 43 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3') 44 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3') 45 | mixed = tf.concat(3, [tower_conv, tower_conv1_1, tower_conv2_2]) 46 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 47 | activation_fn=None, scope='Conv2d_1x1') 48 | net += scale * up 49 | if activation_fn: 50 | net = activation_fn(net) 51 | return net 52 | 53 | 54 | def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 55 | """Builds the 17x17 resnet block.""" 56 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse): 57 | with tf.variable_scope('Branch_0'): 58 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') 59 | with tf.variable_scope('Branch_1'): 60 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1') 61 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7], 62 | scope='Conv2d_0b_1x7') 63 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1], 64 | scope='Conv2d_0c_7x1') 65 | mixed = tf.concat(3, [tower_conv, tower_conv1_2]) 66 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 67 | activation_fn=None, scope='Conv2d_1x1') 68 | net += scale * up 69 | if activation_fn: 70 | net = activation_fn(net) 71 | return net 72 | 73 | 74 | def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 75 | """Builds the 8x8 resnet block.""" 76 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse): 77 | with tf.variable_scope('Branch_0'): 78 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') 79 | with tf.variable_scope('Branch_1'): 80 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1') 81 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3], 82 | scope='Conv2d_0b_1x3') 83 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1], 84 | scope='Conv2d_0c_3x1') 85 | mixed = tf.concat(3, [tower_conv, tower_conv1_2]) 86 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 87 | activation_fn=None, scope='Conv2d_1x1') 88 | net += scale * up 89 | if activation_fn: 90 | net = activation_fn(net) 91 | return net 92 | 93 | 94 | def inception_resnet_v2(inputs, num_classes=1001, is_training=True, 95 | dropout_keep_prob=0.8, 96 | reuse=None, 97 | scope='InceptionResnetV2'): 98 | """Creates the Inception Resnet V2 model. 99 | 100 | Args: 101 | inputs: a 4-D tensor of size [batch_size, height, width, 3]. 102 | num_classes: number of predicted classes. 103 | is_training: whether is training or not. 104 | dropout_keep_prob: float, the fraction to keep before final layer. 105 | reuse: whether or not the network and its variables should be reused. To be 106 | able to reuse 'scope' must be given. 107 | scope: Optional variable_scope. 108 | 109 | Returns: 110 | logits: the logits outputs of the model. 111 | end_points: the set of end_points from the inception model. 112 | """ 113 | end_points = {} 114 | 115 | with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse): 116 | with slim.arg_scope([slim.batch_norm, slim.dropout], 117 | is_training=is_training): 118 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 119 | stride=1, padding='SAME'): 120 | 121 | # 149 x 149 x 32 122 | net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID', 123 | scope='Conv2d_1a_3x3') 124 | end_points['Conv2d_1a_3x3'] = net 125 | # 147 x 147 x 32 126 | net = slim.conv2d(net, 32, 3, padding='VALID', 127 | scope='Conv2d_2a_3x3') 128 | end_points['Conv2d_2a_3x3'] = net 129 | # 147 x 147 x 64 130 | net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3') 131 | end_points['Conv2d_2b_3x3'] = net 132 | # 73 x 73 x 64 133 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID', 134 | scope='MaxPool_3a_3x3') 135 | end_points['MaxPool_3a_3x3'] = net 136 | # 73 x 73 x 80 137 | net = slim.conv2d(net, 80, 1, padding='VALID', 138 | scope='Conv2d_3b_1x1') 139 | end_points['Conv2d_3b_1x1'] = net 140 | # 71 x 71 x 192 141 | net = slim.conv2d(net, 192, 3, padding='VALID', 142 | scope='Conv2d_4a_3x3') 143 | end_points['Conv2d_4a_3x3'] = net 144 | # 35 x 35 x 192 145 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID', 146 | scope='MaxPool_5a_3x3') 147 | end_points['MaxPool_5a_3x3'] = net 148 | 149 | # 35 x 35 x 320 150 | with tf.variable_scope('Mixed_5b'): 151 | with tf.variable_scope('Branch_0'): 152 | tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1') 153 | with tf.variable_scope('Branch_1'): 154 | tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1') 155 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5, 156 | scope='Conv2d_0b_5x5') 157 | with tf.variable_scope('Branch_2'): 158 | tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1') 159 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3, 160 | scope='Conv2d_0b_3x3') 161 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3, 162 | scope='Conv2d_0c_3x3') 163 | with tf.variable_scope('Branch_3'): 164 | tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME', 165 | scope='AvgPool_0a_3x3') 166 | tower_pool_1 = slim.conv2d(tower_pool, 64, 1, 167 | scope='Conv2d_0b_1x1') 168 | net = tf.concat(3, [tower_conv, tower_conv1_1, 169 | tower_conv2_2, tower_pool_1]) 170 | 171 | end_points['Mixed_5b'] = net 172 | net = slim.repeat(net, 10, block35, scale=0.17) 173 | 174 | # 17 x 17 x 1024 175 | with tf.variable_scope('Mixed_6a'): 176 | with tf.variable_scope('Branch_0'): 177 | tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID', 178 | scope='Conv2d_1a_3x3') 179 | with tf.variable_scope('Branch_1'): 180 | tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 181 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3, 182 | scope='Conv2d_0b_3x3') 183 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3, 184 | stride=2, padding='VALID', 185 | scope='Conv2d_1a_3x3') 186 | with tf.variable_scope('Branch_2'): 187 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 188 | scope='MaxPool_1a_3x3') 189 | net = tf.concat(3, [tower_conv, tower_conv1_2, tower_pool]) 190 | 191 | end_points['Mixed_6a'] = net 192 | net = slim.repeat(net, 20, block17, scale=0.10) 193 | 194 | # Auxillary tower 195 | with tf.variable_scope('AuxLogits'): 196 | aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID', 197 | scope='Conv2d_1a_3x3') 198 | aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1') 199 | aux = slim.conv2d(aux, 768, aux.get_shape()[1:3], 200 | padding='VALID', scope='Conv2d_2a_5x5') 201 | aux = slim.flatten(aux) 202 | aux = slim.fully_connected(aux, num_classes, activation_fn=None, 203 | scope='Logits') 204 | end_points['AuxLogits'] = aux 205 | 206 | with tf.variable_scope('Mixed_7a'): 207 | with tf.variable_scope('Branch_0'): 208 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 209 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2, 210 | padding='VALID', scope='Conv2d_1a_3x3') 211 | with tf.variable_scope('Branch_1'): 212 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 213 | tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2, 214 | padding='VALID', scope='Conv2d_1a_3x3') 215 | with tf.variable_scope('Branch_2'): 216 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 217 | tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3, 218 | scope='Conv2d_0b_3x3') 219 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2, 220 | padding='VALID', scope='Conv2d_1a_3x3') 221 | with tf.variable_scope('Branch_3'): 222 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 223 | scope='MaxPool_1a_3x3') 224 | net = tf.concat(3, [tower_conv_1, tower_conv1_1, 225 | tower_conv2_2, tower_pool]) 226 | 227 | end_points['Mixed_7a'] = net 228 | 229 | net = slim.repeat(net, 9, block8, scale=0.20) 230 | net = block8(net, activation_fn=None) 231 | 232 | net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1') 233 | end_points['Conv2d_7b_1x1'] = net 234 | 235 | with tf.variable_scope('Logits'): 236 | end_points['PrePool'] = net 237 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', 238 | scope='AvgPool_1a_8x8') 239 | net = slim.flatten(net) 240 | 241 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 242 | scope='Dropout') 243 | 244 | end_points['PreLogitsFlatten'] = net 245 | logits = slim.fully_connected(net, num_classes, activation_fn=None, 246 | scope='Logits') 247 | end_points['Logits'] = logits 248 | end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions') 249 | 250 | return logits, end_points 251 | inception_resnet_v2.default_image_size = 299 252 | 253 | 254 | def inception_resnet_v2_arg_scope(weight_decay=0.00004, 255 | batch_norm_decay=0.9997, 256 | batch_norm_epsilon=0.001): 257 | """Yields the scope with the default parameters for inception_resnet_v2. 258 | 259 | Args: 260 | weight_decay: the weight decay for weights variables. 261 | batch_norm_decay: decay for the moving average of batch_norm momentums. 262 | batch_norm_epsilon: small float added to variance to avoid dividing by zero. 263 | 264 | Returns: 265 | a arg_scope with the parameters needed for inception_resnet_v2. 266 | """ 267 | # Set weight_decay for weights in conv2d and fully_connected layers. 268 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 269 | weights_regularizer=slim.l2_regularizer(weight_decay), 270 | biases_regularizer=slim.l2_regularizer(weight_decay)): 271 | 272 | batch_norm_params = { 273 | 'decay': batch_norm_decay, 274 | 'epsilon': batch_norm_epsilon, 275 | } 276 | # Set activation_fn and parameters for batch_norm. 277 | with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu, 278 | normalizer_fn=slim.batch_norm, 279 | normalizer_params=batch_norm_params) as scope: 280 | return scope 281 | -------------------------------------------------------------------------------- /nets/inception_resnet_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.inception_resnet_v2.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import inception 23 | 24 | 25 | class InceptionTest(tf.test.TestCase): 26 | 27 | def testBuildLogits(self): 28 | batch_size = 5 29 | height, width = 299, 299 30 | num_classes = 1000 31 | with self.test_session(): 32 | inputs = tf.random_uniform((batch_size, height, width, 3)) 33 | logits, _ = inception.inception_resnet_v2(inputs, num_classes) 34 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits')) 35 | self.assertListEqual(logits.get_shape().as_list(), 36 | [batch_size, num_classes]) 37 | 38 | def testBuildEndPoints(self): 39 | batch_size = 5 40 | height, width = 299, 299 41 | num_classes = 1000 42 | with self.test_session(): 43 | inputs = tf.random_uniform((batch_size, height, width, 3)) 44 | _, end_points = inception.inception_resnet_v2(inputs, num_classes) 45 | self.assertTrue('Logits' in end_points) 46 | logits = end_points['Logits'] 47 | self.assertListEqual(logits.get_shape().as_list(), 48 | [batch_size, num_classes]) 49 | self.assertTrue('AuxLogits' in end_points) 50 | aux_logits = end_points['AuxLogits'] 51 | self.assertListEqual(aux_logits.get_shape().as_list(), 52 | [batch_size, num_classes]) 53 | pre_pool = end_points['PrePool'] 54 | self.assertListEqual(pre_pool.get_shape().as_list(), 55 | [batch_size, 8, 8, 1536]) 56 | 57 | def testVariablesSetDevice(self): 58 | batch_size = 5 59 | height, width = 299, 299 60 | num_classes = 1000 61 | with self.test_session(): 62 | inputs = tf.random_uniform((batch_size, height, width, 3)) 63 | # Force all Variables to reside on the device. 64 | with tf.variable_scope('on_cpu'), tf.device('/cpu:0'): 65 | inception.inception_resnet_v2(inputs, num_classes) 66 | with tf.variable_scope('on_gpu'), tf.device('/gpu:0'): 67 | inception.inception_resnet_v2(inputs, num_classes) 68 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'): 69 | self.assertDeviceEqual(v.device, '/cpu:0') 70 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'): 71 | self.assertDeviceEqual(v.device, '/gpu:0') 72 | 73 | def testHalfSizeImages(self): 74 | batch_size = 5 75 | height, width = 150, 150 76 | num_classes = 1000 77 | with self.test_session(): 78 | inputs = tf.random_uniform((batch_size, height, width, 3)) 79 | logits, end_points = inception.inception_resnet_v2(inputs, num_classes) 80 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits')) 81 | self.assertListEqual(logits.get_shape().as_list(), 82 | [batch_size, num_classes]) 83 | pre_pool = end_points['PrePool'] 84 | self.assertListEqual(pre_pool.get_shape().as_list(), 85 | [batch_size, 3, 3, 1536]) 86 | 87 | def testUnknownBatchSize(self): 88 | batch_size = 1 89 | height, width = 299, 299 90 | num_classes = 1000 91 | with self.test_session() as sess: 92 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 93 | logits, _ = inception.inception_resnet_v2(inputs, num_classes) 94 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits')) 95 | self.assertListEqual(logits.get_shape().as_list(), 96 | [None, num_classes]) 97 | images = tf.random_uniform((batch_size, height, width, 3)) 98 | sess.run(tf.initialize_all_variables()) 99 | output = sess.run(logits, {inputs: images.eval()}) 100 | self.assertEquals(output.shape, (batch_size, num_classes)) 101 | 102 | def testEvaluation(self): 103 | batch_size = 2 104 | height, width = 299, 299 105 | num_classes = 1000 106 | with self.test_session() as sess: 107 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 108 | logits, _ = inception.inception_resnet_v2(eval_inputs, 109 | num_classes, 110 | is_training=False) 111 | predictions = tf.argmax(logits, 1) 112 | sess.run(tf.initialize_all_variables()) 113 | output = sess.run(predictions) 114 | self.assertEquals(output.shape, (batch_size,)) 115 | 116 | def testTrainEvalWithReuse(self): 117 | train_batch_size = 5 118 | eval_batch_size = 2 119 | height, width = 150, 150 120 | num_classes = 1000 121 | with self.test_session() as sess: 122 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 123 | inception.inception_resnet_v2(train_inputs, num_classes) 124 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 125 | logits, _ = inception.inception_resnet_v2(eval_inputs, 126 | num_classes, 127 | is_training=False, 128 | reuse=True) 129 | predictions = tf.argmax(logits, 1) 130 | sess.run(tf.initialize_all_variables()) 131 | output = sess.run(predictions) 132 | self.assertEquals(output.shape, (eval_batch_size,)) 133 | 134 | 135 | if __name__ == '__main__': 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001): 36 | """Defines the default arg scope for inception models. 37 | 38 | Args: 39 | weight_decay: The weight decay to use for regularizing the model. 40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 41 | batch_norm_decay: Decay for batch norm moving average. 42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 43 | in batch norm. 44 | 45 | Returns: 46 | An `arg_scope` to use for the inception models. 47 | """ 48 | batch_norm_params = { 49 | # Decay for the moving averages. 50 | 'decay': batch_norm_decay, 51 | # epsilon to prevent 0s in variance. 52 | 'epsilon': batch_norm_epsilon, 53 | # collection containing update_ops. 54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 55 | } 56 | if use_batch_norm: 57 | normalizer_fn = slim.batch_norm 58 | normalizer_params = batch_norm_params 59 | else: 60 | normalizer_fn = None 61 | normalizer_params = {} 62 | # Set weight_decay for weights in Conv and FC layers. 63 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 64 | weights_regularizer=slim.l2_regularizer(weight_decay)): 65 | with slim.arg_scope( 66 | [slim.conv2d], 67 | weights_initializer=slim.variance_scaling_initializer(), 68 | activation_fn=tf.nn.relu, 69 | normalizer_fn=normalizer_fn, 70 | normalizer_params=normalizer_params) as sc: 71 | return sc 72 | -------------------------------------------------------------------------------- /nets/inception_v1_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nets.inception_v1.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from nets import inception 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | class InceptionV1Test(tf.test.TestCase): 30 | 31 | def testBuildClassificationNetwork(self): 32 | batch_size = 5 33 | height, width = 224, 224 34 | num_classes = 1000 35 | 36 | inputs = tf.random_uniform((batch_size, height, width, 3)) 37 | logits, end_points = inception.inception_v1(inputs, num_classes) 38 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue('Predictions' in end_points) 42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 43 | [batch_size, num_classes]) 44 | 45 | def testBuildBaseNetwork(self): 46 | batch_size = 5 47 | height, width = 224, 224 48 | 49 | inputs = tf.random_uniform((batch_size, height, width, 3)) 50 | mixed_6c, end_points = inception.inception_v1_base(inputs) 51 | self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c')) 52 | self.assertListEqual(mixed_6c.get_shape().as_list(), 53 | [batch_size, 7, 7, 1024]) 54 | expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 55 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 56 | 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 57 | 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 58 | 'Mixed_5b', 'Mixed_5c'] 59 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 60 | 61 | def testBuildOnlyUptoFinalEndpoint(self): 62 | batch_size = 5 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 67 | 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 68 | 'Mixed_5c'] 69 | for index, endpoint in enumerate(endpoints): 70 | with tf.Graph().as_default(): 71 | inputs = tf.random_uniform((batch_size, height, width, 3)) 72 | out_tensor, end_points = inception.inception_v1_base( 73 | inputs, final_endpoint=endpoint) 74 | self.assertTrue(out_tensor.op.name.startswith( 75 | 'InceptionV1/' + endpoint)) 76 | self.assertItemsEqual(endpoints[:index+1], end_points) 77 | 78 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 79 | batch_size = 5 80 | height, width = 224, 224 81 | 82 | inputs = tf.random_uniform((batch_size, height, width, 3)) 83 | _, end_points = inception.inception_v1_base(inputs, 84 | final_endpoint='Mixed_5c') 85 | endpoints_shapes = {'Conv2d_1a_7x7': [5, 112, 112, 64], 86 | 'MaxPool_2a_3x3': [5, 56, 56, 64], 87 | 'Conv2d_2b_1x1': [5, 56, 56, 64], 88 | 'Conv2d_2c_3x3': [5, 56, 56, 192], 89 | 'MaxPool_3a_3x3': [5, 28, 28, 192], 90 | 'Mixed_3b': [5, 28, 28, 256], 91 | 'Mixed_3c': [5, 28, 28, 480], 92 | 'MaxPool_4a_3x3': [5, 14, 14, 480], 93 | 'Mixed_4b': [5, 14, 14, 512], 94 | 'Mixed_4c': [5, 14, 14, 512], 95 | 'Mixed_4d': [5, 14, 14, 512], 96 | 'Mixed_4e': [5, 14, 14, 528], 97 | 'Mixed_4f': [5, 14, 14, 832], 98 | 'MaxPool_5a_2x2': [5, 7, 7, 832], 99 | 'Mixed_5b': [5, 7, 7, 832], 100 | 'Mixed_5c': [5, 7, 7, 1024]} 101 | 102 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 103 | for endpoint_name in endpoints_shapes: 104 | expected_shape = endpoints_shapes[endpoint_name] 105 | self.assertTrue(endpoint_name in end_points) 106 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 107 | expected_shape) 108 | 109 | def testModelHasExpectedNumberOfParameters(self): 110 | batch_size = 5 111 | height, width = 224, 224 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | with slim.arg_scope(inception.inception_v1_arg_scope()): 114 | inception.inception_v1_base(inputs) 115 | total_params, _ = slim.model_analyzer.analyze_vars( 116 | slim.get_model_variables()) 117 | self.assertAlmostEqual(5607184, total_params) 118 | 119 | def testHalfSizeImages(self): 120 | batch_size = 5 121 | height, width = 112, 112 122 | 123 | inputs = tf.random_uniform((batch_size, height, width, 3)) 124 | mixed_5c, _ = inception.inception_v1_base(inputs) 125 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 126 | self.assertListEqual(mixed_5c.get_shape().as_list(), 127 | [batch_size, 4, 4, 1024]) 128 | 129 | def testUnknownImageShape(self): 130 | tf.reset_default_graph() 131 | batch_size = 2 132 | height, width = 224, 224 133 | num_classes = 1000 134 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3)) 135 | with self.test_session() as sess: 136 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3)) 137 | logits, end_points = inception.inception_v1(inputs, num_classes) 138 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | pre_pool = end_points['Mixed_5c'] 142 | feed_dict = {inputs: input_np} 143 | tf.initialize_all_variables().run() 144 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict) 145 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024]) 146 | 147 | def testUnknowBatchSize(self): 148 | batch_size = 1 149 | height, width = 224, 224 150 | num_classes = 1000 151 | 152 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 153 | logits, _ = inception.inception_v1(inputs, num_classes) 154 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 155 | self.assertListEqual(logits.get_shape().as_list(), 156 | [None, num_classes]) 157 | images = tf.random_uniform((batch_size, height, width, 3)) 158 | 159 | with self.test_session() as sess: 160 | sess.run(tf.initialize_all_variables()) 161 | output = sess.run(logits, {inputs: images.eval()}) 162 | self.assertEquals(output.shape, (batch_size, num_classes)) 163 | 164 | def testEvaluation(self): 165 | batch_size = 2 166 | height, width = 224, 224 167 | num_classes = 1000 168 | 169 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 170 | logits, _ = inception.inception_v1(eval_inputs, num_classes, 171 | is_training=False) 172 | predictions = tf.argmax(logits, 1) 173 | 174 | with self.test_session() as sess: 175 | sess.run(tf.initialize_all_variables()) 176 | output = sess.run(predictions) 177 | self.assertEquals(output.shape, (batch_size,)) 178 | 179 | def testTrainEvalWithReuse(self): 180 | train_batch_size = 5 181 | eval_batch_size = 2 182 | height, width = 224, 224 183 | num_classes = 1000 184 | 185 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 186 | inception.inception_v1(train_inputs, num_classes) 187 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 188 | logits, _ = inception.inception_v1(eval_inputs, num_classes, reuse=True) 189 | predictions = tf.argmax(logits, 1) 190 | 191 | with self.test_session() as sess: 192 | sess.run(tf.initialize_all_variables()) 193 | output = sess.run(predictions) 194 | self.assertEquals(output.shape, (eval_batch_size,)) 195 | 196 | def testLogitsNotSqueezed(self): 197 | num_classes = 25 198 | images = tf.random_uniform([1, 224, 224, 3]) 199 | logits, _ = inception.inception_v1(images, 200 | num_classes=num_classes, 201 | spatial_squeeze=False) 202 | 203 | with self.test_session() as sess: 204 | tf.initialize_all_variables().run() 205 | logits_out = sess.run(logits) 206 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) 207 | 208 | 209 | if __name__ == '__main__': 210 | tf.test.main() 211 | -------------------------------------------------------------------------------- /nets/inception_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nets.inception_v2.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from nets import inception 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | class InceptionV2Test(tf.test.TestCase): 30 | 31 | def testBuildClassificationNetwork(self): 32 | batch_size = 5 33 | height, width = 224, 224 34 | num_classes = 1000 35 | 36 | inputs = tf.random_uniform((batch_size, height, width, 3)) 37 | logits, end_points = inception.inception_v2(inputs, num_classes) 38 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue('Predictions' in end_points) 42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 43 | [batch_size, num_classes]) 44 | 45 | def testBuildBaseNetwork(self): 46 | batch_size = 5 47 | height, width = 224, 224 48 | 49 | inputs = tf.random_uniform((batch_size, height, width, 3)) 50 | mixed_5c, end_points = inception.inception_v2_base(inputs) 51 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV2/Mixed_5c')) 52 | self.assertListEqual(mixed_5c.get_shape().as_list(), 53 | [batch_size, 7, 7, 1024]) 54 | expected_endpoints = ['Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b', 55 | 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 56 | 'Mixed_5b', 'Mixed_5c', 'Conv2d_1a_7x7', 57 | 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3', 58 | 'MaxPool_3a_3x3'] 59 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 60 | 61 | def testBuildOnlyUptoFinalEndpoint(self): 62 | batch_size = 5 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'Mixed_4a', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 67 | 'Mixed_5a', 'Mixed_5b', 'Mixed_5c'] 68 | for index, endpoint in enumerate(endpoints): 69 | with tf.Graph().as_default(): 70 | inputs = tf.random_uniform((batch_size, height, width, 3)) 71 | out_tensor, end_points = inception.inception_v2_base( 72 | inputs, final_endpoint=endpoint) 73 | self.assertTrue(out_tensor.op.name.startswith( 74 | 'InceptionV2/' + endpoint)) 75 | self.assertItemsEqual(endpoints[:index+1], end_points) 76 | 77 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 78 | batch_size = 5 79 | height, width = 224, 224 80 | 81 | inputs = tf.random_uniform((batch_size, height, width, 3)) 82 | _, end_points = inception.inception_v2_base(inputs, 83 | final_endpoint='Mixed_5c') 84 | endpoints_shapes = {'Mixed_3b': [batch_size, 28, 28, 256], 85 | 'Mixed_3c': [batch_size, 28, 28, 320], 86 | 'Mixed_4a': [batch_size, 14, 14, 576], 87 | 'Mixed_4b': [batch_size, 14, 14, 576], 88 | 'Mixed_4c': [batch_size, 14, 14, 576], 89 | 'Mixed_4d': [batch_size, 14, 14, 576], 90 | 'Mixed_4e': [batch_size, 14, 14, 576], 91 | 'Mixed_5a': [batch_size, 7, 7, 1024], 92 | 'Mixed_5b': [batch_size, 7, 7, 1024], 93 | 'Mixed_5c': [batch_size, 7, 7, 1024], 94 | 'Conv2d_1a_7x7': [batch_size, 112, 112, 64], 95 | 'MaxPool_2a_3x3': [batch_size, 56, 56, 64], 96 | 'Conv2d_2b_1x1': [batch_size, 56, 56, 64], 97 | 'Conv2d_2c_3x3': [batch_size, 56, 56, 192], 98 | 'MaxPool_3a_3x3': [batch_size, 28, 28, 192]} 99 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 100 | for endpoint_name in endpoints_shapes: 101 | expected_shape = endpoints_shapes[endpoint_name] 102 | self.assertTrue(endpoint_name in end_points) 103 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 104 | expected_shape) 105 | 106 | def testModelHasExpectedNumberOfParameters(self): 107 | batch_size = 5 108 | height, width = 224, 224 109 | inputs = tf.random_uniform((batch_size, height, width, 3)) 110 | with slim.arg_scope(inception.inception_v2_arg_scope()): 111 | inception.inception_v2_base(inputs) 112 | total_params, _ = slim.model_analyzer.analyze_vars( 113 | slim.get_model_variables()) 114 | self.assertAlmostEqual(10173112, total_params) 115 | 116 | def testBuildEndPointsWithDepthMultiplierLessThanOne(self): 117 | batch_size = 5 118 | height, width = 224, 224 119 | num_classes = 1000 120 | 121 | inputs = tf.random_uniform((batch_size, height, width, 3)) 122 | _, end_points = inception.inception_v2(inputs, num_classes) 123 | 124 | endpoint_keys = [key for key in end_points.keys() 125 | if key.startswith('Mixed') or key.startswith('Conv')] 126 | 127 | _, end_points_with_multiplier = inception.inception_v2( 128 | inputs, num_classes, scope='depth_multiplied_net', 129 | depth_multiplier=0.5) 130 | 131 | for key in endpoint_keys: 132 | original_depth = end_points[key].get_shape().as_list()[3] 133 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 134 | self.assertEqual(0.5 * original_depth, new_depth) 135 | 136 | def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self): 137 | batch_size = 5 138 | height, width = 224, 224 139 | num_classes = 1000 140 | 141 | inputs = tf.random_uniform((batch_size, height, width, 3)) 142 | _, end_points = inception.inception_v2(inputs, num_classes) 143 | 144 | endpoint_keys = [key for key in end_points.keys() 145 | if key.startswith('Mixed') or key.startswith('Conv')] 146 | 147 | _, end_points_with_multiplier = inception.inception_v2( 148 | inputs, num_classes, scope='depth_multiplied_net', 149 | depth_multiplier=2.0) 150 | 151 | for key in endpoint_keys: 152 | original_depth = end_points[key].get_shape().as_list()[3] 153 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 154 | self.assertEqual(2.0 * original_depth, new_depth) 155 | 156 | def testRaiseValueErrorWithInvalidDepthMultiplier(self): 157 | batch_size = 5 158 | height, width = 224, 224 159 | num_classes = 1000 160 | 161 | inputs = tf.random_uniform((batch_size, height, width, 3)) 162 | with self.assertRaises(ValueError): 163 | _ = inception.inception_v2(inputs, num_classes, depth_multiplier=-0.1) 164 | with self.assertRaises(ValueError): 165 | _ = inception.inception_v2(inputs, num_classes, depth_multiplier=0.0) 166 | 167 | def testHalfSizeImages(self): 168 | batch_size = 5 169 | height, width = 112, 112 170 | num_classes = 1000 171 | 172 | inputs = tf.random_uniform((batch_size, height, width, 3)) 173 | logits, end_points = inception.inception_v2(inputs, num_classes) 174 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 175 | self.assertListEqual(logits.get_shape().as_list(), 176 | [batch_size, num_classes]) 177 | pre_pool = end_points['Mixed_5c'] 178 | self.assertListEqual(pre_pool.get_shape().as_list(), 179 | [batch_size, 4, 4, 1024]) 180 | 181 | def testUnknownImageShape(self): 182 | tf.reset_default_graph() 183 | batch_size = 2 184 | height, width = 224, 224 185 | num_classes = 1000 186 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3)) 187 | with self.test_session() as sess: 188 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3)) 189 | logits, end_points = inception.inception_v2(inputs, num_classes) 190 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 191 | self.assertListEqual(logits.get_shape().as_list(), 192 | [batch_size, num_classes]) 193 | pre_pool = end_points['Mixed_5c'] 194 | feed_dict = {inputs: input_np} 195 | tf.initialize_all_variables().run() 196 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict) 197 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024]) 198 | 199 | def testUnknowBatchSize(self): 200 | batch_size = 1 201 | height, width = 224, 224 202 | num_classes = 1000 203 | 204 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 205 | logits, _ = inception.inception_v2(inputs, num_classes) 206 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 207 | self.assertListEqual(logits.get_shape().as_list(), 208 | [None, num_classes]) 209 | images = tf.random_uniform((batch_size, height, width, 3)) 210 | 211 | with self.test_session() as sess: 212 | sess.run(tf.initialize_all_variables()) 213 | output = sess.run(logits, {inputs: images.eval()}) 214 | self.assertEquals(output.shape, (batch_size, num_classes)) 215 | 216 | def testEvaluation(self): 217 | batch_size = 2 218 | height, width = 224, 224 219 | num_classes = 1000 220 | 221 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 222 | logits, _ = inception.inception_v2(eval_inputs, num_classes, 223 | is_training=False) 224 | predictions = tf.argmax(logits, 1) 225 | 226 | with self.test_session() as sess: 227 | sess.run(tf.initialize_all_variables()) 228 | output = sess.run(predictions) 229 | self.assertEquals(output.shape, (batch_size,)) 230 | 231 | def testTrainEvalWithReuse(self): 232 | train_batch_size = 5 233 | eval_batch_size = 2 234 | height, width = 150, 150 235 | num_classes = 1000 236 | 237 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 238 | inception.inception_v2(train_inputs, num_classes) 239 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 240 | logits, _ = inception.inception_v2(eval_inputs, num_classes, reuse=True) 241 | predictions = tf.argmax(logits, 1) 242 | 243 | with self.test_session() as sess: 244 | sess.run(tf.initialize_all_variables()) 245 | output = sess.run(predictions) 246 | self.assertEquals(output.shape, (eval_batch_size,)) 247 | 248 | def testLogitsNotSqueezed(self): 249 | num_classes = 25 250 | images = tf.random_uniform([1, 224, 224, 3]) 251 | logits, _ = inception.inception_v2(images, 252 | num_classes=num_classes, 253 | spatial_squeeze=False) 254 | 255 | with self.test_session() as sess: 256 | tf.initialize_all_variables().run() 257 | logits_out = sess.run(logits) 258 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) 259 | 260 | 261 | if __name__ == '__main__': 262 | tf.test.main() 263 | -------------------------------------------------------------------------------- /nets/inception_v3_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nets.inception_v1.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from nets import inception 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | class InceptionV3Test(tf.test.TestCase): 30 | 31 | def testBuildClassificationNetwork(self): 32 | batch_size = 5 33 | height, width = 299, 299 34 | num_classes = 1000 35 | 36 | inputs = tf.random_uniform((batch_size, height, width, 3)) 37 | logits, end_points = inception.inception_v3(inputs, num_classes) 38 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue('Predictions' in end_points) 42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 43 | [batch_size, num_classes]) 44 | 45 | def testBuildBaseNetwork(self): 46 | batch_size = 5 47 | height, width = 299, 299 48 | 49 | inputs = tf.random_uniform((batch_size, height, width, 3)) 50 | final_endpoint, end_points = inception.inception_v3_base(inputs) 51 | self.assertTrue(final_endpoint.op.name.startswith( 52 | 'InceptionV3/Mixed_7c')) 53 | self.assertListEqual(final_endpoint.get_shape().as_list(), 54 | [batch_size, 8, 8, 2048]) 55 | expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 56 | 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 57 | 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 58 | 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 59 | 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'] 60 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 61 | 62 | def testBuildOnlyUptoFinalEndpoint(self): 63 | batch_size = 5 64 | height, width = 299, 299 65 | endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 66 | 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 67 | 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 68 | 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 69 | 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'] 70 | 71 | for index, endpoint in enumerate(endpoints): 72 | with tf.Graph().as_default(): 73 | inputs = tf.random_uniform((batch_size, height, width, 3)) 74 | out_tensor, end_points = inception.inception_v3_base( 75 | inputs, final_endpoint=endpoint) 76 | self.assertTrue(out_tensor.op.name.startswith( 77 | 'InceptionV3/' + endpoint)) 78 | self.assertItemsEqual(endpoints[:index+1], end_points) 79 | 80 | def testBuildAndCheckAllEndPointsUptoMixed7c(self): 81 | batch_size = 5 82 | height, width = 299, 299 83 | 84 | inputs = tf.random_uniform((batch_size, height, width, 3)) 85 | _, end_points = inception.inception_v3_base( 86 | inputs, final_endpoint='Mixed_7c') 87 | endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32], 88 | 'Conv2d_2a_3x3': [batch_size, 147, 147, 32], 89 | 'Conv2d_2b_3x3': [batch_size, 147, 147, 64], 90 | 'MaxPool_3a_3x3': [batch_size, 73, 73, 64], 91 | 'Conv2d_3b_1x1': [batch_size, 73, 73, 80], 92 | 'Conv2d_4a_3x3': [batch_size, 71, 71, 192], 93 | 'MaxPool_5a_3x3': [batch_size, 35, 35, 192], 94 | 'Mixed_5b': [batch_size, 35, 35, 256], 95 | 'Mixed_5c': [batch_size, 35, 35, 288], 96 | 'Mixed_5d': [batch_size, 35, 35, 288], 97 | 'Mixed_6a': [batch_size, 17, 17, 768], 98 | 'Mixed_6b': [batch_size, 17, 17, 768], 99 | 'Mixed_6c': [batch_size, 17, 17, 768], 100 | 'Mixed_6d': [batch_size, 17, 17, 768], 101 | 'Mixed_6e': [batch_size, 17, 17, 768], 102 | 'Mixed_7a': [batch_size, 8, 8, 1280], 103 | 'Mixed_7b': [batch_size, 8, 8, 2048], 104 | 'Mixed_7c': [batch_size, 8, 8, 2048]} 105 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 106 | for endpoint_name in endpoints_shapes: 107 | expected_shape = endpoints_shapes[endpoint_name] 108 | self.assertTrue(endpoint_name in end_points) 109 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 110 | expected_shape) 111 | 112 | def testModelHasExpectedNumberOfParameters(self): 113 | batch_size = 5 114 | height, width = 299, 299 115 | inputs = tf.random_uniform((batch_size, height, width, 3)) 116 | with slim.arg_scope(inception.inception_v3_arg_scope()): 117 | inception.inception_v3_base(inputs) 118 | total_params, _ = slim.model_analyzer.analyze_vars( 119 | slim.get_model_variables()) 120 | self.assertAlmostEqual(21802784, total_params) 121 | 122 | def testBuildEndPoints(self): 123 | batch_size = 5 124 | height, width = 299, 299 125 | num_classes = 1000 126 | 127 | inputs = tf.random_uniform((batch_size, height, width, 3)) 128 | _, end_points = inception.inception_v3(inputs, num_classes) 129 | self.assertTrue('Logits' in end_points) 130 | logits = end_points['Logits'] 131 | self.assertListEqual(logits.get_shape().as_list(), 132 | [batch_size, num_classes]) 133 | self.assertTrue('AuxLogits' in end_points) 134 | aux_logits = end_points['AuxLogits'] 135 | self.assertListEqual(aux_logits.get_shape().as_list(), 136 | [batch_size, num_classes]) 137 | self.assertTrue('Mixed_7c' in end_points) 138 | pre_pool = end_points['Mixed_7c'] 139 | self.assertListEqual(pre_pool.get_shape().as_list(), 140 | [batch_size, 8, 8, 2048]) 141 | self.assertTrue('PreLogits' in end_points) 142 | pre_logits = end_points['PreLogits'] 143 | self.assertListEqual(pre_logits.get_shape().as_list(), 144 | [batch_size, 1, 1, 2048]) 145 | 146 | def testBuildEndPointsWithDepthMultiplierLessThanOne(self): 147 | batch_size = 5 148 | height, width = 299, 299 149 | num_classes = 1000 150 | 151 | inputs = tf.random_uniform((batch_size, height, width, 3)) 152 | _, end_points = inception.inception_v3(inputs, num_classes) 153 | 154 | endpoint_keys = [key for key in end_points.keys() 155 | if key.startswith('Mixed') or key.startswith('Conv')] 156 | 157 | _, end_points_with_multiplier = inception.inception_v3( 158 | inputs, num_classes, scope='depth_multiplied_net', 159 | depth_multiplier=0.5) 160 | 161 | for key in endpoint_keys: 162 | original_depth = end_points[key].get_shape().as_list()[3] 163 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 164 | self.assertEqual(0.5 * original_depth, new_depth) 165 | 166 | def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self): 167 | batch_size = 5 168 | height, width = 299, 299 169 | num_classes = 1000 170 | 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | _, end_points = inception.inception_v3(inputs, num_classes) 173 | 174 | endpoint_keys = [key for key in end_points.keys() 175 | if key.startswith('Mixed') or key.startswith('Conv')] 176 | 177 | _, end_points_with_multiplier = inception.inception_v3( 178 | inputs, num_classes, scope='depth_multiplied_net', 179 | depth_multiplier=2.0) 180 | 181 | for key in endpoint_keys: 182 | original_depth = end_points[key].get_shape().as_list()[3] 183 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 184 | self.assertEqual(2.0 * original_depth, new_depth) 185 | 186 | def testRaiseValueErrorWithInvalidDepthMultiplier(self): 187 | batch_size = 5 188 | height, width = 299, 299 189 | num_classes = 1000 190 | 191 | inputs = tf.random_uniform((batch_size, height, width, 3)) 192 | with self.assertRaises(ValueError): 193 | _ = inception.inception_v3(inputs, num_classes, depth_multiplier=-0.1) 194 | with self.assertRaises(ValueError): 195 | _ = inception.inception_v3(inputs, num_classes, depth_multiplier=0.0) 196 | 197 | def testHalfSizeImages(self): 198 | batch_size = 5 199 | height, width = 150, 150 200 | num_classes = 1000 201 | 202 | inputs = tf.random_uniform((batch_size, height, width, 3)) 203 | logits, end_points = inception.inception_v3(inputs, num_classes) 204 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits')) 205 | self.assertListEqual(logits.get_shape().as_list(), 206 | [batch_size, num_classes]) 207 | pre_pool = end_points['Mixed_7c'] 208 | self.assertListEqual(pre_pool.get_shape().as_list(), 209 | [batch_size, 3, 3, 2048]) 210 | 211 | def testUnknownImageShape(self): 212 | tf.reset_default_graph() 213 | batch_size = 2 214 | height, width = 299, 299 215 | num_classes = 1000 216 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3)) 217 | with self.test_session() as sess: 218 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3)) 219 | logits, end_points = inception.inception_v3(inputs, num_classes) 220 | self.assertListEqual(logits.get_shape().as_list(), 221 | [batch_size, num_classes]) 222 | pre_pool = end_points['Mixed_7c'] 223 | feed_dict = {inputs: input_np} 224 | tf.initialize_all_variables().run() 225 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict) 226 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048]) 227 | 228 | def testUnknowBatchSize(self): 229 | batch_size = 1 230 | height, width = 299, 299 231 | num_classes = 1000 232 | 233 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 234 | logits, _ = inception.inception_v3(inputs, num_classes) 235 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits')) 236 | self.assertListEqual(logits.get_shape().as_list(), 237 | [None, num_classes]) 238 | images = tf.random_uniform((batch_size, height, width, 3)) 239 | 240 | with self.test_session() as sess: 241 | sess.run(tf.initialize_all_variables()) 242 | output = sess.run(logits, {inputs: images.eval()}) 243 | self.assertEquals(output.shape, (batch_size, num_classes)) 244 | 245 | def testEvaluation(self): 246 | batch_size = 2 247 | height, width = 299, 299 248 | num_classes = 1000 249 | 250 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 251 | logits, _ = inception.inception_v3(eval_inputs, num_classes, 252 | is_training=False) 253 | predictions = tf.argmax(logits, 1) 254 | 255 | with self.test_session() as sess: 256 | sess.run(tf.initialize_all_variables()) 257 | output = sess.run(predictions) 258 | self.assertEquals(output.shape, (batch_size,)) 259 | 260 | def testTrainEvalWithReuse(self): 261 | train_batch_size = 5 262 | eval_batch_size = 2 263 | height, width = 150, 150 264 | num_classes = 1000 265 | 266 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 267 | inception.inception_v3(train_inputs, num_classes) 268 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 269 | logits, _ = inception.inception_v3(eval_inputs, num_classes, 270 | is_training=False, reuse=True) 271 | predictions = tf.argmax(logits, 1) 272 | 273 | with self.test_session() as sess: 274 | sess.run(tf.initialize_all_variables()) 275 | output = sess.run(predictions) 276 | self.assertEquals(output.shape, (eval_batch_size,)) 277 | 278 | def testLogitsNotSqueezed(self): 279 | num_classes = 25 280 | images = tf.random_uniform([1, 299, 299, 3]) 281 | logits, _ = inception.inception_v3(images, 282 | num_classes=num_classes, 283 | spatial_squeeze=False) 284 | 285 | with self.test_session() as sess: 286 | tf.initialize_all_variables().run() 287 | logits_out = sess.run(logits) 288 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) 289 | 290 | 291 | if __name__ == '__main__': 292 | tf.test.main() 293 | -------------------------------------------------------------------------------- /nets/inception_v4_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.inception_v4.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import inception 23 | 24 | 25 | class InceptionTest(tf.test.TestCase): 26 | 27 | def testBuildLogits(self): 28 | batch_size = 5 29 | height, width = 299, 299 30 | num_classes = 1000 31 | inputs = tf.random_uniform((batch_size, height, width, 3)) 32 | logits, end_points = inception.inception_v4(inputs, num_classes) 33 | auxlogits = end_points['AuxLogits'] 34 | predictions = end_points['Predictions'] 35 | self.assertTrue(auxlogits.op.name.startswith('InceptionV4/AuxLogits')) 36 | self.assertListEqual(auxlogits.get_shape().as_list(), 37 | [batch_size, num_classes]) 38 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue(predictions.op.name.startswith( 42 | 'InceptionV4/Logits/Predictions')) 43 | self.assertListEqual(predictions.get_shape().as_list(), 44 | [batch_size, num_classes]) 45 | 46 | def testBuildWithoutAuxLogits(self): 47 | batch_size = 5 48 | height, width = 299, 299 49 | num_classes = 1000 50 | inputs = tf.random_uniform((batch_size, height, width, 3)) 51 | logits, endpoints = inception.inception_v4(inputs, num_classes, 52 | create_aux_logits=False) 53 | self.assertFalse('AuxLogits' in endpoints) 54 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 55 | self.assertListEqual(logits.get_shape().as_list(), 56 | [batch_size, num_classes]) 57 | 58 | def testAllEndPointsShapes(self): 59 | batch_size = 5 60 | height, width = 299, 299 61 | num_classes = 1000 62 | inputs = tf.random_uniform((batch_size, height, width, 3)) 63 | _, end_points = inception.inception_v4(inputs, num_classes) 64 | endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32], 65 | 'Conv2d_2a_3x3': [batch_size, 147, 147, 32], 66 | 'Conv2d_2b_3x3': [batch_size, 147, 147, 64], 67 | 'Mixed_3a': [batch_size, 73, 73, 160], 68 | 'Mixed_4a': [batch_size, 71, 71, 192], 69 | 'Mixed_5a': [batch_size, 35, 35, 384], 70 | # 4 x Inception-A blocks 71 | 'Mixed_5b': [batch_size, 35, 35, 384], 72 | 'Mixed_5c': [batch_size, 35, 35, 384], 73 | 'Mixed_5d': [batch_size, 35, 35, 384], 74 | 'Mixed_5e': [batch_size, 35, 35, 384], 75 | # Reduction-A block 76 | 'Mixed_6a': [batch_size, 17, 17, 1024], 77 | # 7 x Inception-B blocks 78 | 'Mixed_6b': [batch_size, 17, 17, 1024], 79 | 'Mixed_6c': [batch_size, 17, 17, 1024], 80 | 'Mixed_6d': [batch_size, 17, 17, 1024], 81 | 'Mixed_6e': [batch_size, 17, 17, 1024], 82 | 'Mixed_6f': [batch_size, 17, 17, 1024], 83 | 'Mixed_6g': [batch_size, 17, 17, 1024], 84 | 'Mixed_6h': [batch_size, 17, 17, 1024], 85 | # Reduction-A block 86 | 'Mixed_7a': [batch_size, 8, 8, 1536], 87 | # 3 x Inception-C blocks 88 | 'Mixed_7b': [batch_size, 8, 8, 1536], 89 | 'Mixed_7c': [batch_size, 8, 8, 1536], 90 | 'Mixed_7d': [batch_size, 8, 8, 1536], 91 | # Logits and predictions 92 | 'AuxLogits': [batch_size, num_classes], 93 | 'PreLogitsFlatten': [batch_size, 1536], 94 | 'Logits': [batch_size, num_classes], 95 | 'Predictions': [batch_size, num_classes]} 96 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 97 | for endpoint_name in endpoints_shapes: 98 | expected_shape = endpoints_shapes[endpoint_name] 99 | self.assertTrue(endpoint_name in end_points) 100 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 101 | expected_shape) 102 | 103 | def testBuildBaseNetwork(self): 104 | batch_size = 5 105 | height, width = 299, 299 106 | inputs = tf.random_uniform((batch_size, height, width, 3)) 107 | net, end_points = inception.inception_v4_base(inputs) 108 | self.assertTrue(net.op.name.startswith( 109 | 'InceptionV4/Mixed_7d')) 110 | self.assertListEqual(net.get_shape().as_list(), [batch_size, 8, 8, 1536]) 111 | expected_endpoints = [ 112 | 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', 113 | 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 114 | 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 115 | 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 116 | 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'] 117 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 118 | for name, op in end_points.iteritems(): 119 | self.assertTrue(op.name.startswith('InceptionV4/' + name)) 120 | 121 | def testBuildOnlyUpToFinalEndpoint(self): 122 | batch_size = 5 123 | height, width = 299, 299 124 | all_endpoints = [ 125 | 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', 126 | 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 127 | 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 128 | 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 129 | 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'] 130 | for index, endpoint in enumerate(all_endpoints): 131 | with tf.Graph().as_default(): 132 | inputs = tf.random_uniform((batch_size, height, width, 3)) 133 | out_tensor, end_points = inception.inception_v4_base( 134 | inputs, final_endpoint=endpoint) 135 | self.assertTrue(out_tensor.op.name.startswith( 136 | 'InceptionV4/' + endpoint)) 137 | self.assertItemsEqual(all_endpoints[:index+1], end_points) 138 | 139 | def testVariablesSetDevice(self): 140 | batch_size = 5 141 | height, width = 299, 299 142 | num_classes = 1000 143 | inputs = tf.random_uniform((batch_size, height, width, 3)) 144 | # Force all Variables to reside on the device. 145 | with tf.variable_scope('on_cpu'), tf.device('/cpu:0'): 146 | inception.inception_v4(inputs, num_classes) 147 | with tf.variable_scope('on_gpu'), tf.device('/gpu:0'): 148 | inception.inception_v4(inputs, num_classes) 149 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'): 150 | self.assertDeviceEqual(v.device, '/cpu:0') 151 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'): 152 | self.assertDeviceEqual(v.device, '/gpu:0') 153 | 154 | def testHalfSizeImages(self): 155 | batch_size = 5 156 | height, width = 150, 150 157 | num_classes = 1000 158 | inputs = tf.random_uniform((batch_size, height, width, 3)) 159 | logits, end_points = inception.inception_v4(inputs, num_classes) 160 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [batch_size, num_classes]) 163 | pre_pool = end_points['Mixed_7d'] 164 | self.assertListEqual(pre_pool.get_shape().as_list(), 165 | [batch_size, 3, 3, 1536]) 166 | 167 | def testUnknownBatchSize(self): 168 | batch_size = 1 169 | height, width = 299, 299 170 | num_classes = 1000 171 | with self.test_session() as sess: 172 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 173 | logits, _ = inception.inception_v4(inputs, num_classes) 174 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 175 | self.assertListEqual(logits.get_shape().as_list(), 176 | [None, num_classes]) 177 | images = tf.random_uniform((batch_size, height, width, 3)) 178 | sess.run(tf.initialize_all_variables()) 179 | output = sess.run(logits, {inputs: images.eval()}) 180 | self.assertEquals(output.shape, (batch_size, num_classes)) 181 | 182 | def testEvaluation(self): 183 | batch_size = 2 184 | height, width = 299, 299 185 | num_classes = 1000 186 | with self.test_session() as sess: 187 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 188 | logits, _ = inception.inception_v4(eval_inputs, 189 | num_classes, 190 | is_training=False) 191 | predictions = tf.argmax(logits, 1) 192 | sess.run(tf.initialize_all_variables()) 193 | output = sess.run(predictions) 194 | self.assertEquals(output.shape, (batch_size,)) 195 | 196 | def testTrainEvalWithReuse(self): 197 | train_batch_size = 5 198 | eval_batch_size = 2 199 | height, width = 150, 150 200 | num_classes = 1000 201 | with self.test_session() as sess: 202 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 203 | inception.inception_v4(train_inputs, num_classes) 204 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 205 | logits, _ = inception.inception_v4(eval_inputs, 206 | num_classes, 207 | is_training=False, 208 | reuse=True) 209 | predictions = tf.argmax(logits, 1) 210 | sess.run(tf.initialize_all_variables()) 211 | output = sess.run(predictions) 212 | self.assertEquals(output.shape, (eval_batch_size,)) 213 | 214 | 215 | if __name__ == '__main__': 216 | tf.test.main() 217 | -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. 44 | is_training: specifies whether or not we're currently training the model. 45 | This variable will determine the behaviour of the dropout layer. 46 | dropout_keep_prob: the percentage of activation values that are retained. 47 | prediction_fn: a function to get predictions out of logits. 48 | scope: Optional variable_scope. 49 | 50 | Returns: 51 | logits: the pre-softmax activations, a tensor of size 52 | [batch_size, `num_classes`] 53 | end_points: a dictionary from components of the network to the corresponding 54 | activation. 55 | """ 56 | end_points = {} 57 | 58 | with tf.variable_scope(scope, 'LeNet', [images, num_classes]): 59 | net = slim.conv2d(images, 32, [5, 5], scope='conv1') 60 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 61 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 62 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 63 | net = slim.flatten(net) 64 | end_points['Flatten'] = net 65 | 66 | net = slim.fully_connected(net, 1024, scope='fc3') 67 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 68 | scope='dropout3') 69 | logits = slim.fully_connected(net, num_classes, activation_fn=None, 70 | scope='fc4') 71 | 72 | end_points['Logits'] = logits 73 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 74 | 75 | return logits, end_points 76 | lenet.default_image_size = 28 77 | 78 | 79 | def lenet_arg_scope(weight_decay=0.0): 80 | """Defines the default lenet argument scope. 81 | 82 | Args: 83 | weight_decay: The weight decay to use for regularizing the model. 84 | 85 | Returns: 86 | An `arg_scope` to use for the inception v3 model. 87 | """ 88 | with slim.arg_scope( 89 | [slim.conv2d, slim.fully_connected], 90 | weights_regularizer=slim.l2_regularizer(weight_decay), 91 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 92 | activation_fn=tf.nn.relu) as sc: 93 | return sc 94 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import overfeat 29 | from nets import resnet_v1 30 | from nets import resnet_v2 31 | from nets import vgg 32 | 33 | slim = tf.contrib.slim 34 | 35 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 36 | 'cifarnet': cifarnet.cifarnet, 37 | 'overfeat': overfeat.overfeat, 38 | 'vgg_a': vgg.vgg_a, 39 | 'vgg_16': vgg.vgg_16, 40 | 'vgg_19': vgg.vgg_19, 41 | 'inception_v1': inception.inception_v1, 42 | 'inception_v2': inception.inception_v2, 43 | 'inception_v3': inception.inception_v3, 44 | 'inception_v4': inception.inception_v4, 45 | 'inception_resnet_v2': inception.inception_resnet_v2, 46 | 'lenet': lenet.lenet, 47 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 48 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 49 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 50 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 51 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 52 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 53 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 54 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 55 | } 56 | 57 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 58 | 'cifarnet': cifarnet.cifarnet_arg_scope, 59 | 'overfeat': overfeat.overfeat_arg_scope, 60 | 'vgg_a': vgg.vgg_arg_scope, 61 | 'vgg_16': vgg.vgg_arg_scope, 62 | 'vgg_19': vgg.vgg_arg_scope, 63 | 'inception_v1': inception.inception_v3_arg_scope, 64 | 'inception_v2': inception.inception_v3_arg_scope, 65 | 'inception_v3': inception.inception_v3_arg_scope, 66 | 'inception_v4': inception.inception_v4_arg_scope, 67 | 'inception_resnet_v2': 68 | inception.inception_resnet_v2_arg_scope, 69 | 'lenet': lenet.lenet_arg_scope, 70 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 71 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 72 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 73 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 74 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 75 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 76 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 77 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 78 | } 79 | 80 | 81 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 82 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 83 | 84 | Args: 85 | name: The name of the network. 86 | num_classes: The number of classes to use for classification. 87 | weight_decay: The l2 coefficient for the model weights. 88 | is_training: `True` if the model is being used for training and `False` 89 | otherwise. 90 | 91 | Returns: 92 | network_fn: A function that applies the model to a batch of images. It has 93 | the following signature: 94 | logits, end_points = network_fn(images) 95 | Raises: 96 | ValueError: If network `name` is not recognized. 97 | """ 98 | if name not in networks_map: 99 | raise ValueError('Name of network unknown %s' % name) 100 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 101 | func = networks_map[name] 102 | @functools.wraps(func) 103 | def network_fn(images, **kwargs): 104 | with slim.arg_scope(arg_scope): 105 | return func(images, num_classes, is_training=is_training, **kwargs) 106 | if hasattr(func, 'default_image_size'): 107 | network_fn.default_image_size = func.default_image_size 108 | 109 | return network_fn 110 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFn(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map: 34 | with self.test_session(): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | if __name__ == '__main__': 46 | tf.test.main() 47 | -------------------------------------------------------------------------------- /nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat'): 56 | """Contains the model definition for the OverFeat network. 57 | 58 | The definition for the network was obtained from: 59 | OverFeat: Integrated Recognition, Localization and Detection using 60 | Convolutional Networks 61 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 62 | Yann LeCun, 2014 63 | http://arxiv.org/abs/1312.6229 64 | 65 | Note: All the fully_connected layers have been transformed to conv2d layers. 66 | To use in classification mode, resize input to 231x231. To use in fully 67 | convolutional mode, set spatial_squeeze to false. 68 | 69 | Args: 70 | inputs: a tensor of size [batch_size, height, width, channels]. 71 | num_classes: number of predicted classes. 72 | is_training: whether or not the model is being trained. 73 | dropout_keep_prob: the probability that activations are kept in the dropout 74 | layers during training. 75 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 76 | outputs. Useful to remove unnecessary dimensions for classification. 77 | scope: Optional scope for the variables. 78 | 79 | Returns: 80 | the last op containing the log predictions and end_points dict. 81 | 82 | """ 83 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 84 | end_points_collection = sc.name + '_end_points' 85 | # Collect outputs for conv2d, fully_connected and max_pool2d 86 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 87 | outputs_collections=end_points_collection): 88 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 89 | scope='conv1') 90 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 91 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 92 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 93 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 94 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 95 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 96 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 97 | with slim.arg_scope([slim.conv2d], 98 | weights_initializer=trunc_normal(0.005), 99 | biases_initializer=tf.constant_initializer(0.1)): 100 | # Use conv2d instead of fully_connected layers. 101 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 102 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 103 | scope='dropout6') 104 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 105 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 106 | scope='dropout7') 107 | net = slim.conv2d(net, num_classes, [1, 1], 108 | activation_fn=None, 109 | normalizer_fn=None, 110 | biases_initializer=tf.zeros_initializer, 111 | scope='fc8') 112 | # Convert end_points_collection into a end_point dict. 113 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 114 | if spatial_squeeze: 115 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 116 | end_points[sc.name + '/fc8'] = net 117 | return net, end_points 118 | overfeat.default_image_size = 231 119 | -------------------------------------------------------------------------------- /nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testEndPoints(self): 52 | batch_size = 5 53 | height, width = 231, 231 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | _, end_points = overfeat.overfeat(inputs, num_classes) 58 | expected_names = ['overfeat/conv1', 59 | 'overfeat/pool1', 60 | 'overfeat/conv2', 61 | 'overfeat/pool2', 62 | 'overfeat/conv3', 63 | 'overfeat/conv4', 64 | 'overfeat/conv5', 65 | 'overfeat/pool5', 66 | 'overfeat/fc6', 67 | 'overfeat/fc7', 68 | 'overfeat/fc8' 69 | ] 70 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 71 | 72 | def testModelVariables(self): 73 | batch_size = 5 74 | height, width = 231, 231 75 | num_classes = 1000 76 | with self.test_session(): 77 | inputs = tf.random_uniform((batch_size, height, width, 3)) 78 | overfeat.overfeat(inputs, num_classes) 79 | expected_names = ['overfeat/conv1/weights', 80 | 'overfeat/conv1/biases', 81 | 'overfeat/conv2/weights', 82 | 'overfeat/conv2/biases', 83 | 'overfeat/conv3/weights', 84 | 'overfeat/conv3/biases', 85 | 'overfeat/conv4/weights', 86 | 'overfeat/conv4/biases', 87 | 'overfeat/conv5/weights', 88 | 'overfeat/conv5/biases', 89 | 'overfeat/fc6/weights', 90 | 'overfeat/fc6/biases', 91 | 'overfeat/fc7/weights', 92 | 'overfeat/fc7/biases', 93 | 'overfeat/fc8/weights', 94 | 'overfeat/fc8/biases', 95 | ] 96 | model_variables = [v.op.name for v in slim.get_model_variables()] 97 | self.assertSetEqual(set(model_variables), set(expected_names)) 98 | 99 | def testEvaluation(self): 100 | batch_size = 2 101 | height, width = 231, 231 102 | num_classes = 1000 103 | with self.test_session(): 104 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 105 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 106 | self.assertListEqual(logits.get_shape().as_list(), 107 | [batch_size, num_classes]) 108 | predictions = tf.argmax(logits, 1) 109 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 110 | 111 | def testTrainEvalWithReuse(self): 112 | train_batch_size = 2 113 | eval_batch_size = 1 114 | train_height, train_width = 231, 231 115 | eval_height, eval_width = 281, 281 116 | num_classes = 1000 117 | with self.test_session(): 118 | train_inputs = tf.random_uniform( 119 | (train_batch_size, train_height, train_width, 3)) 120 | logits, _ = overfeat.overfeat(train_inputs) 121 | self.assertListEqual(logits.get_shape().as_list(), 122 | [train_batch_size, num_classes]) 123 | tf.get_variable_scope().reuse_variables() 124 | eval_inputs = tf.random_uniform( 125 | (eval_batch_size, eval_height, eval_width, 3)) 126 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 127 | spatial_squeeze=False) 128 | self.assertListEqual(logits.get_shape().as_list(), 129 | [eval_batch_size, 2, 2, num_classes]) 130 | logits = tf.reduce_mean(logits, [1, 2]) 131 | predictions = tf.argmax(logits, 1) 132 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 133 | 134 | def testForward(self): 135 | batch_size = 1 136 | height, width = 231, 231 137 | with self.test_session() as sess: 138 | inputs = tf.random_uniform((batch_size, height, width, 3)) 139 | logits, _ = overfeat.overfeat(inputs) 140 | sess.run(tf.initialize_all_variables()) 141 | output = sess.run(logits) 142 | self.assertTrue(output.any()) 143 | 144 | if __name__ == '__main__': 145 | tf.test.main() 146 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | outputs_collections=None): 128 | """Stacks ResNet `Blocks` and controls output feature density. 129 | 130 | First, this function creates scopes for the ResNet in the form of 131 | 'block_name/unit_1', 'block_name/unit_2', etc. 132 | 133 | Second, this function allows the user to explicitly control the ResNet 134 | output_stride, which is the ratio of the input to output spatial resolution. 135 | This is useful for dense prediction tasks such as semantic segmentation or 136 | object detection. 137 | 138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 139 | factor of 2 when transitioning between consecutive ResNet blocks. This results 140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 141 | half the nominal network stride (e.g., output_stride=4), then we compute 142 | responses twice. 143 | 144 | Control of the output feature density is implemented by atrous convolution. 145 | 146 | Args: 147 | net: A `Tensor` of size [batch, height, width, channels]. 148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 149 | element is a ResNet `Block` object describing the units in the `Block`. 150 | output_stride: If `None`, then the output will be computed at the nominal 151 | network stride. If output_stride is not `None`, it specifies the requested 152 | ratio of input to output spatial resolution, which needs to be equal to 153 | the product of unit strides from the start up to some level of the ResNet. 154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 156 | is equivalent to output_stride=24). 157 | outputs_collections: Collection to add the ResNet block outputs. 158 | 159 | Returns: 160 | net: Output tensor with stride equal to the specified output_stride. 161 | 162 | Raises: 163 | ValueError: If the target output_stride is not valid. 164 | """ 165 | # The current_stride variable keeps track of the effective stride of the 166 | # activations. This allows us to invoke atrous convolution whenever applying 167 | # the next residual unit would result in the activations having stride larger 168 | # than the target output_stride. 169 | current_stride = 1 170 | 171 | # The atrous convolution rate parameter. 172 | rate = 1 173 | 174 | for block in blocks: 175 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 176 | for i, unit in enumerate(block.args): 177 | if output_stride is not None and current_stride > output_stride: 178 | raise ValueError('The target output_stride cannot be reached.') 179 | 180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 181 | unit_depth, unit_depth_bottleneck, unit_stride = unit 182 | 183 | # If we have reached the target output_stride, then we need to employ 184 | # atrous convolution with stride=1 and multiply the atrous rate by the 185 | # current unit's stride for use in subsequent layers. 186 | if output_stride is not None and current_stride == output_stride: 187 | net = block.unit_fn(net, 188 | depth=unit_depth, 189 | depth_bottleneck=unit_depth_bottleneck, 190 | stride=1, 191 | rate=rate) 192 | rate *= unit_stride 193 | 194 | else: 195 | net = block.unit_fn(net, 196 | depth=unit_depth, 197 | depth_bottleneck=unit_depth_bottleneck, 198 | stride=unit_stride, 199 | rate=1) 200 | current_stride *= unit_stride 201 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 202 | 203 | if output_stride is not None and current_stride != output_stride: 204 | raise ValueError('The target output_stride cannot be reached.') 205 | 206 | return net 207 | 208 | 209 | def resnet_arg_scope(weight_decay=0.0001, 210 | batch_norm_decay=0.997, 211 | batch_norm_epsilon=1e-5, 212 | batch_norm_scale=True): 213 | """Defines the default ResNet arg scope. 214 | 215 | TODO(gpapan): The batch-normalization related default values above are 216 | appropriate for use in conjunction with the reference ResNet models 217 | released at https://github.com/KaimingHe/deep-residual-networks. When 218 | training ResNets from scratch, they might need to be tuned. 219 | 220 | Args: 221 | weight_decay: The weight decay to use for regularizing the model. 222 | batch_norm_decay: The moving average decay when estimating layer activation 223 | statistics in batch normalization. 224 | batch_norm_epsilon: Small constant to prevent division by zero when 225 | normalizing activations by their variance in batch normalization. 226 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 227 | activations in the batch normalization layer. 228 | 229 | Returns: 230 | An `arg_scope` to use for the resnet models. 231 | """ 232 | batch_norm_params = { 233 | 'decay': batch_norm_decay, 234 | 'epsilon': batch_norm_epsilon, 235 | 'scale': batch_norm_scale, 236 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 237 | } 238 | 239 | with slim.arg_scope( 240 | [slim.conv2d], 241 | weights_regularizer=slim.l2_regularizer(weight_decay), 242 | weights_initializer=slim.variance_scaling_initializer(), 243 | activation_fn=tf.nn.relu, 244 | normalizer_fn=slim.batch_norm, 245 | normalizer_params=batch_norm_params): 246 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 247 | # The following implies padding='SAME' for pool1, which makes feature 248 | # alignment easier for dense prediction tasks. This is also used in 249 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 250 | # code of 'Deep Residual Learning for Image Recognition' uses 251 | # padding='VALID' for pool1. You can switch to that choice by setting 252 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 253 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 254 | return arg_sc 255 | -------------------------------------------------------------------------------- /nets/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | 17 | The 'v1' residual networks (ResNets) implemented in this module were proposed 18 | by: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | 22 | Other variants were introduced in: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The networks defined in this module utilize the bottleneck building block of 27 | [1] with projection shortcuts only for increasing depths. They employ batch 28 | normalization *after* every weight layer. This is the architecture used by 29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 31 | architecture and the alternative 'v2' architecture of [2] which uses batch 32 | normalization *before* every weight layer in the so-called full pre-activation 33 | units. 34 | 35 | Typical use: 36 | 37 | from tensorflow.contrib.slim.nets import resnet_v1 38 | 39 | ResNet-101 for image classification into 1000 classes: 40 | 41 | # inputs has shape [batch, 224, 224, 3] 42 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 43 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 44 | 45 | ResNet-101 for semantic segmentation into 21 classes: 46 | 47 | # inputs has shape [batch, 513, 513, 3] 48 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 49 | net, end_points = resnet_v1.resnet_v1_101(inputs, 50 | 21, 51 | is_training=False, 52 | global_pool=False, 53 | output_stride=16) 54 | """ 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | 61 | from nets import resnet_utils 62 | 63 | 64 | resnet_arg_scope = resnet_utils.resnet_arg_scope 65 | slim = tf.contrib.slim 66 | 67 | 68 | @slim.add_arg_scope 69 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 70 | outputs_collections=None, scope=None): 71 | """Bottleneck residual unit variant with BN after convolutions. 72 | 73 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 74 | its definition. Note that we use here the bottleneck variant which has an 75 | extra bottleneck layer. 76 | 77 | When putting together two consecutive ResNet blocks that use this unit, one 78 | should use stride = 2 in the last unit of the first block. 79 | 80 | Args: 81 | inputs: A tensor of size [batch, height, width, channels]. 82 | depth: The depth of the ResNet unit output. 83 | depth_bottleneck: The depth of the bottleneck layers. 84 | stride: The ResNet unit's stride. Determines the amount of downsampling of 85 | the units output compared to its input. 86 | rate: An integer, rate for atrous convolution. 87 | outputs_collections: Collection to add the ResNet unit output. 88 | scope: Optional variable_scope. 89 | 90 | Returns: 91 | The ResNet unit's output. 92 | """ 93 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 94 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 95 | if depth == depth_in: 96 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 97 | else: 98 | shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride, 99 | activation_fn=None, scope='shortcut') 100 | 101 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 102 | scope='conv1') 103 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 104 | rate=rate, scope='conv2') 105 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 106 | activation_fn=None, scope='conv3') 107 | 108 | output = tf.nn.relu(shortcut + residual) 109 | 110 | return slim.utils.collect_named_outputs(outputs_collections, 111 | sc.original_name_scope, 112 | output) 113 | 114 | 115 | def resnet_v1(inputs, 116 | blocks, 117 | num_classes=None, 118 | is_training=True, 119 | global_pool=True, 120 | output_stride=None, 121 | include_root_block=True, 122 | reuse=None, 123 | scope=None): 124 | """Generator for v1 ResNet models. 125 | 126 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 127 | methods for specific model instantiations, obtained by selecting different 128 | block instantiations that produce ResNets of various depths. 129 | 130 | Training for image classification on Imagenet is usually done with [224, 224] 131 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 132 | block for the ResNets defined in [1] that have nominal stride equal to 32. 133 | However, for dense prediction tasks we advise that one uses inputs with 134 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 135 | this case the feature maps at the ResNet output will have spatial shape 136 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 137 | and corners exactly aligned with the input image corners, which greatly 138 | facilitates alignment of the features to the image. Using as input [225, 225] 139 | images results in [8, 8] feature maps at the output of the last ResNet block. 140 | 141 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 142 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 143 | have nominal stride equal to 32 and a good choice in FCN mode is to use 144 | output_stride=16 in order to increase the density of the computed features at 145 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 146 | 147 | Args: 148 | inputs: A tensor of size [batch, height_in, width_in, channels]. 149 | blocks: A list of length equal to the number of ResNet blocks. Each element 150 | is a resnet_utils.Block object describing the units in the block. 151 | num_classes: Number of predicted classes for classification tasks. If None 152 | we return the features before the logit layer. 153 | is_training: whether is training or not. 154 | global_pool: If True, we perform global average pooling before computing the 155 | logits. Set to True for image classification, False for dense prediction. 156 | output_stride: If None, then the output will be computed at the nominal 157 | network stride. If output_stride is not None, it specifies the requested 158 | ratio of input to output spatial resolution. 159 | include_root_block: If True, include the initial convolution followed by 160 | max-pooling, if False excludes it. 161 | reuse: whether or not the network and its variables should be reused. To be 162 | able to reuse 'scope' must be given. 163 | scope: Optional variable_scope. 164 | 165 | Returns: 166 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 167 | If global_pool is False, then height_out and width_out are reduced by a 168 | factor of output_stride compared to the respective height_in and width_in, 169 | else both height_out and width_out equal one. If num_classes is None, then 170 | net is the output of the last ResNet block, potentially after global 171 | average pooling. If num_classes is not None, net contains the pre-softmax 172 | activations. 173 | end_points: A dictionary from components of the network to the corresponding 174 | activation. 175 | 176 | Raises: 177 | ValueError: If the target output_stride is not valid. 178 | """ 179 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 180 | end_points_collection = sc.name + '_end_points' 181 | with slim.arg_scope([slim.conv2d, bottleneck, 182 | resnet_utils.stack_blocks_dense], 183 | outputs_collections=end_points_collection): 184 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 185 | net = inputs 186 | if include_root_block: 187 | if output_stride is not None: 188 | if output_stride % 4 != 0: 189 | raise ValueError('The output_stride needs to be a multiple of 4.') 190 | output_stride /= 4 191 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 192 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 193 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 194 | if global_pool: 195 | # Global average pooling. 196 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 197 | if num_classes is not None: 198 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 199 | normalizer_fn=None, scope='logits') 200 | # Convert end_points_collection into a dictionary of end_points. 201 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 202 | if num_classes is not None: 203 | end_points['predictions'] = slim.softmax(net, scope='predictions') 204 | return net, end_points 205 | resnet_v1.default_image_size = 224 206 | 207 | 208 | def resnet_v1_50(inputs, 209 | num_classes=None, 210 | is_training=True, 211 | global_pool=True, 212 | output_stride=None, 213 | reuse=None, 214 | scope='resnet_v1_50'): 215 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 216 | blocks = [ 217 | resnet_utils.Block( 218 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 219 | resnet_utils.Block( 220 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 221 | resnet_utils.Block( 222 | 'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]), 223 | resnet_utils.Block( 224 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 225 | ] 226 | return resnet_v1(inputs, blocks, num_classes, is_training, 227 | global_pool=global_pool, output_stride=output_stride, 228 | include_root_block=True, reuse=reuse, scope=scope) 229 | 230 | 231 | def resnet_v1_101(inputs, 232 | num_classes=None, 233 | is_training=True, 234 | global_pool=True, 235 | output_stride=None, 236 | reuse=None, 237 | scope='resnet_v1_101'): 238 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 239 | blocks = [ 240 | resnet_utils.Block( 241 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 242 | resnet_utils.Block( 243 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 244 | resnet_utils.Block( 245 | 'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]), 246 | resnet_utils.Block( 247 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 248 | ] 249 | return resnet_v1(inputs, blocks, num_classes, is_training, 250 | global_pool=global_pool, output_stride=output_stride, 251 | include_root_block=True, reuse=reuse, scope=scope) 252 | 253 | 254 | def resnet_v1_152(inputs, 255 | num_classes=None, 256 | is_training=True, 257 | global_pool=True, 258 | output_stride=None, 259 | reuse=None, 260 | scope='resnet_v1_152'): 261 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 262 | blocks = [ 263 | resnet_utils.Block( 264 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 265 | resnet_utils.Block( 266 | 'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]), 267 | resnet_utils.Block( 268 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 269 | resnet_utils.Block( 270 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 271 | return resnet_v1(inputs, blocks, num_classes, is_training, 272 | global_pool=global_pool, output_stride=output_stride, 273 | include_root_block=True, reuse=reuse, scope=scope) 274 | 275 | 276 | def resnet_v1_200(inputs, 277 | num_classes=None, 278 | is_training=True, 279 | global_pool=True, 280 | output_stride=None, 281 | reuse=None, 282 | scope='resnet_v1_200'): 283 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 284 | blocks = [ 285 | resnet_utils.Block( 286 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 287 | resnet_utils.Block( 288 | 'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]), 289 | resnet_utils.Block( 290 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 291 | resnet_utils.Block( 292 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 293 | return resnet_v1(inputs, blocks, num_classes, is_training, 294 | global_pool=global_pool, output_stride=output_stride, 295 | include_root_block=True, reuse=reuse, scope=scope) 296 | -------------------------------------------------------------------------------- /nets/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains model definitions for versions of the Oxford VGG network. 16 | These model definitions were introduced in the following technical report: 17 | Very Deep Convolutional Networks For Large-Scale Image Recognition 18 | Karen Simonyan and Andrew Zisserman 19 | arXiv technical report, 2015 20 | PDF: http://arxiv.org/pdf/1409.1556.pdf 21 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 22 | CC-BY-4.0 23 | More information can be obtained from the VGG website: 24 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 25 | Usage: 26 | with slim.arg_scope(vgg.vgg_arg_scope()): 27 | outputs, end_points = vgg.vgg_a(inputs) 28 | with slim.arg_scope(vgg.vgg_arg_scope()): 29 | outputs, end_points = vgg.vgg_16(inputs) 30 | @@vgg_a 31 | @@vgg_16 32 | @@vgg_19 33 | """ 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import tensorflow as tf 39 | 40 | slim = tf.contrib.slim 41 | 42 | 43 | def vgg_arg_scope(weight_decay=0.0005): 44 | """Defines the VGG arg scope. 45 | Args: 46 | weight_decay: The l2 regularization coefficient. 47 | Returns: 48 | An arg_scope. 49 | """ 50 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 51 | activation_fn=tf.nn.relu, 52 | weights_regularizer=slim.l2_regularizer(weight_decay), 53 | biases_initializer=tf.zeros_initializer()): 54 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 55 | return arg_sc 56 | 57 | 58 | def vgg_a(inputs, 59 | num_classes=1000, 60 | is_training=True, 61 | dropout_keep_prob=0.5, 62 | spatial_squeeze=True, 63 | scope='vgg_a'): 64 | """Oxford Net VGG 11-Layers version A Example. 65 | Note: All the fully_connected layers have been transformed to conv2d layers. 66 | To use in classification mode, resize input to 224x224. 67 | Args: 68 | inputs: a tensor of size [batch_size, height, width, channels]. 69 | num_classes: number of predicted classes. 70 | is_training: whether or not the model is being trained. 71 | dropout_keep_prob: the probability that activations are kept in the dropout 72 | layers during training. 73 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 74 | outputs. Useful to remove unnecessary dimensions for classification. 75 | scope: Optional scope for the variables. 76 | Returns: 77 | the last op containing the log predictions and end_points dict. 78 | """ 79 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc: 80 | end_points_collection = sc.name + '_end_points' 81 | # Collect outputs for conv2d, fully_connected and max_pool2d. 82 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 83 | outputs_collections=end_points_collection): 84 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') 85 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 86 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2') 87 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 88 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3') 89 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 90 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4') 91 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 92 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') 93 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 94 | # Use conv2d instead of fully_connected layers. 95 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 96 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 97 | scope='dropout6') 98 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 99 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 100 | scope='dropout7') 101 | net = slim.conv2d(net, num_classes, [1, 1], 102 | activation_fn=None, 103 | normalizer_fn=None, 104 | scope='fc8') 105 | # Convert end_points_collection into a end_point dict. 106 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 107 | if spatial_squeeze: 108 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 109 | end_points[sc.name + '/fc8'] = net 110 | return net, end_points 111 | vgg_a.default_image_size = 224 112 | 113 | 114 | def vgg_16(inputs, 115 | num_classes=1000, 116 | is_training=True, 117 | dropout_keep_prob=0.5, 118 | spatial_squeeze=True, 119 | scope='vgg_16'): 120 | """Oxford Net VGG 16-Layers version D Example. 121 | Note: All the fully_connected layers have been transformed to conv2d layers. 122 | To use in classification mode, resize input to 224x224. 123 | Args: 124 | inputs: a tensor of size [batch_size, height, width, channels]. 125 | num_classes: number of predicted classes. 126 | is_training: whether or not the model is being trained. 127 | dropout_keep_prob: the probability that activations are kept in the dropout 128 | layers during training. 129 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 130 | outputs. Useful to remove unnecessary dimensions for classification. 131 | scope: Optional scope for the variables. 132 | Returns: 133 | the last op containing the log predictions and end_points dict. 134 | """ 135 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: 136 | end_points_collection = sc.name + '_end_points' 137 | # Collect outputs for conv2d, fully_connected and max_pool2d. 138 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 139 | outputs_collections=end_points_collection): 140 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 141 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 142 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 143 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 144 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 145 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 146 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 147 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 148 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 149 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 150 | # Use conv2d instead of fully_connected layers. 151 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 152 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 153 | scope='dropout6') 154 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 155 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 156 | scope='dropout7') 157 | net = slim.conv2d(net, num_classes, [1, 1], 158 | activation_fn=None, 159 | normalizer_fn=None, 160 | scope='fc8') 161 | # Convert end_points_collection into a end_point dict. 162 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 163 | if spatial_squeeze: 164 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 165 | end_points[sc.name + '/fc8'] = net 166 | return net, end_points 167 | vgg_16.default_image_size = 224 168 | 169 | 170 | def vgg_19(inputs, 171 | num_classes=1000, 172 | is_training=True, 173 | dropout_keep_prob=0.5, 174 | spatial_squeeze=True, 175 | scope='vgg_19'): 176 | """Oxford Net VGG 19-Layers version E Example. 177 | Note: All the fully_connected layers have been transformed to conv2d layers. 178 | To use in classification mode, resize input to 224x224. 179 | Args: 180 | inputs: a tensor of size [batch_size, height, width, channels]. 181 | num_classes: number of predicted classes. 182 | is_training: whether or not the model is being trained. 183 | dropout_keep_prob: the probability that activations are kept in the dropout 184 | layers during training. 185 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 186 | outputs. Useful to remove unnecessary dimensions for classification. 187 | scope: Optional scope for the variables. 188 | Returns: 189 | the last op containing the log predictions and end_points dict. 190 | """ 191 | with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc: 192 | end_points_collection = sc.name + '_end_points' 193 | # Collect outputs for conv2d, fully_connected and max_pool2d. 194 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 195 | outputs_collections=end_points_collection): 196 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 197 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 198 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 199 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 200 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3') 201 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 202 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4') 203 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 204 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5') 205 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 206 | # Use conv2d instead of fully_connected layers. 207 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 208 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 209 | scope='dropout6') 210 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 211 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 212 | scope='dropout7') 213 | net = slim.conv2d(net, num_classes, [1, 1], 214 | activation_fn=None, 215 | normalizer_fn=None, 216 | scope='fc8') 217 | # Convert end_points_collection into a end_point dict. 218 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 219 | if spatial_squeeze: 220 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 221 | end_points[sc.name + '/fc8'] = net 222 | return net, end_points 223 | vgg_19.default_image_size = 224 224 | 225 | # Alias 226 | vgg_d = vgg_16 227 | vgg_e = vgg_19 -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING): 34 | """Preprocesses the given image for training. 35 | 36 | Note that the actual resizing scale is sampled from 37 | [`resize_size_min`, `resize_size_max`]. 38 | 39 | Args: 40 | image: A `Tensor` representing an image of arbitrary size. 41 | output_height: The height of the image after preprocessing. 42 | output_width: The width of the image after preprocessing. 43 | padding: The amound of padding before and after each dimension of the image. 44 | 45 | Returns: 46 | A preprocessed image. 47 | """ 48 | tf.image_summary('image', tf.expand_dims(image, 0)) 49 | 50 | # Transform the image to floats. 51 | image = tf.to_float(image) 52 | if padding > 0: 53 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 54 | # Randomly crop a [height, width] section of the image. 55 | distorted_image = tf.random_crop(image, 56 | [output_height, output_width, 3]) 57 | 58 | # Randomly flip the image horizontally. 59 | distorted_image = tf.image.random_flip_left_right(distorted_image) 60 | 61 | tf.image_summary('distorted_image', tf.expand_dims(distorted_image, 0)) 62 | 63 | # Because these operations are not commutative, consider randomizing 64 | # the order their operation. 65 | distorted_image = tf.image.random_brightness(distorted_image, 66 | max_delta=63) 67 | distorted_image = tf.image.random_contrast(distorted_image, 68 | lower=0.2, upper=1.8) 69 | # Subtract off the mean and divide by the variance of the pixels. 70 | return tf.image.per_image_whitening(distorted_image) 71 | 72 | 73 | def preprocess_for_eval(image, output_height, output_width): 74 | """Preprocesses the given image for evaluation. 75 | 76 | Args: 77 | image: A `Tensor` representing an image of arbitrary size. 78 | output_height: The height of the image after preprocessing. 79 | output_width: The width of the image after preprocessing. 80 | 81 | Returns: 82 | A preprocessed image. 83 | """ 84 | tf.image_summary('image', tf.expand_dims(image, 0)) 85 | # Transform the image to floats. 86 | image = tf.to_float(image) 87 | 88 | # Resize and crop if needed. 89 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 90 | output_width, 91 | output_height) 92 | tf.image_summary('resized_image', tf.expand_dims(resized_image, 0)) 93 | 94 | # Subtract off the mean and divide by the variance of the pixels. 95 | return tf.image.per_image_whitening(resized_image) 96 | 97 | 98 | def preprocess_image(image, output_height, output_width, is_training=False): 99 | """Preprocesses the given image. 100 | 101 | Args: 102 | image: A `Tensor` representing an image of arbitrary size. 103 | output_height: The height of the image after preprocessing. 104 | output_width: The width of the image after preprocessing. 105 | is_training: `True` if we're preprocessing the image for training and 106 | `False` otherwise. 107 | 108 | Returns: 109 | A preprocessed image. 110 | """ 111 | if is_training: 112 | return preprocess_for_train(image, output_height, output_width) 113 | else: 114 | return preprocess_for_eval(image, output_height, output_width) 115 | -------------------------------------------------------------------------------- /preprocessing/inception_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images for the Inception networks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | 25 | 26 | def apply_with_random_selector(x, func, num_cases): 27 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 28 | 29 | Args: 30 | x: input Tensor. 31 | func: Python function to apply. 32 | num_cases: Python int32, number of cases to sample sel from. 33 | 34 | Returns: 35 | The result of func(x, sel), where func receives the value of the 36 | selector as a python integer, but sel is sampled dynamically. 37 | """ 38 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 39 | # Pass the real x only to one of the func calls. 40 | return control_flow_ops.merge([ 41 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 42 | for case in range(num_cases)])[0] 43 | 44 | 45 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 46 | """Distort the color of a Tensor image. 47 | 48 | Each color distortion is non-commutative and thus ordering of the color ops 49 | matters. Ideally we would randomly permute the ordering of the color ops. 50 | Rather then adding that level of complication, we select a distinct ordering 51 | of color ops for each preprocessing thread. 52 | 53 | Args: 54 | image: 3-D Tensor containing single image in [0, 1]. 55 | color_ordering: Python int, a type of distortion (valid values: 0-3). 56 | fast_mode: Avoids slower ops (random_hue and random_contrast) 57 | scope: Optional scope for name_scope. 58 | Returns: 59 | 3-D Tensor color-distorted image on range [0, 1] 60 | Raises: 61 | ValueError: if color_ordering not in [0, 3] 62 | """ 63 | with tf.name_scope(scope, 'distort_color', [image]): 64 | if fast_mode: 65 | if color_ordering == 0: 66 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 67 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 68 | else: 69 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 70 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 71 | else: 72 | if color_ordering == 0: 73 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 74 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 75 | image = tf.image.random_hue(image, max_delta=0.2) 76 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 77 | elif color_ordering == 1: 78 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 79 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 80 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 81 | image = tf.image.random_hue(image, max_delta=0.2) 82 | elif color_ordering == 2: 83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 84 | image = tf.image.random_hue(image, max_delta=0.2) 85 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 86 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 87 | elif color_ordering == 3: 88 | image = tf.image.random_hue(image, max_delta=0.2) 89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 90 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 91 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 92 | else: 93 | raise ValueError('color_ordering must be in [0, 3]') 94 | 95 | # The random_* ops do not necessarily clamp. 96 | return tf.clip_by_value(image, 0.0, 1.0) 97 | 98 | 99 | def distorted_bounding_box_crop(image, 100 | bbox, 101 | min_object_covered=0.1, 102 | aspect_ratio_range=(0.75, 1.33), 103 | area_range=(0.05, 1.0), 104 | max_attempts=100, 105 | scope=None): 106 | """Generates cropped_image using a one of the bboxes randomly distorted. 107 | 108 | See `tf.image.sample_distorted_bounding_box` for more documentation. 109 | 110 | Args: 111 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 112 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 113 | where each coordinate is [0, 1) and the coordinates are arranged 114 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 115 | image. 116 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 117 | area of the image must contain at least this fraction of any bounding box 118 | supplied. 119 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 120 | image must have an aspect ratio = width / height within this range. 121 | area_range: An optional list of `floats`. The cropped area of the image 122 | must contain a fraction of the supplied image within in this range. 123 | max_attempts: An optional `int`. Number of attempts at generating a cropped 124 | region of the image of the specified constraints. After `max_attempts` 125 | failures, return the entire image. 126 | scope: Optional scope for name_scope. 127 | Returns: 128 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 129 | """ 130 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 131 | # Each bounding box has shape [1, num_boxes, box coords] and 132 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 133 | 134 | # A large fraction of image datasets contain a human-annotated bounding 135 | # box delineating the region of the image containing the object of interest. 136 | # We choose to create a new bounding box for the object which is a randomly 137 | # distorted version of the human-annotated bounding box that obeys an 138 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 139 | # bounding box. If no box is supplied, then we assume the bounding box is 140 | # the entire image. 141 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 142 | tf.shape(image), 143 | bounding_boxes=bbox, 144 | min_object_covered=min_object_covered, 145 | aspect_ratio_range=aspect_ratio_range, 146 | area_range=area_range, 147 | max_attempts=max_attempts, 148 | use_image_if_no_bounding_boxes=True) 149 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 150 | 151 | # Crop the image to the specified bounding box. 152 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 153 | return cropped_image, distort_bbox 154 | 155 | 156 | def preprocess_for_train(image, height, width, bbox, 157 | fast_mode=True, 158 | scope=None): 159 | """Distort one image for training a network. 160 | 161 | Distorting images provides a useful technique for augmenting the data 162 | set during training in order to make the network invariant to aspects 163 | of the image that do not effect the label. 164 | 165 | Additionally it would create image_summaries to display the different 166 | transformations applied to the image. 167 | 168 | Args: 169 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 170 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 171 | is [0, MAX], where MAX is largest positive representable number for 172 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 173 | height: integer 174 | width: integer 175 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 176 | where each coordinate is [0, 1) and the coordinates are arranged 177 | as [ymin, xmin, ymax, xmax]. 178 | fast_mode: Optional boolean, if True avoids slower transformations (i.e. 179 | bi-cubic resizing, random_hue or random_contrast). 180 | scope: Optional scope for name_scope. 181 | Returns: 182 | 3-D float Tensor of distorted image used for training with range [-1, 1]. 183 | """ 184 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): 185 | if bbox is None: 186 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], 187 | dtype=tf.float32, 188 | shape=[1, 1, 4]) 189 | if image.dtype != tf.float32: 190 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 191 | # Each bounding box has shape [1, num_boxes, box coords] and 192 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 193 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 194 | bbox) 195 | tf.image_summary('image_with_bounding_boxes', image_with_box) 196 | 197 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) 198 | # Restore the shape since the dynamic slice based upon the bbox_size loses 199 | # the third dimension. 200 | distorted_image.set_shape([None, None, 3]) 201 | image_with_distorted_box = tf.image.draw_bounding_boxes( 202 | tf.expand_dims(image, 0), distorted_bbox) 203 | tf.image_summary('images_with_distorted_bounding_box', 204 | image_with_distorted_box) 205 | 206 | # This resizing operation may distort the images because the aspect 207 | # ratio is not respected. We select a resize method in a round robin 208 | # fashion based on the thread number. 209 | # Note that ResizeMethod contains 4 enumerated resizing methods. 210 | 211 | # We select only 1 case for fast_mode bilinear. 212 | num_resize_cases = 1 if fast_mode else 4 213 | distorted_image = apply_with_random_selector( 214 | distorted_image, 215 | lambda x, method: tf.image.resize_images(x, [height, width], method=method), 216 | num_cases=num_resize_cases) 217 | 218 | tf.image_summary('cropped_resized_image', 219 | tf.expand_dims(distorted_image, 0)) 220 | 221 | # Randomly flip the image horizontally. 222 | distorted_image = tf.image.random_flip_left_right(distorted_image) 223 | 224 | # Randomly distort the colors. There are 4 ways to do it. 225 | distorted_image = apply_with_random_selector( 226 | distorted_image, 227 | lambda x, ordering: distort_color(x, ordering, fast_mode), 228 | num_cases=4) 229 | 230 | tf.image_summary('final_distorted_image', 231 | tf.expand_dims(distorted_image, 0)) 232 | distorted_image = tf.sub(distorted_image, 0.5) 233 | distorted_image = tf.mul(distorted_image, 2.0) 234 | return distorted_image 235 | 236 | 237 | def preprocess_for_eval(image, height, width, 238 | central_fraction=0.875, scope=None): 239 | """Prepare one image for evaluation. 240 | 241 | If height and width are specified it would output an image with that size by 242 | applying resize_bilinear. 243 | 244 | If central_fraction is specified it would cropt the central fraction of the 245 | input image. 246 | 247 | Args: 248 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 249 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 250 | is [0, MAX], where MAX is largest positive representable number for 251 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details) 252 | height: integer 253 | width: integer 254 | central_fraction: Optional Float, fraction of the image to crop. 255 | scope: Optional scope for name_scope. 256 | Returns: 257 | 3-D float Tensor of prepared image. 258 | """ 259 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 260 | if image.dtype != tf.float32: 261 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 262 | # Crop the central region of the image with an area containing 87.5% of 263 | # the original image. 264 | if central_fraction: 265 | image = tf.image.central_crop(image, central_fraction=central_fraction) 266 | 267 | if height and width: 268 | # Resize the image to the specified height and width. 269 | image = tf.expand_dims(image, 0) 270 | image = tf.image.resize_bilinear(image, [height, width], 271 | align_corners=False) 272 | image = tf.squeeze(image, [0]) 273 | image = tf.sub(image, 0.5) 274 | image = tf.mul(image, 2.0) 275 | return image 276 | 277 | 278 | def preprocess_image(image, height, width, 279 | is_training=False, 280 | bbox=None, 281 | fast_mode=True): 282 | """Pre-process one image for training or evaluation. 283 | 284 | Args: 285 | image: 3-D Tensor [height, width, channels] with the image. 286 | height: integer, image expected height. 287 | width: integer, image expected width. 288 | is_training: Boolean. If true it would transform an image for train, 289 | otherwise it would transform it for evaluation. 290 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 291 | where each coordinate is [0, 1) and the coordinates are arranged as 292 | [ymin, xmin, ymax, xmax]. 293 | fast_mode: Optional boolean, if True avoids slower transformations. 294 | 295 | Returns: 296 | 3-D float Tensor containing an appropriately scaled image 297 | 298 | Raises: 299 | ValueError: if user does not provide bounding box 300 | """ 301 | if is_training: 302 | return preprocess_for_train(image, height, width, bbox, fast_mode) 303 | else: 304 | return preprocess_for_eval(image, height, width) 305 | -------------------------------------------------------------------------------- /preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.sub(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'resnet_v1_50': vgg_preprocessing, 57 | 'resnet_v1_101': vgg_preprocessing, 58 | 'resnet_v1_152': vgg_preprocessing, 59 | 'vgg': vgg_preprocessing, 60 | 'vgg_a': vgg_preprocessing, 61 | 'vgg_16': vgg_preprocessing, 62 | 'vgg_19': vgg_preprocessing, 63 | } 64 | 65 | if name not in preprocessing_fn_map: 66 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 67 | 68 | def preprocessing_fn(image, output_height, output_width, **kwargs): 69 | return preprocessing_fn_map[name].preprocess_image( 70 | image, output_height, output_width, is_training=is_training, **kwargs) 71 | 72 | def unprocessing_fn(image, **kwargs): 73 | return preprocessing_fn_map[name].unprocess_image( 74 | image, **kwargs) 75 | 76 | return preprocessing_fn, unprocessing_fn 77 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import isfile, join 3 | import tensorflow as tf 4 | 5 | 6 | def get_image(path, height, width, preprocess_fn): 7 | png = path.lower().endswith('png') 8 | img_bytes = tf.read_file(path) 9 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) 10 | return preprocess_fn(image, height, width) 11 | 12 | 13 | def image(batch_size, height, width, path, preprocess_fn, epochs=2, shuffle=True): 14 | filenames = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 15 | if not shuffle: 16 | filenames = sorted(filenames) 17 | 18 | png = filenames[0].lower().endswith('png') # If first file is a png, assume they all are 19 | 20 | filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle, num_epochs=epochs) 21 | reader = tf.WholeFileReader() 22 | _, img_bytes = reader.read(filename_queue) 23 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) 24 | 25 | processed_image = preprocess_fn(image, height, width) 26 | return tf.train.batch([processed_image], batch_size, dynamic_pad=True) 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | from __future__ import division 4 | import tensorflow as tf 5 | from nets import nets_factory 6 | from preprocessing import preprocessing_factory 7 | import reader 8 | import model 9 | import time 10 | import losses 11 | import utils 12 | import os 13 | import argparse 14 | 15 | slim = tf.contrib.slim 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-c', '--conf', default='conf/mosaic.yml', help='the path to the conf file') 21 | return parser.parse_args() 22 | 23 | 24 | def main(FLAGS): 25 | style_features_t = losses.get_style_features(FLAGS) 26 | 27 | # Make sure the training path exists. 28 | training_path = os.path.join(FLAGS.model_path, FLAGS.naming) 29 | if not(os.path.exists(training_path)): 30 | os.makedirs(training_path) 31 | 32 | with tf.Graph().as_default(): 33 | with tf.Session() as sess: 34 | """Build Network""" 35 | network_fn = nets_factory.get_network_fn( 36 | FLAGS.loss_model, 37 | num_classes=1, 38 | is_training=False) 39 | 40 | image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( 41 | FLAGS.loss_model, 42 | is_training=False) 43 | processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 44 | 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) 45 | generated = model.net(processed_images, training=True) 46 | processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 47 | for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) 48 | ] 49 | processed_generated = tf.stack(processed_generated) 50 | _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False) 51 | 52 | # Log the structure of loss network 53 | tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):') 54 | for key in endpoints_dict: 55 | tf.logging.info(key) 56 | 57 | """Build Losses""" 58 | content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) 59 | style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers) 60 | tv_loss = losses.total_variation_loss(generated) # use the unprocessed image 61 | 62 | loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss 63 | 64 | # Add Summary for visualization in tensorboard. 65 | """Add Summary""" 66 | tf.summary.scalar('losses/content_loss', content_loss) 67 | tf.summary.scalar('losses/style_loss', style_loss) 68 | tf.summary.scalar('losses/regularizer_loss', tv_loss) 69 | 70 | tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight) 71 | tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight) 72 | tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight) 73 | tf.summary.scalar('total_loss', loss) 74 | 75 | for layer in FLAGS.style_layers: 76 | tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) 77 | tf.summary.image('generated', generated) 78 | # tf.image_summary('processed_generated', processed_generated) # May be better? 79 | tf.summary.image('origin', tf.stack([ 80 | image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size) 81 | ])) 82 | summary = tf.summary.merge_all() 83 | writer = tf.summary.FileWriter(training_path) 84 | 85 | """Prepare to Train""" 86 | global_step = tf.Variable(0, name="global_step", trainable=False) 87 | 88 | variable_to_train = [] 89 | for variable in tf.trainable_variables(): 90 | if not(variable.name.startswith(FLAGS.loss_model)): 91 | variable_to_train.append(variable) 92 | train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) 93 | 94 | variables_to_restore = [] 95 | for v in tf.global_variables(): 96 | if not(v.name.startswith(FLAGS.loss_model)): 97 | variables_to_restore.append(v) 98 | saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) 99 | 100 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 101 | 102 | # Restore variables for loss network. 103 | init_func = utils._get_init_fn(FLAGS) 104 | init_func(sess) 105 | 106 | # Restore variables for training model if the checkpoint file exists. 107 | last_file = tf.train.latest_checkpoint(training_path) 108 | if last_file: 109 | tf.logging.info('Restoring model from {}'.format(last_file)) 110 | saver.restore(sess, last_file) 111 | 112 | """Start Training""" 113 | coord = tf.train.Coordinator() 114 | threads = tf.train.start_queue_runners(coord=coord) 115 | start_time = time.time() 116 | try: 117 | while not coord.should_stop(): 118 | _, loss_t, step = sess.run([train_op, loss, global_step]) 119 | elapsed_time = time.time() - start_time 120 | start_time = time.time() 121 | """logging""" 122 | # print(step) 123 | if step % 10 == 0: 124 | tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) 125 | """summary""" 126 | if step % 25 == 0: 127 | tf.logging.info('adding summary...') 128 | summary_str = sess.run(summary) 129 | writer.add_summary(summary_str, step) 130 | writer.flush() 131 | """checkpoint""" 132 | if step % 1000 == 0: 133 | saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) 134 | except tf.errors.OutOfRangeError: 135 | saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) 136 | tf.logging.info('Done training -- epoch limit reached') 137 | finally: 138 | coord.request_stop() 139 | coord.join(threads) 140 | 141 | 142 | if __name__ == '__main__': 143 | tf.logging.set_verbosity(tf.logging.INFO) 144 | args = parse_args() 145 | FLAGS = utils.read_conf_file(args.conf) 146 | main(FLAGS) 147 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import yaml 3 | 4 | slim = tf.contrib.slim 5 | 6 | 7 | def _get_init_fn(FLAGS): 8 | """ 9 | This function is copied from TF slim. 10 | 11 | Returns a function run by the chief worker to warm-start the training. 12 | 13 | Note that the init_fn is only run when initializing the model during the very 14 | first global step. 15 | 16 | Returns: 17 | An init function run by the supervisor. 18 | """ 19 | tf.logging.info('Use pretrained model %s' % FLAGS.loss_model_file) 20 | 21 | exclusions = [] 22 | if FLAGS.checkpoint_exclude_scopes: 23 | exclusions = [scope.strip() 24 | for scope in FLAGS.checkpoint_exclude_scopes.split(',')] 25 | 26 | # TODO(sguada) variables.filter_variables() 27 | variables_to_restore = [] 28 | for var in slim.get_model_variables(): 29 | excluded = False 30 | for exclusion in exclusions: 31 | if var.op.name.startswith(exclusion): 32 | excluded = True 33 | break 34 | if not excluded: 35 | variables_to_restore.append(var) 36 | 37 | return slim.assign_from_checkpoint_fn( 38 | FLAGS.loss_model_file, 39 | variables_to_restore, 40 | ignore_missing_vars=True) 41 | 42 | 43 | class Flag(object): 44 | def __init__(self, **entries): 45 | self.__dict__.update(entries) 46 | 47 | 48 | def read_conf_file(conf_file): 49 | with open(conf_file) as f: 50 | FLAGS = Flag(**yaml.load(f)) 51 | return FLAGS 52 | 53 | 54 | def mean_image_subtraction(image, means): 55 | image = tf.to_float(image) 56 | 57 | num_channels = 3 58 | channels = tf.split(image, num_channels, 2) 59 | for i in range(num_channels): 60 | channels[i] -= means[i] 61 | return tf.concat(channels, 2) 62 | 63 | 64 | if __name__ == '__main__': 65 | f = read_conf_file('conf/mosaic.yml') 66 | print(f.loss_model_file) 67 | --------------------------------------------------------------------------------