├── .gitignore
├── README.md
├── conf
├── candy.yml
├── cubist.yml
├── denoised_starry.yml
├── feathers.yml
├── mosaic.yml
├── scream.yml
├── udnie.yml
└── wave.yml
├── eval.py
├── export.py
├── img
├── candy.jpg
├── cubist.jpg
├── denoised_starry.jpg
├── feathers.jpg
├── mosaic.jpg
├── results
│ ├── cubist.jpg
│ ├── denoised_starry.jpg
│ ├── feathers.jpg
│ ├── mosaic.jpg
│ ├── scream.jpg
│ ├── style_cubist.jpg
│ ├── style_denoised_starry.jpg
│ ├── style_feathers.jpg
│ ├── style_mosaic.jpg
│ ├── style_scream.jpg
│ ├── style_udnie.jpg
│ ├── style_wave.jpg
│ ├── udnie.jpg
│ └── wave.jpg
├── scream.jpg
├── starry.jpg
├── test.jpg
├── test1.jpg
├── test2.jpg
├── test3.jpg
├── test4.jpg
├── test5.jpg
├── udnie.jpg
└── wave.jpg
├── losses.py
├── model.py
├── nets
├── __init__.py
├── alexnet.py
├── alexnet_test.py
├── cifarnet.py
├── inception.py
├── inception_resnet_v2.py
├── inception_resnet_v2_test.py
├── inception_utils.py
├── inception_v1.py
├── inception_v1_test.py
├── inception_v2.py
├── inception_v2_test.py
├── inception_v3.py
├── inception_v3_test.py
├── inception_v4.py
├── inception_v4_test.py
├── lenet.py
├── nets_factory.py
├── nets_factory_test.py
├── overfeat.py
├── overfeat_test.py
├── resnet_utils.py
├── resnet_v1.py
├── resnet_v1_test.py
├── resnet_v2.py
├── resnet_v2_test.py
├── vgg.py
└── vgg_test.py
├── preprocessing
├── __init__.py
├── cifarnet_preprocessing.py
├── inception_preprocessing.py
├── lenet_preprocessing.py
├── preprocessing_factory.py
└── vgg_preprocessing.py
├── reader.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 |
92 | slim-official/
93 | train2014
94 | generated/
95 | models/
96 | tensorboard/
97 | result/
98 | pretrained/
99 |
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # fast-neural-style-tensorflow
2 |
3 | A tensorflow implementation for [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155).
4 |
5 | This code is based on [Tensorflow-Slim](https://github.com/tensorflow/models/tree/master/slim) and [OlavHN/fast-neural-style](https://github.com/OlavHN/fast-neural-style).
6 |
7 | ## Samples:
8 |
9 | | configuration | style | sample |
10 | | :---: | :----: | :----: |
11 | | [wave.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/wave.yml) ||  |
12 | | [cubist.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/cubist.yml) ||  |
13 | | [denoised_starry.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/denoised_starry.yml) ||  |
14 | | [mosaic.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/mosaic.yml) ||  |
15 | | [scream.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/scream.yml) ||  |
16 | | [feathers.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/feathers.yml) ||  |
17 | | [udnie.yml](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/udnie.yml) ||  |
18 |
19 | ## Requirements and Prerequisites:
20 | - Python 2.7.x
21 | - Now support Tensorflow >= 1.0
22 |
23 | Attention: This code also supports Tensorflow == 0.11. If it is your version, use the commit 5309a2a (git reset --hard 5309a2a).
24 |
25 | And make sure you installed pyyaml:
26 | ```
27 | pip install pyyaml
28 | ```
29 |
30 | ## Use Trained Models:
31 |
32 | You can download all the 7 trained models from [Baidu Drive](https://pan.baidu.com/s/1i4GTS4d).
33 |
34 | To generate a sample from the model "wave.ckpt-done", run:
35 |
36 | ```
37 | python eval.py --model_file --image_file img/test.jpg
38 | ```
39 |
40 | Then check out generated/res.jpg.
41 |
42 | ## Train a Model:
43 | To train a model from scratch, you should first download [VGG16 model](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz) from Tensorflow Slim. Extract the file vgg_16.ckpt. Then copy it to the folder pretrained/ :
44 | ```
45 | cd
46 | mkdir pretrained
47 | cp pretrained/
48 | ```
49 |
50 | Then download the [COCO dataset](http://msvocds.blob.core.windows.net/coco2014/train2014.zip). Please unzip it, and you will have a folder named "train2014" with many raw images in it. Then create a symbol link to it:
51 | ```
52 | cd
53 | ln -s train2014
54 | ```
55 |
56 | Train the model of "wave":
57 | ```
58 | python train.py -c conf/wave.yml
59 | ```
60 |
61 | (Optional) Use tensorboard:
62 | ```
63 | tensorboard --logdir models/wave/
64 | ```
65 |
66 | Checkpoints will be written to "models/wave/".
67 |
68 | View the [configuration file](https://github.com/hzy46/fast-neural-style-tensorflow/blob/master/conf/wave.yml) for details.
69 |
--------------------------------------------------------------------------------
/conf/candy.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/candy.jpg # targeted style image
3 | naming: "candy" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 50.0 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
27 |
--------------------------------------------------------------------------------
/conf/cubist.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/cubist.jpg # targeted style image
3 | naming: "cubist" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 180.0 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
--------------------------------------------------------------------------------
/conf/denoised_starry.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/denoised_starry.jpg # targeted style image
3 | naming: "denoised_starry" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 250 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
--------------------------------------------------------------------------------
/conf/feathers.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/feathers.jpg # targeted style image
3 | naming: "feathers" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 220.0 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
--------------------------------------------------------------------------------
/conf/mosaic.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/mosaic.jpg # targeted style image
3 | naming: "mosaic" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 100.0 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
--------------------------------------------------------------------------------
/conf/scream.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/scream.jpg # targeted style image
3 | naming: "scream" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 250.0 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
--------------------------------------------------------------------------------
/conf/udnie.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/udnie.jpg # targeted style image
3 | naming: "udnie" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 200.0 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
--------------------------------------------------------------------------------
/conf/wave.yml:
--------------------------------------------------------------------------------
1 | ## Basic configuration
2 | style_image: img/wave.jpg # targeted style image
3 | naming: "wave" # the name of this model. Determine the path to save checkpoint and events file.
4 | model_path: models # root path to save checkpoint and events file. The final path would be /
5 |
6 | ## Weight of the loss
7 | content_weight: 1.0 # weight for content features loss
8 | style_weight: 220.0 # weight for style features loss
9 | tv_weight: 0.0 # weight for total variation loss
10 |
11 | ## The size, the iter number to run
12 | image_size: 256
13 | batch_size: 4
14 | epoch: 2
15 |
16 | ## Loss Network
17 | loss_model: "vgg_16"
18 | content_layers: # use these layers for content loss
19 | - "vgg_16/conv3/conv3_3"
20 | style_layers: # use these layers for style loss
21 | - "vgg_16/conv1/conv1_2"
22 | - "vgg_16/conv2/conv2_2"
23 | - "vgg_16/conv3/conv3_3"
24 | - "vgg_16/conv4/conv4_3"
25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers.
26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint
27 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | from __future__ import print_function
3 | import tensorflow as tf
4 | from preprocessing import preprocessing_factory
5 | import reader
6 | import model
7 | import time
8 | import os
9 |
10 | tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'The name of the architecture to evaluate. '
11 | 'You can view all the support models in nets/nets_factory.py')
12 | tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.')
13 | tf.app.flags.DEFINE_string("model_file", "models.ckpt", "")
14 | tf.app.flags.DEFINE_string("image_file", "a.jpg", "")
15 |
16 | FLAGS = tf.app.flags.FLAGS
17 |
18 |
19 | def main(_):
20 |
21 | # Get image's height and width.
22 | height = 0
23 | width = 0
24 | with open(FLAGS.image_file, 'rb') as img:
25 | with tf.Session().as_default() as sess:
26 | if FLAGS.image_file.lower().endswith('png'):
27 | image = sess.run(tf.image.decode_png(img.read()))
28 | else:
29 | image = sess.run(tf.image.decode_jpeg(img.read()))
30 | height = image.shape[0]
31 | width = image.shape[1]
32 | tf.logging.info('Image size: %dx%d' % (width, height))
33 |
34 | with tf.Graph().as_default():
35 | with tf.Session().as_default() as sess:
36 |
37 | # Read image data.
38 | image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
39 | FLAGS.loss_model,
40 | is_training=False)
41 | image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn)
42 |
43 | # Add batch dimension
44 | image = tf.expand_dims(image, 0)
45 |
46 | generated = model.net(image, training=False)
47 | generated = tf.cast(generated, tf.uint8)
48 |
49 | # Remove batch dimension
50 | generated = tf.squeeze(generated, [0])
51 |
52 | # Restore model variables.
53 | saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
54 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
55 | # Use absolute path
56 | FLAGS.model_file = os.path.abspath(FLAGS.model_file)
57 | saver.restore(sess, FLAGS.model_file)
58 |
59 | # Make sure 'generated' directory exists.
60 | generated_file = 'generated/res.jpg'
61 | if os.path.exists('generated') is False:
62 | os.makedirs('generated')
63 |
64 | # Generate and write image data to file.
65 | with open(generated_file, 'wb') as img:
66 | start_time = time.time()
67 | img.write(sess.run(tf.image.encode_jpeg(generated)))
68 | end_time = time.time()
69 | tf.logging.info('Elapsed time: %fs' % (end_time - start_time))
70 |
71 | tf.logging.info('Done. Please check %s.' % generated_file)
72 |
73 |
74 | if __name__ == '__main__':
75 | tf.logging.set_verbosity(tf.logging.INFO)
76 | tf.app.run()
77 |
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | from __future__ import print_function
3 | import tensorflow as tf
4 | import argparse
5 | import time
6 | import os
7 |
8 | import model
9 | import utils
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('-m', '--model_file', help='the path to the model file')
15 | parser.add_argument('-n', '--model_name', default='transfer', help='the name of the model')
16 | parser.add_argument('-d', dest='is_debug', action='store_true')
17 | parser.set_defaults(is_debug=False)
18 | return parser.parse_args()
19 |
20 |
21 | def main(args):
22 | g = tf.Graph() # A new graph
23 | with g.as_default():
24 | with tf.Session() as sess:
25 | # Building graph.
26 | image_data = tf.placeholder(tf.int32, name='input_image')
27 | height = tf.placeholder(tf.int32, name='height')
28 | width = tf.placeholder(tf.int32, name='width')
29 |
30 | # Reshape data
31 | image = tf.reshape(image_data, [height, width, 3])
32 |
33 | processed_image = utils.mean_image_subtraction(
34 | image, [123.68, 116.779, 103.939]) # Preprocessing image
35 | batched_image = tf.expand_dims(processed_image, 0) # Add batch dimension
36 | generated_image = model.net(batched_image, training=False)
37 | casted_image = tf.cast(generated_image, tf.int32)
38 | # Remove batch dimension
39 | squeezed_image = tf.squeeze(casted_image, [0])
40 | cropped_image = tf.slice(squeezed_image, [0, 0, 0], [height, width, 3])
41 | # stylized_image = tf.image.encode_jpeg(squeezed_image, name='output_image')
42 | stylized_image_data = tf.reshape(cropped_image, [-1], name='output_image')
43 |
44 | # Restore model variables.
45 | saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
46 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
47 | # Use absolute path.
48 | model_file = os.path.abspath(args.model_file)
49 | saver.restore(sess, model_file)
50 |
51 | if args.is_debug:
52 | content_file = '/Users/Lex/Desktop/t.jpg'
53 | generated_file = '/Users/Lex/Desktop/xwz-stylized.jpg'
54 |
55 | with open(generated_file, 'wb') as img:
56 | image_bytes = tf.read_file(content_file)
57 | input_array, decoded_image = sess.run([
58 | tf.reshape(tf.image.decode_jpeg(image_bytes, channels=3), [-1]),
59 | tf.image.decode_jpeg(image_bytes, channels=3)])
60 |
61 | start_time = time.time()
62 | img.write(sess.run(tf.image.encode_jpeg(tf.cast(cropped_image, tf.uint8)), feed_dict={
63 | image_data: input_array,
64 | height: decoded_image.shape[0],
65 | width: decoded_image.shape[1]}))
66 | end_time = time.time()
67 |
68 | tf.logging.info('Elapsed time: %fs' % (end_time - start_time))
69 | else:
70 | output_graph_def = tf.graph_util.convert_variables_to_constants(
71 | sess, sess.graph_def, output_node_names=['output_image'])
72 |
73 | with tf.gfile.FastGFile('/Users/Lex/Desktop/' + args.model_name + '.pb', mode='wb') as f:
74 | f.write(output_graph_def.SerializeToString())
75 |
76 | # tf.train.write_graph(g.as_graph_def(), '/Users/Lex/Desktop',
77 | # args.model_name + '.pb', as_text=False)
78 |
79 |
80 | if __name__ == '__main__':
81 | tf.logging.set_verbosity(tf.logging.INFO)
82 | args = parse_args()
83 | print(args)
84 | main(args)
85 |
--------------------------------------------------------------------------------
/img/candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/candy.jpg
--------------------------------------------------------------------------------
/img/cubist.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/cubist.jpg
--------------------------------------------------------------------------------
/img/denoised_starry.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/denoised_starry.jpg
--------------------------------------------------------------------------------
/img/feathers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/feathers.jpg
--------------------------------------------------------------------------------
/img/mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/mosaic.jpg
--------------------------------------------------------------------------------
/img/results/cubist.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/cubist.jpg
--------------------------------------------------------------------------------
/img/results/denoised_starry.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/denoised_starry.jpg
--------------------------------------------------------------------------------
/img/results/feathers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/feathers.jpg
--------------------------------------------------------------------------------
/img/results/mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/mosaic.jpg
--------------------------------------------------------------------------------
/img/results/scream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/scream.jpg
--------------------------------------------------------------------------------
/img/results/style_cubist.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_cubist.jpg
--------------------------------------------------------------------------------
/img/results/style_denoised_starry.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_denoised_starry.jpg
--------------------------------------------------------------------------------
/img/results/style_feathers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_feathers.jpg
--------------------------------------------------------------------------------
/img/results/style_mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_mosaic.jpg
--------------------------------------------------------------------------------
/img/results/style_scream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_scream.jpg
--------------------------------------------------------------------------------
/img/results/style_udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_udnie.jpg
--------------------------------------------------------------------------------
/img/results/style_wave.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/style_wave.jpg
--------------------------------------------------------------------------------
/img/results/udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/udnie.jpg
--------------------------------------------------------------------------------
/img/results/wave.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/results/wave.jpg
--------------------------------------------------------------------------------
/img/scream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/scream.jpg
--------------------------------------------------------------------------------
/img/starry.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/starry.jpg
--------------------------------------------------------------------------------
/img/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test.jpg
--------------------------------------------------------------------------------
/img/test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test1.jpg
--------------------------------------------------------------------------------
/img/test2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test2.jpg
--------------------------------------------------------------------------------
/img/test3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test3.jpg
--------------------------------------------------------------------------------
/img/test4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test4.jpg
--------------------------------------------------------------------------------
/img/test5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/test5.jpg
--------------------------------------------------------------------------------
/img/udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/udnie.jpg
--------------------------------------------------------------------------------
/img/wave.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hzy46/fast-neural-style-tensorflow/eeaa47d359e5c589a4cc6ccbf8c0450ccc657d2d/img/wave.jpg
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | from __future__ import print_function
3 | import tensorflow as tf
4 | from nets import nets_factory
5 | from preprocessing import preprocessing_factory
6 | import utils
7 | import os
8 |
9 | slim = tf.contrib.slim
10 |
11 |
12 | def gram(layer):
13 | shape = tf.shape(layer)
14 | num_images = shape[0]
15 | width = shape[1]
16 | height = shape[2]
17 | num_filters = shape[3]
18 | filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
19 | grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters)
20 |
21 | return grams
22 |
23 |
24 | def get_style_features(FLAGS):
25 | """
26 | For the "style_image", the preprocessing step is:
27 | 1. Resize the shorter side to FLAGS.image_size
28 | 2. Apply central crop
29 | """
30 | with tf.Graph().as_default():
31 | network_fn = nets_factory.get_network_fn(
32 | FLAGS.loss_model,
33 | num_classes=1,
34 | is_training=False)
35 | image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
36 | FLAGS.loss_model,
37 | is_training=False)
38 |
39 | # Get the style image data
40 | size = FLAGS.image_size
41 | img_bytes = tf.read_file(FLAGS.style_image)
42 | if FLAGS.style_image.lower().endswith('png'):
43 | image = tf.image.decode_png(img_bytes)
44 | else:
45 | image = tf.image.decode_jpeg(img_bytes)
46 | # image = _aspect_preserving_resize(image, size)
47 |
48 | # Add the batch dimension
49 | images = tf.expand_dims(image_preprocessing_fn(image, size, size), 0)
50 | # images = tf.stack([image_preprocessing_fn(image, size, size)])
51 |
52 | _, endpoints_dict = network_fn(images, spatial_squeeze=False)
53 | features = []
54 | for layer in FLAGS.style_layers:
55 | feature = endpoints_dict[layer]
56 | feature = tf.squeeze(gram(feature), [0]) # remove the batch dimension
57 | features.append(feature)
58 |
59 | with tf.Session() as sess:
60 | # Restore variables for loss network.
61 | init_func = utils._get_init_fn(FLAGS)
62 | init_func(sess)
63 |
64 | # Make sure the 'generated' directory is exists.
65 | if os.path.exists('generated') is False:
66 | os.makedirs('generated')
67 | # Indicate cropped style image path
68 | save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
69 | # Write preprocessed style image to indicated path
70 | with open(save_file, 'wb') as f:
71 | target_image = image_unprocessing_fn(images[0, :])
72 | value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
73 | f.write(sess.run(value))
74 | tf.logging.info('Target style pattern is saved to: %s.' % save_file)
75 |
76 | # Return the features those layers are use for measuring style loss.
77 | return sess.run(features)
78 |
79 |
80 | def style_loss(endpoints_dict, style_features_t, style_layers):
81 | style_loss = 0
82 | style_loss_summary = {}
83 | for style_gram, layer in zip(style_features_t, style_layers):
84 | generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)
85 | size = tf.size(generated_images)
86 | layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size)
87 | style_loss_summary[layer] = layer_style_loss
88 | style_loss += layer_style_loss
89 | return style_loss, style_loss_summary
90 |
91 |
92 | def content_loss(endpoints_dict, content_layers):
93 | content_loss = 0
94 | for layer in content_layers:
95 | generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)
96 | size = tf.size(generated_images)
97 | content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper
98 | return content_loss
99 |
100 |
101 | def total_variation_loss(layer):
102 | shape = tf.shape(layer)
103 | height = shape[1]
104 | width = shape[2]
105 | y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
106 | x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
107 | loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
108 | return loss
109 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'):
5 | with tf.variable_scope('conv'):
6 |
7 | shape = [kernel, kernel, input_filters, output_filters]
8 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
9 | x_padded = tf.pad(x, [[0, 0], [int(kernel / 2), int(kernel / 2)], [int(kernel / 2), int(kernel / 2)], [0, 0]], mode=mode)
10 | return tf.nn.conv2d(x_padded, weight, strides=[1, strides, strides, 1], padding='VALID', name='conv')
11 |
12 |
13 | def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
14 | with tf.variable_scope('conv_transpose'):
15 |
16 | shape = [kernel, kernel, output_filters, input_filters]
17 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
18 |
19 | batch_size = tf.shape(x)[0]
20 | height = tf.shape(x)[1] * strides
21 | width = tf.shape(x)[2] * strides
22 | output_shape = tf.stack([batch_size, height, width, output_filters])
23 | return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose')
24 |
25 |
26 | def resize_conv2d(x, input_filters, output_filters, kernel, strides, training):
27 | '''
28 | An alternative to transposed convolution where we first resize, then convolve.
29 | See http://distill.pub/2016/deconv-checkerboard/
30 |
31 | For some reason the shape needs to be statically known for gradient propagation
32 | through tf.image.resize_images, but we only know that for fixed image size, so we
33 | plumb through a "training" argument
34 | '''
35 | with tf.variable_scope('conv_transpose'):
36 | height = x.get_shape()[1].value if training else tf.shape(x)[1]
37 | width = x.get_shape()[2].value if training else tf.shape(x)[2]
38 |
39 | new_height = height * strides * 2
40 | new_width = width * strides * 2
41 |
42 | x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
43 |
44 | # shape = [kernel, kernel, input_filters, output_filters]
45 | # weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
46 | return conv2d(x_resized, input_filters, output_filters, kernel, strides)
47 |
48 |
49 | def instance_norm(x):
50 | epsilon = 1e-9
51 |
52 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
53 |
54 | return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon)))
55 |
56 |
57 | def batch_norm(x, size, training, decay=0.999):
58 | beta = tf.Variable(tf.zeros([size]), name='beta')
59 | scale = tf.Variable(tf.ones([size]), name='scale')
60 | pop_mean = tf.Variable(tf.zeros([size]))
61 | pop_var = tf.Variable(tf.ones([size]))
62 | epsilon = 1e-3
63 |
64 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
65 | train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
66 | train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
67 |
68 | def batch_statistics():
69 | with tf.control_dependencies([train_mean, train_var]):
70 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')
71 |
72 | def population_statistics():
73 | return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')
74 |
75 | return tf.cond(training, batch_statistics, population_statistics)
76 |
77 |
78 | def relu(input):
79 | relu = tf.nn.relu(input)
80 | # convert nan to zero (nan != nan)
81 | nan_to_zero = tf.where(tf.equal(relu, relu), relu, tf.zeros_like(relu))
82 | return nan_to_zero
83 |
84 |
85 | def residual(x, filters, kernel, strides):
86 | with tf.variable_scope('residual'):
87 | conv1 = conv2d(x, filters, filters, kernel, strides)
88 | conv2 = conv2d(relu(conv1), filters, filters, kernel, strides)
89 |
90 | residual = x + conv2
91 |
92 | return residual
93 |
94 |
95 | def net(image, training):
96 | # Less border effects when padding a little before passing through ..
97 | image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT')
98 |
99 | with tf.variable_scope('conv1'):
100 | conv1 = relu(instance_norm(conv2d(image, 3, 32, 9, 1)))
101 | with tf.variable_scope('conv2'):
102 | conv2 = relu(instance_norm(conv2d(conv1, 32, 64, 3, 2)))
103 | with tf.variable_scope('conv3'):
104 | conv3 = relu(instance_norm(conv2d(conv2, 64, 128, 3, 2)))
105 | with tf.variable_scope('res1'):
106 | res1 = residual(conv3, 128, 3, 1)
107 | with tf.variable_scope('res2'):
108 | res2 = residual(res1, 128, 3, 1)
109 | with tf.variable_scope('res3'):
110 | res3 = residual(res2, 128, 3, 1)
111 | with tf.variable_scope('res4'):
112 | res4 = residual(res3, 128, 3, 1)
113 | with tf.variable_scope('res5'):
114 | res5 = residual(res4, 128, 3, 1)
115 | # print(res5.get_shape())
116 | with tf.variable_scope('deconv1'):
117 | # deconv1 = relu(instance_norm(conv2d_transpose(res5, 128, 64, 3, 2)))
118 | deconv1 = relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training)))
119 | with tf.variable_scope('deconv2'):
120 | # deconv2 = relu(instance_norm(conv2d_transpose(deconv1, 64, 32, 3, 2)))
121 | deconv2 = relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training)))
122 | with tf.variable_scope('deconv3'):
123 | # deconv_test = relu(instance_norm(conv2d(deconv2, 32, 32, 2, 1)))
124 | deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1)))
125 |
126 | y = (deconv3 + 1) * 127.5
127 |
128 | # Remove border effect reducing padding.
129 | height = tf.shape(y)[1]
130 | width = tf.shape(y)[2]
131 | y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1]))
132 |
133 | return y
134 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/nets/alexnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a model definition for AlexNet.
16 |
17 | This work was first described in:
18 | ImageNet Classification with Deep Convolutional Neural Networks
19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton
20 |
21 | and later refined in:
22 | One weird trick for parallelizing convolutional neural networks
23 | Alex Krizhevsky, 2014
24 |
25 | Here we provide the implementation proposed in "One weird trick" and not
26 | "ImageNet Classification", as per the paper, the LRN layers have been removed.
27 |
28 | Usage:
29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
30 | outputs, end_points = alexnet.alexnet_v2(inputs)
31 |
32 | @@alexnet_v2
33 | """
34 |
35 | from __future__ import absolute_import
36 | from __future__ import division
37 | from __future__ import print_function
38 |
39 | import tensorflow as tf
40 |
41 | slim = tf.contrib.slim
42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
43 |
44 |
45 | def alexnet_v2_arg_scope(weight_decay=0.0005):
46 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
47 | activation_fn=tf.nn.relu,
48 | biases_initializer=tf.constant_initializer(0.1),
49 | weights_regularizer=slim.l2_regularizer(weight_decay)):
50 | with slim.arg_scope([slim.conv2d], padding='SAME'):
51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
52 | return arg_sc
53 |
54 |
55 | def alexnet_v2(inputs,
56 | num_classes=1000,
57 | is_training=True,
58 | dropout_keep_prob=0.5,
59 | spatial_squeeze=True,
60 | scope='alexnet_v2'):
61 | """AlexNet version 2.
62 |
63 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf
64 | Parameters from:
65 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
66 | layers-imagenet-1gpu.cfg
67 |
68 | Note: All the fully_connected layers have been transformed to conv2d layers.
69 | To use in classification mode, resize input to 224x224. To use in fully
70 | convolutional mode, set spatial_squeeze to false.
71 | The LRN layers have been removed and change the initializers from
72 | random_normal_initializer to xavier_initializer.
73 |
74 | Args:
75 | inputs: a tensor of size [batch_size, height, width, channels].
76 | num_classes: number of predicted classes.
77 | is_training: whether or not the model is being trained.
78 | dropout_keep_prob: the probability that activations are kept in the dropout
79 | layers during training.
80 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
81 | outputs. Useful to remove unnecessary dimensions for classification.
82 | scope: Optional scope for the variables.
83 |
84 | Returns:
85 | the last op containing the log predictions and end_points dict.
86 | """
87 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
88 | end_points_collection = sc.name + '_end_points'
89 | # Collect outputs for conv2d, fully_connected and max_pool2d.
90 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
91 | outputs_collections=[end_points_collection]):
92 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
93 | scope='conv1')
94 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
95 | net = slim.conv2d(net, 192, [5, 5], scope='conv2')
96 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
97 | net = slim.conv2d(net, 384, [3, 3], scope='conv3')
98 | net = slim.conv2d(net, 384, [3, 3], scope='conv4')
99 | net = slim.conv2d(net, 256, [3, 3], scope='conv5')
100 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')
101 |
102 | # Use conv2d instead of fully_connected layers.
103 | with slim.arg_scope([slim.conv2d],
104 | weights_initializer=trunc_normal(0.005),
105 | biases_initializer=tf.constant_initializer(0.1)):
106 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
107 | scope='fc6')
108 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
109 | scope='dropout6')
110 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
111 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
112 | scope='dropout7')
113 | net = slim.conv2d(net, num_classes, [1, 1],
114 | activation_fn=None,
115 | normalizer_fn=None,
116 | biases_initializer=tf.zeros_initializer,
117 | scope='fc8')
118 |
119 | # Convert end_points_collection into a end_point dict.
120 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
121 | if spatial_squeeze:
122 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
123 | end_points[sc.name + '/fc8'] = net
124 | return net, end_points
125 | alexnet_v2.default_image_size = 224
126 |
--------------------------------------------------------------------------------
/nets/alexnet_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.nets.alexnet."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from nets import alexnet
23 |
24 | slim = tf.contrib.slim
25 |
26 |
27 | class AlexnetV2Test(tf.test.TestCase):
28 |
29 | def testBuild(self):
30 | batch_size = 5
31 | height, width = 224, 224
32 | num_classes = 1000
33 | with self.test_session():
34 | inputs = tf.random_uniform((batch_size, height, width, 3))
35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes)
36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed')
37 | self.assertListEqual(logits.get_shape().as_list(),
38 | [batch_size, num_classes])
39 |
40 | def testFullyConvolutional(self):
41 | batch_size = 1
42 | height, width = 300, 400
43 | num_classes = 1000
44 | with self.test_session():
45 | inputs = tf.random_uniform((batch_size, height, width, 3))
46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False)
47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
48 | self.assertListEqual(logits.get_shape().as_list(),
49 | [batch_size, 4, 7, num_classes])
50 |
51 | def testEndPoints(self):
52 | batch_size = 5
53 | height, width = 224, 224
54 | num_classes = 1000
55 | with self.test_session():
56 | inputs = tf.random_uniform((batch_size, height, width, 3))
57 | _, end_points = alexnet.alexnet_v2(inputs, num_classes)
58 | expected_names = ['alexnet_v2/conv1',
59 | 'alexnet_v2/pool1',
60 | 'alexnet_v2/conv2',
61 | 'alexnet_v2/pool2',
62 | 'alexnet_v2/conv3',
63 | 'alexnet_v2/conv4',
64 | 'alexnet_v2/conv5',
65 | 'alexnet_v2/pool5',
66 | 'alexnet_v2/fc6',
67 | 'alexnet_v2/fc7',
68 | 'alexnet_v2/fc8'
69 | ]
70 | self.assertSetEqual(set(end_points.keys()), set(expected_names))
71 |
72 | def testModelVariables(self):
73 | batch_size = 5
74 | height, width = 224, 224
75 | num_classes = 1000
76 | with self.test_session():
77 | inputs = tf.random_uniform((batch_size, height, width, 3))
78 | alexnet.alexnet_v2(inputs, num_classes)
79 | expected_names = ['alexnet_v2/conv1/weights',
80 | 'alexnet_v2/conv1/biases',
81 | 'alexnet_v2/conv2/weights',
82 | 'alexnet_v2/conv2/biases',
83 | 'alexnet_v2/conv3/weights',
84 | 'alexnet_v2/conv3/biases',
85 | 'alexnet_v2/conv4/weights',
86 | 'alexnet_v2/conv4/biases',
87 | 'alexnet_v2/conv5/weights',
88 | 'alexnet_v2/conv5/biases',
89 | 'alexnet_v2/fc6/weights',
90 | 'alexnet_v2/fc6/biases',
91 | 'alexnet_v2/fc7/weights',
92 | 'alexnet_v2/fc7/biases',
93 | 'alexnet_v2/fc8/weights',
94 | 'alexnet_v2/fc8/biases',
95 | ]
96 | model_variables = [v.op.name for v in slim.get_model_variables()]
97 | self.assertSetEqual(set(model_variables), set(expected_names))
98 |
99 | def testEvaluation(self):
100 | batch_size = 2
101 | height, width = 224, 224
102 | num_classes = 1000
103 | with self.test_session():
104 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
105 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False)
106 | self.assertListEqual(logits.get_shape().as_list(),
107 | [batch_size, num_classes])
108 | predictions = tf.argmax(logits, 1)
109 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
110 |
111 | def testTrainEvalWithReuse(self):
112 | train_batch_size = 2
113 | eval_batch_size = 1
114 | train_height, train_width = 224, 224
115 | eval_height, eval_width = 300, 400
116 | num_classes = 1000
117 | with self.test_session():
118 | train_inputs = tf.random_uniform(
119 | (train_batch_size, train_height, train_width, 3))
120 | logits, _ = alexnet.alexnet_v2(train_inputs)
121 | self.assertListEqual(logits.get_shape().as_list(),
122 | [train_batch_size, num_classes])
123 | tf.get_variable_scope().reuse_variables()
124 | eval_inputs = tf.random_uniform(
125 | (eval_batch_size, eval_height, eval_width, 3))
126 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False,
127 | spatial_squeeze=False)
128 | self.assertListEqual(logits.get_shape().as_list(),
129 | [eval_batch_size, 4, 7, num_classes])
130 | logits = tf.reduce_mean(logits, [1, 2])
131 | predictions = tf.argmax(logits, 1)
132 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
133 |
134 | def testForward(self):
135 | batch_size = 1
136 | height, width = 224, 224
137 | with self.test_session() as sess:
138 | inputs = tf.random_uniform((batch_size, height, width, 3))
139 | logits, _ = alexnet.alexnet_v2(inputs)
140 | sess.run(tf.initialize_all_variables())
141 | output = sess.run(logits)
142 | self.assertTrue(output.any())
143 |
144 | if __name__ == '__main__':
145 | tf.test.main()
146 |
--------------------------------------------------------------------------------
/nets/cifarnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a variant of the CIFAR-10 model definition."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev)
26 |
27 |
28 | def cifarnet(images, num_classes=10, is_training=False,
29 | dropout_keep_prob=0.5,
30 | prediction_fn=slim.softmax,
31 | scope='CifarNet'):
32 | """Creates a variant of the CifarNet model.
33 |
34 | Note that since the output is a set of 'logits', the values fall in the
35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a
36 | probability distribution over the characters, one will need to convert them
37 | using the softmax function:
38 |
39 | logits = cifarnet.cifarnet(images, is_training=False)
40 | probabilities = tf.nn.softmax(logits)
41 | predictions = tf.argmax(logits, 1)
42 |
43 | Args:
44 | images: A batch of `Tensors` of size [batch_size, height, width, channels].
45 | num_classes: the number of classes in the dataset.
46 | is_training: specifies whether or not we're currently training the model.
47 | This variable will determine the behaviour of the dropout layer.
48 | dropout_keep_prob: the percentage of activation values that are retained.
49 | prediction_fn: a function to get predictions out of logits.
50 | scope: Optional variable_scope.
51 |
52 | Returns:
53 | logits: the pre-softmax activations, a tensor of size
54 | [batch_size, `num_classes`]
55 | end_points: a dictionary from components of the network to the corresponding
56 | activation.
57 | """
58 | end_points = {}
59 |
60 | with tf.variable_scope(scope, 'CifarNet', [images, num_classes]):
61 | net = slim.conv2d(images, 64, [5, 5], scope='conv1')
62 | end_points['conv1'] = net
63 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
64 | end_points['pool1'] = net
65 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
66 | net = slim.conv2d(net, 64, [5, 5], scope='conv2')
67 | end_points['conv2'] = net
68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
69 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
70 | end_points['pool2'] = net
71 | net = slim.flatten(net)
72 | end_points['Flatten'] = net
73 | net = slim.fully_connected(net, 384, scope='fc3')
74 | end_points['fc3'] = net
75 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
76 | scope='dropout3')
77 | net = slim.fully_connected(net, 192, scope='fc4')
78 | end_points['fc4'] = net
79 | logits = slim.fully_connected(net, num_classes,
80 | biases_initializer=tf.zeros_initializer,
81 | weights_initializer=trunc_normal(1/192.0),
82 | weights_regularizer=None,
83 | activation_fn=None,
84 | scope='logits')
85 |
86 | end_points['Logits'] = logits
87 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
88 |
89 | return logits, end_points
90 | cifarnet.default_image_size = 32
91 |
92 |
93 | def cifarnet_arg_scope(weight_decay=0.004):
94 | """Defines the default cifarnet argument scope.
95 |
96 | Args:
97 | weight_decay: The weight decay to use for regularizing the model.
98 |
99 | Returns:
100 | An `arg_scope` to use for the inception v3 model.
101 | """
102 | with slim.arg_scope(
103 | [slim.conv2d],
104 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2),
105 | activation_fn=tf.nn.relu):
106 | with slim.arg_scope(
107 | [slim.fully_connected],
108 | biases_initializer=tf.constant_initializer(0.1),
109 | weights_initializer=trunc_normal(0.04),
110 | weights_regularizer=slim.l2_regularizer(weight_decay),
111 | activation_fn=tf.nn.relu) as sc:
112 | return sc
113 |
--------------------------------------------------------------------------------
/nets/inception.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Brings all inception models under one namespace."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # pylint: disable=unused-import
22 | from nets.inception_resnet_v2 import inception_resnet_v2
23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope
24 | from nets.inception_v1 import inception_v1
25 | from nets.inception_v1 import inception_v1_arg_scope
26 | from nets.inception_v1 import inception_v1_base
27 | from nets.inception_v2 import inception_v2
28 | from nets.inception_v2 import inception_v2_arg_scope
29 | from nets.inception_v2 import inception_v2_base
30 | from nets.inception_v3 import inception_v3
31 | from nets.inception_v3 import inception_v3_arg_scope
32 | from nets.inception_v3 import inception_v3_base
33 | from nets.inception_v4 import inception_v4
34 | from nets.inception_v4 import inception_v4_arg_scope
35 | from nets.inception_v4 import inception_v4_base
36 | # pylint: enable=unused-import
37 |
--------------------------------------------------------------------------------
/nets/inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains the definition of the Inception Resnet V2 architecture.
16 |
17 | As described in http://arxiv.org/abs/1602.07261.
18 |
19 | Inception-v4, Inception-ResNet and the Impact of Residual Connections
20 | on Learning
21 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
22 | """
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 |
28 | import tensorflow as tf
29 |
30 | slim = tf.contrib.slim
31 |
32 |
33 | def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
34 | """Builds the 35x35 resnet block."""
35 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
36 | with tf.variable_scope('Branch_0'):
37 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
38 | with tf.variable_scope('Branch_1'):
39 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
40 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
41 | with tf.variable_scope('Branch_2'):
42 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
43 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')
44 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')
45 | mixed = tf.concat(3, [tower_conv, tower_conv1_1, tower_conv2_2])
46 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
47 | activation_fn=None, scope='Conv2d_1x1')
48 | net += scale * up
49 | if activation_fn:
50 | net = activation_fn(net)
51 | return net
52 |
53 |
54 | def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
55 | """Builds the 17x17 resnet block."""
56 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
57 | with tf.variable_scope('Branch_0'):
58 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
59 | with tf.variable_scope('Branch_1'):
60 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
61 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],
62 | scope='Conv2d_0b_1x7')
63 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],
64 | scope='Conv2d_0c_7x1')
65 | mixed = tf.concat(3, [tower_conv, tower_conv1_2])
66 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
67 | activation_fn=None, scope='Conv2d_1x1')
68 | net += scale * up
69 | if activation_fn:
70 | net = activation_fn(net)
71 | return net
72 |
73 |
74 | def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
75 | """Builds the 8x8 resnet block."""
76 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
77 | with tf.variable_scope('Branch_0'):
78 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
79 | with tf.variable_scope('Branch_1'):
80 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
81 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],
82 | scope='Conv2d_0b_1x3')
83 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],
84 | scope='Conv2d_0c_3x1')
85 | mixed = tf.concat(3, [tower_conv, tower_conv1_2])
86 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
87 | activation_fn=None, scope='Conv2d_1x1')
88 | net += scale * up
89 | if activation_fn:
90 | net = activation_fn(net)
91 | return net
92 |
93 |
94 | def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
95 | dropout_keep_prob=0.8,
96 | reuse=None,
97 | scope='InceptionResnetV2'):
98 | """Creates the Inception Resnet V2 model.
99 |
100 | Args:
101 | inputs: a 4-D tensor of size [batch_size, height, width, 3].
102 | num_classes: number of predicted classes.
103 | is_training: whether is training or not.
104 | dropout_keep_prob: float, the fraction to keep before final layer.
105 | reuse: whether or not the network and its variables should be reused. To be
106 | able to reuse 'scope' must be given.
107 | scope: Optional variable_scope.
108 |
109 | Returns:
110 | logits: the logits outputs of the model.
111 | end_points: the set of end_points from the inception model.
112 | """
113 | end_points = {}
114 |
115 | with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse):
116 | with slim.arg_scope([slim.batch_norm, slim.dropout],
117 | is_training=is_training):
118 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
119 | stride=1, padding='SAME'):
120 |
121 | # 149 x 149 x 32
122 | net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',
123 | scope='Conv2d_1a_3x3')
124 | end_points['Conv2d_1a_3x3'] = net
125 | # 147 x 147 x 32
126 | net = slim.conv2d(net, 32, 3, padding='VALID',
127 | scope='Conv2d_2a_3x3')
128 | end_points['Conv2d_2a_3x3'] = net
129 | # 147 x 147 x 64
130 | net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
131 | end_points['Conv2d_2b_3x3'] = net
132 | # 73 x 73 x 64
133 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
134 | scope='MaxPool_3a_3x3')
135 | end_points['MaxPool_3a_3x3'] = net
136 | # 73 x 73 x 80
137 | net = slim.conv2d(net, 80, 1, padding='VALID',
138 | scope='Conv2d_3b_1x1')
139 | end_points['Conv2d_3b_1x1'] = net
140 | # 71 x 71 x 192
141 | net = slim.conv2d(net, 192, 3, padding='VALID',
142 | scope='Conv2d_4a_3x3')
143 | end_points['Conv2d_4a_3x3'] = net
144 | # 35 x 35 x 192
145 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
146 | scope='MaxPool_5a_3x3')
147 | end_points['MaxPool_5a_3x3'] = net
148 |
149 | # 35 x 35 x 320
150 | with tf.variable_scope('Mixed_5b'):
151 | with tf.variable_scope('Branch_0'):
152 | tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
153 | with tf.variable_scope('Branch_1'):
154 | tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
155 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
156 | scope='Conv2d_0b_5x5')
157 | with tf.variable_scope('Branch_2'):
158 | tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
159 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
160 | scope='Conv2d_0b_3x3')
161 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
162 | scope='Conv2d_0c_3x3')
163 | with tf.variable_scope('Branch_3'):
164 | tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
165 | scope='AvgPool_0a_3x3')
166 | tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
167 | scope='Conv2d_0b_1x1')
168 | net = tf.concat(3, [tower_conv, tower_conv1_1,
169 | tower_conv2_2, tower_pool_1])
170 |
171 | end_points['Mixed_5b'] = net
172 | net = slim.repeat(net, 10, block35, scale=0.17)
173 |
174 | # 17 x 17 x 1024
175 | with tf.variable_scope('Mixed_6a'):
176 | with tf.variable_scope('Branch_0'):
177 | tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',
178 | scope='Conv2d_1a_3x3')
179 | with tf.variable_scope('Branch_1'):
180 | tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
181 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
182 | scope='Conv2d_0b_3x3')
183 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
184 | stride=2, padding='VALID',
185 | scope='Conv2d_1a_3x3')
186 | with tf.variable_scope('Branch_2'):
187 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
188 | scope='MaxPool_1a_3x3')
189 | net = tf.concat(3, [tower_conv, tower_conv1_2, tower_pool])
190 |
191 | end_points['Mixed_6a'] = net
192 | net = slim.repeat(net, 20, block17, scale=0.10)
193 |
194 | # Auxillary tower
195 | with tf.variable_scope('AuxLogits'):
196 | aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID',
197 | scope='Conv2d_1a_3x3')
198 | aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
199 | aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
200 | padding='VALID', scope='Conv2d_2a_5x5')
201 | aux = slim.flatten(aux)
202 | aux = slim.fully_connected(aux, num_classes, activation_fn=None,
203 | scope='Logits')
204 | end_points['AuxLogits'] = aux
205 |
206 | with tf.variable_scope('Mixed_7a'):
207 | with tf.variable_scope('Branch_0'):
208 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
209 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
210 | padding='VALID', scope='Conv2d_1a_3x3')
211 | with tf.variable_scope('Branch_1'):
212 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
213 | tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
214 | padding='VALID', scope='Conv2d_1a_3x3')
215 | with tf.variable_scope('Branch_2'):
216 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
217 | tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
218 | scope='Conv2d_0b_3x3')
219 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
220 | padding='VALID', scope='Conv2d_1a_3x3')
221 | with tf.variable_scope('Branch_3'):
222 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
223 | scope='MaxPool_1a_3x3')
224 | net = tf.concat(3, [tower_conv_1, tower_conv1_1,
225 | tower_conv2_2, tower_pool])
226 |
227 | end_points['Mixed_7a'] = net
228 |
229 | net = slim.repeat(net, 9, block8, scale=0.20)
230 | net = block8(net, activation_fn=None)
231 |
232 | net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
233 | end_points['Conv2d_7b_1x1'] = net
234 |
235 | with tf.variable_scope('Logits'):
236 | end_points['PrePool'] = net
237 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
238 | scope='AvgPool_1a_8x8')
239 | net = slim.flatten(net)
240 |
241 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
242 | scope='Dropout')
243 |
244 | end_points['PreLogitsFlatten'] = net
245 | logits = slim.fully_connected(net, num_classes, activation_fn=None,
246 | scope='Logits')
247 | end_points['Logits'] = logits
248 | end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
249 |
250 | return logits, end_points
251 | inception_resnet_v2.default_image_size = 299
252 |
253 |
254 | def inception_resnet_v2_arg_scope(weight_decay=0.00004,
255 | batch_norm_decay=0.9997,
256 | batch_norm_epsilon=0.001):
257 | """Yields the scope with the default parameters for inception_resnet_v2.
258 |
259 | Args:
260 | weight_decay: the weight decay for weights variables.
261 | batch_norm_decay: decay for the moving average of batch_norm momentums.
262 | batch_norm_epsilon: small float added to variance to avoid dividing by zero.
263 |
264 | Returns:
265 | a arg_scope with the parameters needed for inception_resnet_v2.
266 | """
267 | # Set weight_decay for weights in conv2d and fully_connected layers.
268 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
269 | weights_regularizer=slim.l2_regularizer(weight_decay),
270 | biases_regularizer=slim.l2_regularizer(weight_decay)):
271 |
272 | batch_norm_params = {
273 | 'decay': batch_norm_decay,
274 | 'epsilon': batch_norm_epsilon,
275 | }
276 | # Set activation_fn and parameters for batch_norm.
277 | with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu,
278 | normalizer_fn=slim.batch_norm,
279 | normalizer_params=batch_norm_params) as scope:
280 | return scope
281 |
--------------------------------------------------------------------------------
/nets/inception_resnet_v2_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.inception_resnet_v2."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from nets import inception
23 |
24 |
25 | class InceptionTest(tf.test.TestCase):
26 |
27 | def testBuildLogits(self):
28 | batch_size = 5
29 | height, width = 299, 299
30 | num_classes = 1000
31 | with self.test_session():
32 | inputs = tf.random_uniform((batch_size, height, width, 3))
33 | logits, _ = inception.inception_resnet_v2(inputs, num_classes)
34 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
35 | self.assertListEqual(logits.get_shape().as_list(),
36 | [batch_size, num_classes])
37 |
38 | def testBuildEndPoints(self):
39 | batch_size = 5
40 | height, width = 299, 299
41 | num_classes = 1000
42 | with self.test_session():
43 | inputs = tf.random_uniform((batch_size, height, width, 3))
44 | _, end_points = inception.inception_resnet_v2(inputs, num_classes)
45 | self.assertTrue('Logits' in end_points)
46 | logits = end_points['Logits']
47 | self.assertListEqual(logits.get_shape().as_list(),
48 | [batch_size, num_classes])
49 | self.assertTrue('AuxLogits' in end_points)
50 | aux_logits = end_points['AuxLogits']
51 | self.assertListEqual(aux_logits.get_shape().as_list(),
52 | [batch_size, num_classes])
53 | pre_pool = end_points['PrePool']
54 | self.assertListEqual(pre_pool.get_shape().as_list(),
55 | [batch_size, 8, 8, 1536])
56 |
57 | def testVariablesSetDevice(self):
58 | batch_size = 5
59 | height, width = 299, 299
60 | num_classes = 1000
61 | with self.test_session():
62 | inputs = tf.random_uniform((batch_size, height, width, 3))
63 | # Force all Variables to reside on the device.
64 | with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
65 | inception.inception_resnet_v2(inputs, num_classes)
66 | with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
67 | inception.inception_resnet_v2(inputs, num_classes)
68 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
69 | self.assertDeviceEqual(v.device, '/cpu:0')
70 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
71 | self.assertDeviceEqual(v.device, '/gpu:0')
72 |
73 | def testHalfSizeImages(self):
74 | batch_size = 5
75 | height, width = 150, 150
76 | num_classes = 1000
77 | with self.test_session():
78 | inputs = tf.random_uniform((batch_size, height, width, 3))
79 | logits, end_points = inception.inception_resnet_v2(inputs, num_classes)
80 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
81 | self.assertListEqual(logits.get_shape().as_list(),
82 | [batch_size, num_classes])
83 | pre_pool = end_points['PrePool']
84 | self.assertListEqual(pre_pool.get_shape().as_list(),
85 | [batch_size, 3, 3, 1536])
86 |
87 | def testUnknownBatchSize(self):
88 | batch_size = 1
89 | height, width = 299, 299
90 | num_classes = 1000
91 | with self.test_session() as sess:
92 | inputs = tf.placeholder(tf.float32, (None, height, width, 3))
93 | logits, _ = inception.inception_resnet_v2(inputs, num_classes)
94 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
95 | self.assertListEqual(logits.get_shape().as_list(),
96 | [None, num_classes])
97 | images = tf.random_uniform((batch_size, height, width, 3))
98 | sess.run(tf.initialize_all_variables())
99 | output = sess.run(logits, {inputs: images.eval()})
100 | self.assertEquals(output.shape, (batch_size, num_classes))
101 |
102 | def testEvaluation(self):
103 | batch_size = 2
104 | height, width = 299, 299
105 | num_classes = 1000
106 | with self.test_session() as sess:
107 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
108 | logits, _ = inception.inception_resnet_v2(eval_inputs,
109 | num_classes,
110 | is_training=False)
111 | predictions = tf.argmax(logits, 1)
112 | sess.run(tf.initialize_all_variables())
113 | output = sess.run(predictions)
114 | self.assertEquals(output.shape, (batch_size,))
115 |
116 | def testTrainEvalWithReuse(self):
117 | train_batch_size = 5
118 | eval_batch_size = 2
119 | height, width = 150, 150
120 | num_classes = 1000
121 | with self.test_session() as sess:
122 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
123 | inception.inception_resnet_v2(train_inputs, num_classes)
124 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
125 | logits, _ = inception.inception_resnet_v2(eval_inputs,
126 | num_classes,
127 | is_training=False,
128 | reuse=True)
129 | predictions = tf.argmax(logits, 1)
130 | sess.run(tf.initialize_all_variables())
131 | output = sess.run(predictions)
132 | self.assertEquals(output.shape, (eval_batch_size,))
133 |
134 |
135 | if __name__ == '__main__':
136 | tf.test.main()
137 |
--------------------------------------------------------------------------------
/nets/inception_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains common code shared by all inception models.
16 |
17 | Usage of arg scope:
18 | with slim.arg_scope(inception_arg_scope()):
19 | logits, end_points = inception.inception_v3(images, num_classes,
20 | is_training=is_training)
21 |
22 | """
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import tensorflow as tf
28 |
29 | slim = tf.contrib.slim
30 |
31 |
32 | def inception_arg_scope(weight_decay=0.00004,
33 | use_batch_norm=True,
34 | batch_norm_decay=0.9997,
35 | batch_norm_epsilon=0.001):
36 | """Defines the default arg scope for inception models.
37 |
38 | Args:
39 | weight_decay: The weight decay to use for regularizing the model.
40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution.
41 | batch_norm_decay: Decay for batch norm moving average.
42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero
43 | in batch norm.
44 |
45 | Returns:
46 | An `arg_scope` to use for the inception models.
47 | """
48 | batch_norm_params = {
49 | # Decay for the moving averages.
50 | 'decay': batch_norm_decay,
51 | # epsilon to prevent 0s in variance.
52 | 'epsilon': batch_norm_epsilon,
53 | # collection containing update_ops.
54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
55 | }
56 | if use_batch_norm:
57 | normalizer_fn = slim.batch_norm
58 | normalizer_params = batch_norm_params
59 | else:
60 | normalizer_fn = None
61 | normalizer_params = {}
62 | # Set weight_decay for weights in Conv and FC layers.
63 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
64 | weights_regularizer=slim.l2_regularizer(weight_decay)):
65 | with slim.arg_scope(
66 | [slim.conv2d],
67 | weights_initializer=slim.variance_scaling_initializer(),
68 | activation_fn=tf.nn.relu,
69 | normalizer_fn=normalizer_fn,
70 | normalizer_params=normalizer_params) as sc:
71 | return sc
72 |
--------------------------------------------------------------------------------
/nets/inception_v1_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for nets.inception_v1."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | from nets import inception
25 |
26 | slim = tf.contrib.slim
27 |
28 |
29 | class InceptionV1Test(tf.test.TestCase):
30 |
31 | def testBuildClassificationNetwork(self):
32 | batch_size = 5
33 | height, width = 224, 224
34 | num_classes = 1000
35 |
36 | inputs = tf.random_uniform((batch_size, height, width, 3))
37 | logits, end_points = inception.inception_v1(inputs, num_classes)
38 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
39 | self.assertListEqual(logits.get_shape().as_list(),
40 | [batch_size, num_classes])
41 | self.assertTrue('Predictions' in end_points)
42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
43 | [batch_size, num_classes])
44 |
45 | def testBuildBaseNetwork(self):
46 | batch_size = 5
47 | height, width = 224, 224
48 |
49 | inputs = tf.random_uniform((batch_size, height, width, 3))
50 | mixed_6c, end_points = inception.inception_v1_base(inputs)
51 | self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c'))
52 | self.assertListEqual(mixed_6c.get_shape().as_list(),
53 | [batch_size, 7, 7, 1024])
54 | expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
55 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b',
56 | 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c',
57 | 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2',
58 | 'Mixed_5b', 'Mixed_5c']
59 | self.assertItemsEqual(end_points.keys(), expected_endpoints)
60 |
61 | def testBuildOnlyUptoFinalEndpoint(self):
62 | batch_size = 5
63 | height, width = 224, 224
64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
66 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d',
67 | 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b',
68 | 'Mixed_5c']
69 | for index, endpoint in enumerate(endpoints):
70 | with tf.Graph().as_default():
71 | inputs = tf.random_uniform((batch_size, height, width, 3))
72 | out_tensor, end_points = inception.inception_v1_base(
73 | inputs, final_endpoint=endpoint)
74 | self.assertTrue(out_tensor.op.name.startswith(
75 | 'InceptionV1/' + endpoint))
76 | self.assertItemsEqual(endpoints[:index+1], end_points)
77 |
78 | def testBuildAndCheckAllEndPointsUptoMixed5c(self):
79 | batch_size = 5
80 | height, width = 224, 224
81 |
82 | inputs = tf.random_uniform((batch_size, height, width, 3))
83 | _, end_points = inception.inception_v1_base(inputs,
84 | final_endpoint='Mixed_5c')
85 | endpoints_shapes = {'Conv2d_1a_7x7': [5, 112, 112, 64],
86 | 'MaxPool_2a_3x3': [5, 56, 56, 64],
87 | 'Conv2d_2b_1x1': [5, 56, 56, 64],
88 | 'Conv2d_2c_3x3': [5, 56, 56, 192],
89 | 'MaxPool_3a_3x3': [5, 28, 28, 192],
90 | 'Mixed_3b': [5, 28, 28, 256],
91 | 'Mixed_3c': [5, 28, 28, 480],
92 | 'MaxPool_4a_3x3': [5, 14, 14, 480],
93 | 'Mixed_4b': [5, 14, 14, 512],
94 | 'Mixed_4c': [5, 14, 14, 512],
95 | 'Mixed_4d': [5, 14, 14, 512],
96 | 'Mixed_4e': [5, 14, 14, 528],
97 | 'Mixed_4f': [5, 14, 14, 832],
98 | 'MaxPool_5a_2x2': [5, 7, 7, 832],
99 | 'Mixed_5b': [5, 7, 7, 832],
100 | 'Mixed_5c': [5, 7, 7, 1024]}
101 |
102 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
103 | for endpoint_name in endpoints_shapes:
104 | expected_shape = endpoints_shapes[endpoint_name]
105 | self.assertTrue(endpoint_name in end_points)
106 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
107 | expected_shape)
108 |
109 | def testModelHasExpectedNumberOfParameters(self):
110 | batch_size = 5
111 | height, width = 224, 224
112 | inputs = tf.random_uniform((batch_size, height, width, 3))
113 | with slim.arg_scope(inception.inception_v1_arg_scope()):
114 | inception.inception_v1_base(inputs)
115 | total_params, _ = slim.model_analyzer.analyze_vars(
116 | slim.get_model_variables())
117 | self.assertAlmostEqual(5607184, total_params)
118 |
119 | def testHalfSizeImages(self):
120 | batch_size = 5
121 | height, width = 112, 112
122 |
123 | inputs = tf.random_uniform((batch_size, height, width, 3))
124 | mixed_5c, _ = inception.inception_v1_base(inputs)
125 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c'))
126 | self.assertListEqual(mixed_5c.get_shape().as_list(),
127 | [batch_size, 4, 4, 1024])
128 |
129 | def testUnknownImageShape(self):
130 | tf.reset_default_graph()
131 | batch_size = 2
132 | height, width = 224, 224
133 | num_classes = 1000
134 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
135 | with self.test_session() as sess:
136 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
137 | logits, end_points = inception.inception_v1(inputs, num_classes)
138 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
139 | self.assertListEqual(logits.get_shape().as_list(),
140 | [batch_size, num_classes])
141 | pre_pool = end_points['Mixed_5c']
142 | feed_dict = {inputs: input_np}
143 | tf.initialize_all_variables().run()
144 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
145 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
146 |
147 | def testUnknowBatchSize(self):
148 | batch_size = 1
149 | height, width = 224, 224
150 | num_classes = 1000
151 |
152 | inputs = tf.placeholder(tf.float32, (None, height, width, 3))
153 | logits, _ = inception.inception_v1(inputs, num_classes)
154 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
155 | self.assertListEqual(logits.get_shape().as_list(),
156 | [None, num_classes])
157 | images = tf.random_uniform((batch_size, height, width, 3))
158 |
159 | with self.test_session() as sess:
160 | sess.run(tf.initialize_all_variables())
161 | output = sess.run(logits, {inputs: images.eval()})
162 | self.assertEquals(output.shape, (batch_size, num_classes))
163 |
164 | def testEvaluation(self):
165 | batch_size = 2
166 | height, width = 224, 224
167 | num_classes = 1000
168 |
169 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
170 | logits, _ = inception.inception_v1(eval_inputs, num_classes,
171 | is_training=False)
172 | predictions = tf.argmax(logits, 1)
173 |
174 | with self.test_session() as sess:
175 | sess.run(tf.initialize_all_variables())
176 | output = sess.run(predictions)
177 | self.assertEquals(output.shape, (batch_size,))
178 |
179 | def testTrainEvalWithReuse(self):
180 | train_batch_size = 5
181 | eval_batch_size = 2
182 | height, width = 224, 224
183 | num_classes = 1000
184 |
185 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
186 | inception.inception_v1(train_inputs, num_classes)
187 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
188 | logits, _ = inception.inception_v1(eval_inputs, num_classes, reuse=True)
189 | predictions = tf.argmax(logits, 1)
190 |
191 | with self.test_session() as sess:
192 | sess.run(tf.initialize_all_variables())
193 | output = sess.run(predictions)
194 | self.assertEquals(output.shape, (eval_batch_size,))
195 |
196 | def testLogitsNotSqueezed(self):
197 | num_classes = 25
198 | images = tf.random_uniform([1, 224, 224, 3])
199 | logits, _ = inception.inception_v1(images,
200 | num_classes=num_classes,
201 | spatial_squeeze=False)
202 |
203 | with self.test_session() as sess:
204 | tf.initialize_all_variables().run()
205 | logits_out = sess.run(logits)
206 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
207 |
208 |
209 | if __name__ == '__main__':
210 | tf.test.main()
211 |
--------------------------------------------------------------------------------
/nets/inception_v2_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for nets.inception_v2."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | from nets import inception
25 |
26 | slim = tf.contrib.slim
27 |
28 |
29 | class InceptionV2Test(tf.test.TestCase):
30 |
31 | def testBuildClassificationNetwork(self):
32 | batch_size = 5
33 | height, width = 224, 224
34 | num_classes = 1000
35 |
36 | inputs = tf.random_uniform((batch_size, height, width, 3))
37 | logits, end_points = inception.inception_v2(inputs, num_classes)
38 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
39 | self.assertListEqual(logits.get_shape().as_list(),
40 | [batch_size, num_classes])
41 | self.assertTrue('Predictions' in end_points)
42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
43 | [batch_size, num_classes])
44 |
45 | def testBuildBaseNetwork(self):
46 | batch_size = 5
47 | height, width = 224, 224
48 |
49 | inputs = tf.random_uniform((batch_size, height, width, 3))
50 | mixed_5c, end_points = inception.inception_v2_base(inputs)
51 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV2/Mixed_5c'))
52 | self.assertListEqual(mixed_5c.get_shape().as_list(),
53 | [batch_size, 7, 7, 1024])
54 | expected_endpoints = ['Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b',
55 | 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a',
56 | 'Mixed_5b', 'Mixed_5c', 'Conv2d_1a_7x7',
57 | 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3',
58 | 'MaxPool_3a_3x3']
59 | self.assertItemsEqual(end_points.keys(), expected_endpoints)
60 |
61 | def testBuildOnlyUptoFinalEndpoint(self):
62 | batch_size = 5
63 | height, width = 224, 224
64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
66 | 'Mixed_4a', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
67 | 'Mixed_5a', 'Mixed_5b', 'Mixed_5c']
68 | for index, endpoint in enumerate(endpoints):
69 | with tf.Graph().as_default():
70 | inputs = tf.random_uniform((batch_size, height, width, 3))
71 | out_tensor, end_points = inception.inception_v2_base(
72 | inputs, final_endpoint=endpoint)
73 | self.assertTrue(out_tensor.op.name.startswith(
74 | 'InceptionV2/' + endpoint))
75 | self.assertItemsEqual(endpoints[:index+1], end_points)
76 |
77 | def testBuildAndCheckAllEndPointsUptoMixed5c(self):
78 | batch_size = 5
79 | height, width = 224, 224
80 |
81 | inputs = tf.random_uniform((batch_size, height, width, 3))
82 | _, end_points = inception.inception_v2_base(inputs,
83 | final_endpoint='Mixed_5c')
84 | endpoints_shapes = {'Mixed_3b': [batch_size, 28, 28, 256],
85 | 'Mixed_3c': [batch_size, 28, 28, 320],
86 | 'Mixed_4a': [batch_size, 14, 14, 576],
87 | 'Mixed_4b': [batch_size, 14, 14, 576],
88 | 'Mixed_4c': [batch_size, 14, 14, 576],
89 | 'Mixed_4d': [batch_size, 14, 14, 576],
90 | 'Mixed_4e': [batch_size, 14, 14, 576],
91 | 'Mixed_5a': [batch_size, 7, 7, 1024],
92 | 'Mixed_5b': [batch_size, 7, 7, 1024],
93 | 'Mixed_5c': [batch_size, 7, 7, 1024],
94 | 'Conv2d_1a_7x7': [batch_size, 112, 112, 64],
95 | 'MaxPool_2a_3x3': [batch_size, 56, 56, 64],
96 | 'Conv2d_2b_1x1': [batch_size, 56, 56, 64],
97 | 'Conv2d_2c_3x3': [batch_size, 56, 56, 192],
98 | 'MaxPool_3a_3x3': [batch_size, 28, 28, 192]}
99 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
100 | for endpoint_name in endpoints_shapes:
101 | expected_shape = endpoints_shapes[endpoint_name]
102 | self.assertTrue(endpoint_name in end_points)
103 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
104 | expected_shape)
105 |
106 | def testModelHasExpectedNumberOfParameters(self):
107 | batch_size = 5
108 | height, width = 224, 224
109 | inputs = tf.random_uniform((batch_size, height, width, 3))
110 | with slim.arg_scope(inception.inception_v2_arg_scope()):
111 | inception.inception_v2_base(inputs)
112 | total_params, _ = slim.model_analyzer.analyze_vars(
113 | slim.get_model_variables())
114 | self.assertAlmostEqual(10173112, total_params)
115 |
116 | def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
117 | batch_size = 5
118 | height, width = 224, 224
119 | num_classes = 1000
120 |
121 | inputs = tf.random_uniform((batch_size, height, width, 3))
122 | _, end_points = inception.inception_v2(inputs, num_classes)
123 |
124 | endpoint_keys = [key for key in end_points.keys()
125 | if key.startswith('Mixed') or key.startswith('Conv')]
126 |
127 | _, end_points_with_multiplier = inception.inception_v2(
128 | inputs, num_classes, scope='depth_multiplied_net',
129 | depth_multiplier=0.5)
130 |
131 | for key in endpoint_keys:
132 | original_depth = end_points[key].get_shape().as_list()[3]
133 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
134 | self.assertEqual(0.5 * original_depth, new_depth)
135 |
136 | def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
137 | batch_size = 5
138 | height, width = 224, 224
139 | num_classes = 1000
140 |
141 | inputs = tf.random_uniform((batch_size, height, width, 3))
142 | _, end_points = inception.inception_v2(inputs, num_classes)
143 |
144 | endpoint_keys = [key for key in end_points.keys()
145 | if key.startswith('Mixed') or key.startswith('Conv')]
146 |
147 | _, end_points_with_multiplier = inception.inception_v2(
148 | inputs, num_classes, scope='depth_multiplied_net',
149 | depth_multiplier=2.0)
150 |
151 | for key in endpoint_keys:
152 | original_depth = end_points[key].get_shape().as_list()[3]
153 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
154 | self.assertEqual(2.0 * original_depth, new_depth)
155 |
156 | def testRaiseValueErrorWithInvalidDepthMultiplier(self):
157 | batch_size = 5
158 | height, width = 224, 224
159 | num_classes = 1000
160 |
161 | inputs = tf.random_uniform((batch_size, height, width, 3))
162 | with self.assertRaises(ValueError):
163 | _ = inception.inception_v2(inputs, num_classes, depth_multiplier=-0.1)
164 | with self.assertRaises(ValueError):
165 | _ = inception.inception_v2(inputs, num_classes, depth_multiplier=0.0)
166 |
167 | def testHalfSizeImages(self):
168 | batch_size = 5
169 | height, width = 112, 112
170 | num_classes = 1000
171 |
172 | inputs = tf.random_uniform((batch_size, height, width, 3))
173 | logits, end_points = inception.inception_v2(inputs, num_classes)
174 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
175 | self.assertListEqual(logits.get_shape().as_list(),
176 | [batch_size, num_classes])
177 | pre_pool = end_points['Mixed_5c']
178 | self.assertListEqual(pre_pool.get_shape().as_list(),
179 | [batch_size, 4, 4, 1024])
180 |
181 | def testUnknownImageShape(self):
182 | tf.reset_default_graph()
183 | batch_size = 2
184 | height, width = 224, 224
185 | num_classes = 1000
186 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
187 | with self.test_session() as sess:
188 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
189 | logits, end_points = inception.inception_v2(inputs, num_classes)
190 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
191 | self.assertListEqual(logits.get_shape().as_list(),
192 | [batch_size, num_classes])
193 | pre_pool = end_points['Mixed_5c']
194 | feed_dict = {inputs: input_np}
195 | tf.initialize_all_variables().run()
196 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
197 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
198 |
199 | def testUnknowBatchSize(self):
200 | batch_size = 1
201 | height, width = 224, 224
202 | num_classes = 1000
203 |
204 | inputs = tf.placeholder(tf.float32, (None, height, width, 3))
205 | logits, _ = inception.inception_v2(inputs, num_classes)
206 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
207 | self.assertListEqual(logits.get_shape().as_list(),
208 | [None, num_classes])
209 | images = tf.random_uniform((batch_size, height, width, 3))
210 |
211 | with self.test_session() as sess:
212 | sess.run(tf.initialize_all_variables())
213 | output = sess.run(logits, {inputs: images.eval()})
214 | self.assertEquals(output.shape, (batch_size, num_classes))
215 |
216 | def testEvaluation(self):
217 | batch_size = 2
218 | height, width = 224, 224
219 | num_classes = 1000
220 |
221 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
222 | logits, _ = inception.inception_v2(eval_inputs, num_classes,
223 | is_training=False)
224 | predictions = tf.argmax(logits, 1)
225 |
226 | with self.test_session() as sess:
227 | sess.run(tf.initialize_all_variables())
228 | output = sess.run(predictions)
229 | self.assertEquals(output.shape, (batch_size,))
230 |
231 | def testTrainEvalWithReuse(self):
232 | train_batch_size = 5
233 | eval_batch_size = 2
234 | height, width = 150, 150
235 | num_classes = 1000
236 |
237 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
238 | inception.inception_v2(train_inputs, num_classes)
239 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
240 | logits, _ = inception.inception_v2(eval_inputs, num_classes, reuse=True)
241 | predictions = tf.argmax(logits, 1)
242 |
243 | with self.test_session() as sess:
244 | sess.run(tf.initialize_all_variables())
245 | output = sess.run(predictions)
246 | self.assertEquals(output.shape, (eval_batch_size,))
247 |
248 | def testLogitsNotSqueezed(self):
249 | num_classes = 25
250 | images = tf.random_uniform([1, 224, 224, 3])
251 | logits, _ = inception.inception_v2(images,
252 | num_classes=num_classes,
253 | spatial_squeeze=False)
254 |
255 | with self.test_session() as sess:
256 | tf.initialize_all_variables().run()
257 | logits_out = sess.run(logits)
258 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
259 |
260 |
261 | if __name__ == '__main__':
262 | tf.test.main()
263 |
--------------------------------------------------------------------------------
/nets/inception_v3_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for nets.inception_v1."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | from nets import inception
25 |
26 | slim = tf.contrib.slim
27 |
28 |
29 | class InceptionV3Test(tf.test.TestCase):
30 |
31 | def testBuildClassificationNetwork(self):
32 | batch_size = 5
33 | height, width = 299, 299
34 | num_classes = 1000
35 |
36 | inputs = tf.random_uniform((batch_size, height, width, 3))
37 | logits, end_points = inception.inception_v3(inputs, num_classes)
38 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
39 | self.assertListEqual(logits.get_shape().as_list(),
40 | [batch_size, num_classes])
41 | self.assertTrue('Predictions' in end_points)
42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
43 | [batch_size, num_classes])
44 |
45 | def testBuildBaseNetwork(self):
46 | batch_size = 5
47 | height, width = 299, 299
48 |
49 | inputs = tf.random_uniform((batch_size, height, width, 3))
50 | final_endpoint, end_points = inception.inception_v3_base(inputs)
51 | self.assertTrue(final_endpoint.op.name.startswith(
52 | 'InceptionV3/Mixed_7c'))
53 | self.assertListEqual(final_endpoint.get_shape().as_list(),
54 | [batch_size, 8, 8, 2048])
55 | expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
56 | 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
57 | 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
58 | 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
59 | 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
60 | self.assertItemsEqual(end_points.keys(), expected_endpoints)
61 |
62 | def testBuildOnlyUptoFinalEndpoint(self):
63 | batch_size = 5
64 | height, width = 299, 299
65 | endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
66 | 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
67 | 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
68 | 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
69 | 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
70 |
71 | for index, endpoint in enumerate(endpoints):
72 | with tf.Graph().as_default():
73 | inputs = tf.random_uniform((batch_size, height, width, 3))
74 | out_tensor, end_points = inception.inception_v3_base(
75 | inputs, final_endpoint=endpoint)
76 | self.assertTrue(out_tensor.op.name.startswith(
77 | 'InceptionV3/' + endpoint))
78 | self.assertItemsEqual(endpoints[:index+1], end_points)
79 |
80 | def testBuildAndCheckAllEndPointsUptoMixed7c(self):
81 | batch_size = 5
82 | height, width = 299, 299
83 |
84 | inputs = tf.random_uniform((batch_size, height, width, 3))
85 | _, end_points = inception.inception_v3_base(
86 | inputs, final_endpoint='Mixed_7c')
87 | endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32],
88 | 'Conv2d_2a_3x3': [batch_size, 147, 147, 32],
89 | 'Conv2d_2b_3x3': [batch_size, 147, 147, 64],
90 | 'MaxPool_3a_3x3': [batch_size, 73, 73, 64],
91 | 'Conv2d_3b_1x1': [batch_size, 73, 73, 80],
92 | 'Conv2d_4a_3x3': [batch_size, 71, 71, 192],
93 | 'MaxPool_5a_3x3': [batch_size, 35, 35, 192],
94 | 'Mixed_5b': [batch_size, 35, 35, 256],
95 | 'Mixed_5c': [batch_size, 35, 35, 288],
96 | 'Mixed_5d': [batch_size, 35, 35, 288],
97 | 'Mixed_6a': [batch_size, 17, 17, 768],
98 | 'Mixed_6b': [batch_size, 17, 17, 768],
99 | 'Mixed_6c': [batch_size, 17, 17, 768],
100 | 'Mixed_6d': [batch_size, 17, 17, 768],
101 | 'Mixed_6e': [batch_size, 17, 17, 768],
102 | 'Mixed_7a': [batch_size, 8, 8, 1280],
103 | 'Mixed_7b': [batch_size, 8, 8, 2048],
104 | 'Mixed_7c': [batch_size, 8, 8, 2048]}
105 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
106 | for endpoint_name in endpoints_shapes:
107 | expected_shape = endpoints_shapes[endpoint_name]
108 | self.assertTrue(endpoint_name in end_points)
109 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
110 | expected_shape)
111 |
112 | def testModelHasExpectedNumberOfParameters(self):
113 | batch_size = 5
114 | height, width = 299, 299
115 | inputs = tf.random_uniform((batch_size, height, width, 3))
116 | with slim.arg_scope(inception.inception_v3_arg_scope()):
117 | inception.inception_v3_base(inputs)
118 | total_params, _ = slim.model_analyzer.analyze_vars(
119 | slim.get_model_variables())
120 | self.assertAlmostEqual(21802784, total_params)
121 |
122 | def testBuildEndPoints(self):
123 | batch_size = 5
124 | height, width = 299, 299
125 | num_classes = 1000
126 |
127 | inputs = tf.random_uniform((batch_size, height, width, 3))
128 | _, end_points = inception.inception_v3(inputs, num_classes)
129 | self.assertTrue('Logits' in end_points)
130 | logits = end_points['Logits']
131 | self.assertListEqual(logits.get_shape().as_list(),
132 | [batch_size, num_classes])
133 | self.assertTrue('AuxLogits' in end_points)
134 | aux_logits = end_points['AuxLogits']
135 | self.assertListEqual(aux_logits.get_shape().as_list(),
136 | [batch_size, num_classes])
137 | self.assertTrue('Mixed_7c' in end_points)
138 | pre_pool = end_points['Mixed_7c']
139 | self.assertListEqual(pre_pool.get_shape().as_list(),
140 | [batch_size, 8, 8, 2048])
141 | self.assertTrue('PreLogits' in end_points)
142 | pre_logits = end_points['PreLogits']
143 | self.assertListEqual(pre_logits.get_shape().as_list(),
144 | [batch_size, 1, 1, 2048])
145 |
146 | def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
147 | batch_size = 5
148 | height, width = 299, 299
149 | num_classes = 1000
150 |
151 | inputs = tf.random_uniform((batch_size, height, width, 3))
152 | _, end_points = inception.inception_v3(inputs, num_classes)
153 |
154 | endpoint_keys = [key for key in end_points.keys()
155 | if key.startswith('Mixed') or key.startswith('Conv')]
156 |
157 | _, end_points_with_multiplier = inception.inception_v3(
158 | inputs, num_classes, scope='depth_multiplied_net',
159 | depth_multiplier=0.5)
160 |
161 | for key in endpoint_keys:
162 | original_depth = end_points[key].get_shape().as_list()[3]
163 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
164 | self.assertEqual(0.5 * original_depth, new_depth)
165 |
166 | def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
167 | batch_size = 5
168 | height, width = 299, 299
169 | num_classes = 1000
170 |
171 | inputs = tf.random_uniform((batch_size, height, width, 3))
172 | _, end_points = inception.inception_v3(inputs, num_classes)
173 |
174 | endpoint_keys = [key for key in end_points.keys()
175 | if key.startswith('Mixed') or key.startswith('Conv')]
176 |
177 | _, end_points_with_multiplier = inception.inception_v3(
178 | inputs, num_classes, scope='depth_multiplied_net',
179 | depth_multiplier=2.0)
180 |
181 | for key in endpoint_keys:
182 | original_depth = end_points[key].get_shape().as_list()[3]
183 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
184 | self.assertEqual(2.0 * original_depth, new_depth)
185 |
186 | def testRaiseValueErrorWithInvalidDepthMultiplier(self):
187 | batch_size = 5
188 | height, width = 299, 299
189 | num_classes = 1000
190 |
191 | inputs = tf.random_uniform((batch_size, height, width, 3))
192 | with self.assertRaises(ValueError):
193 | _ = inception.inception_v3(inputs, num_classes, depth_multiplier=-0.1)
194 | with self.assertRaises(ValueError):
195 | _ = inception.inception_v3(inputs, num_classes, depth_multiplier=0.0)
196 |
197 | def testHalfSizeImages(self):
198 | batch_size = 5
199 | height, width = 150, 150
200 | num_classes = 1000
201 |
202 | inputs = tf.random_uniform((batch_size, height, width, 3))
203 | logits, end_points = inception.inception_v3(inputs, num_classes)
204 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
205 | self.assertListEqual(logits.get_shape().as_list(),
206 | [batch_size, num_classes])
207 | pre_pool = end_points['Mixed_7c']
208 | self.assertListEqual(pre_pool.get_shape().as_list(),
209 | [batch_size, 3, 3, 2048])
210 |
211 | def testUnknownImageShape(self):
212 | tf.reset_default_graph()
213 | batch_size = 2
214 | height, width = 299, 299
215 | num_classes = 1000
216 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
217 | with self.test_session() as sess:
218 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
219 | logits, end_points = inception.inception_v3(inputs, num_classes)
220 | self.assertListEqual(logits.get_shape().as_list(),
221 | [batch_size, num_classes])
222 | pre_pool = end_points['Mixed_7c']
223 | feed_dict = {inputs: input_np}
224 | tf.initialize_all_variables().run()
225 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
226 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048])
227 |
228 | def testUnknowBatchSize(self):
229 | batch_size = 1
230 | height, width = 299, 299
231 | num_classes = 1000
232 |
233 | inputs = tf.placeholder(tf.float32, (None, height, width, 3))
234 | logits, _ = inception.inception_v3(inputs, num_classes)
235 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
236 | self.assertListEqual(logits.get_shape().as_list(),
237 | [None, num_classes])
238 | images = tf.random_uniform((batch_size, height, width, 3))
239 |
240 | with self.test_session() as sess:
241 | sess.run(tf.initialize_all_variables())
242 | output = sess.run(logits, {inputs: images.eval()})
243 | self.assertEquals(output.shape, (batch_size, num_classes))
244 |
245 | def testEvaluation(self):
246 | batch_size = 2
247 | height, width = 299, 299
248 | num_classes = 1000
249 |
250 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
251 | logits, _ = inception.inception_v3(eval_inputs, num_classes,
252 | is_training=False)
253 | predictions = tf.argmax(logits, 1)
254 |
255 | with self.test_session() as sess:
256 | sess.run(tf.initialize_all_variables())
257 | output = sess.run(predictions)
258 | self.assertEquals(output.shape, (batch_size,))
259 |
260 | def testTrainEvalWithReuse(self):
261 | train_batch_size = 5
262 | eval_batch_size = 2
263 | height, width = 150, 150
264 | num_classes = 1000
265 |
266 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
267 | inception.inception_v3(train_inputs, num_classes)
268 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
269 | logits, _ = inception.inception_v3(eval_inputs, num_classes,
270 | is_training=False, reuse=True)
271 | predictions = tf.argmax(logits, 1)
272 |
273 | with self.test_session() as sess:
274 | sess.run(tf.initialize_all_variables())
275 | output = sess.run(predictions)
276 | self.assertEquals(output.shape, (eval_batch_size,))
277 |
278 | def testLogitsNotSqueezed(self):
279 | num_classes = 25
280 | images = tf.random_uniform([1, 299, 299, 3])
281 | logits, _ = inception.inception_v3(images,
282 | num_classes=num_classes,
283 | spatial_squeeze=False)
284 |
285 | with self.test_session() as sess:
286 | tf.initialize_all_variables().run()
287 | logits_out = sess.run(logits)
288 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
289 |
290 |
291 | if __name__ == '__main__':
292 | tf.test.main()
293 |
--------------------------------------------------------------------------------
/nets/inception_v4_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.inception_v4."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from nets import inception
23 |
24 |
25 | class InceptionTest(tf.test.TestCase):
26 |
27 | def testBuildLogits(self):
28 | batch_size = 5
29 | height, width = 299, 299
30 | num_classes = 1000
31 | inputs = tf.random_uniform((batch_size, height, width, 3))
32 | logits, end_points = inception.inception_v4(inputs, num_classes)
33 | auxlogits = end_points['AuxLogits']
34 | predictions = end_points['Predictions']
35 | self.assertTrue(auxlogits.op.name.startswith('InceptionV4/AuxLogits'))
36 | self.assertListEqual(auxlogits.get_shape().as_list(),
37 | [batch_size, num_classes])
38 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
39 | self.assertListEqual(logits.get_shape().as_list(),
40 | [batch_size, num_classes])
41 | self.assertTrue(predictions.op.name.startswith(
42 | 'InceptionV4/Logits/Predictions'))
43 | self.assertListEqual(predictions.get_shape().as_list(),
44 | [batch_size, num_classes])
45 |
46 | def testBuildWithoutAuxLogits(self):
47 | batch_size = 5
48 | height, width = 299, 299
49 | num_classes = 1000
50 | inputs = tf.random_uniform((batch_size, height, width, 3))
51 | logits, endpoints = inception.inception_v4(inputs, num_classes,
52 | create_aux_logits=False)
53 | self.assertFalse('AuxLogits' in endpoints)
54 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
55 | self.assertListEqual(logits.get_shape().as_list(),
56 | [batch_size, num_classes])
57 |
58 | def testAllEndPointsShapes(self):
59 | batch_size = 5
60 | height, width = 299, 299
61 | num_classes = 1000
62 | inputs = tf.random_uniform((batch_size, height, width, 3))
63 | _, end_points = inception.inception_v4(inputs, num_classes)
64 | endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32],
65 | 'Conv2d_2a_3x3': [batch_size, 147, 147, 32],
66 | 'Conv2d_2b_3x3': [batch_size, 147, 147, 64],
67 | 'Mixed_3a': [batch_size, 73, 73, 160],
68 | 'Mixed_4a': [batch_size, 71, 71, 192],
69 | 'Mixed_5a': [batch_size, 35, 35, 384],
70 | # 4 x Inception-A blocks
71 | 'Mixed_5b': [batch_size, 35, 35, 384],
72 | 'Mixed_5c': [batch_size, 35, 35, 384],
73 | 'Mixed_5d': [batch_size, 35, 35, 384],
74 | 'Mixed_5e': [batch_size, 35, 35, 384],
75 | # Reduction-A block
76 | 'Mixed_6a': [batch_size, 17, 17, 1024],
77 | # 7 x Inception-B blocks
78 | 'Mixed_6b': [batch_size, 17, 17, 1024],
79 | 'Mixed_6c': [batch_size, 17, 17, 1024],
80 | 'Mixed_6d': [batch_size, 17, 17, 1024],
81 | 'Mixed_6e': [batch_size, 17, 17, 1024],
82 | 'Mixed_6f': [batch_size, 17, 17, 1024],
83 | 'Mixed_6g': [batch_size, 17, 17, 1024],
84 | 'Mixed_6h': [batch_size, 17, 17, 1024],
85 | # Reduction-A block
86 | 'Mixed_7a': [batch_size, 8, 8, 1536],
87 | # 3 x Inception-C blocks
88 | 'Mixed_7b': [batch_size, 8, 8, 1536],
89 | 'Mixed_7c': [batch_size, 8, 8, 1536],
90 | 'Mixed_7d': [batch_size, 8, 8, 1536],
91 | # Logits and predictions
92 | 'AuxLogits': [batch_size, num_classes],
93 | 'PreLogitsFlatten': [batch_size, 1536],
94 | 'Logits': [batch_size, num_classes],
95 | 'Predictions': [batch_size, num_classes]}
96 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
97 | for endpoint_name in endpoints_shapes:
98 | expected_shape = endpoints_shapes[endpoint_name]
99 | self.assertTrue(endpoint_name in end_points)
100 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
101 | expected_shape)
102 |
103 | def testBuildBaseNetwork(self):
104 | batch_size = 5
105 | height, width = 299, 299
106 | inputs = tf.random_uniform((batch_size, height, width, 3))
107 | net, end_points = inception.inception_v4_base(inputs)
108 | self.assertTrue(net.op.name.startswith(
109 | 'InceptionV4/Mixed_7d'))
110 | self.assertListEqual(net.get_shape().as_list(), [batch_size, 8, 8, 1536])
111 | expected_endpoints = [
112 | 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
113 | 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
114 | 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
115 | 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a',
116 | 'Mixed_7b', 'Mixed_7c', 'Mixed_7d']
117 | self.assertItemsEqual(end_points.keys(), expected_endpoints)
118 | for name, op in end_points.iteritems():
119 | self.assertTrue(op.name.startswith('InceptionV4/' + name))
120 |
121 | def testBuildOnlyUpToFinalEndpoint(self):
122 | batch_size = 5
123 | height, width = 299, 299
124 | all_endpoints = [
125 | 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
126 | 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
127 | 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
128 | 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a',
129 | 'Mixed_7b', 'Mixed_7c', 'Mixed_7d']
130 | for index, endpoint in enumerate(all_endpoints):
131 | with tf.Graph().as_default():
132 | inputs = tf.random_uniform((batch_size, height, width, 3))
133 | out_tensor, end_points = inception.inception_v4_base(
134 | inputs, final_endpoint=endpoint)
135 | self.assertTrue(out_tensor.op.name.startswith(
136 | 'InceptionV4/' + endpoint))
137 | self.assertItemsEqual(all_endpoints[:index+1], end_points)
138 |
139 | def testVariablesSetDevice(self):
140 | batch_size = 5
141 | height, width = 299, 299
142 | num_classes = 1000
143 | inputs = tf.random_uniform((batch_size, height, width, 3))
144 | # Force all Variables to reside on the device.
145 | with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
146 | inception.inception_v4(inputs, num_classes)
147 | with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
148 | inception.inception_v4(inputs, num_classes)
149 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
150 | self.assertDeviceEqual(v.device, '/cpu:0')
151 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
152 | self.assertDeviceEqual(v.device, '/gpu:0')
153 |
154 | def testHalfSizeImages(self):
155 | batch_size = 5
156 | height, width = 150, 150
157 | num_classes = 1000
158 | inputs = tf.random_uniform((batch_size, height, width, 3))
159 | logits, end_points = inception.inception_v4(inputs, num_classes)
160 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
161 | self.assertListEqual(logits.get_shape().as_list(),
162 | [batch_size, num_classes])
163 | pre_pool = end_points['Mixed_7d']
164 | self.assertListEqual(pre_pool.get_shape().as_list(),
165 | [batch_size, 3, 3, 1536])
166 |
167 | def testUnknownBatchSize(self):
168 | batch_size = 1
169 | height, width = 299, 299
170 | num_classes = 1000
171 | with self.test_session() as sess:
172 | inputs = tf.placeholder(tf.float32, (None, height, width, 3))
173 | logits, _ = inception.inception_v4(inputs, num_classes)
174 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
175 | self.assertListEqual(logits.get_shape().as_list(),
176 | [None, num_classes])
177 | images = tf.random_uniform((batch_size, height, width, 3))
178 | sess.run(tf.initialize_all_variables())
179 | output = sess.run(logits, {inputs: images.eval()})
180 | self.assertEquals(output.shape, (batch_size, num_classes))
181 |
182 | def testEvaluation(self):
183 | batch_size = 2
184 | height, width = 299, 299
185 | num_classes = 1000
186 | with self.test_session() as sess:
187 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
188 | logits, _ = inception.inception_v4(eval_inputs,
189 | num_classes,
190 | is_training=False)
191 | predictions = tf.argmax(logits, 1)
192 | sess.run(tf.initialize_all_variables())
193 | output = sess.run(predictions)
194 | self.assertEquals(output.shape, (batch_size,))
195 |
196 | def testTrainEvalWithReuse(self):
197 | train_batch_size = 5
198 | eval_batch_size = 2
199 | height, width = 150, 150
200 | num_classes = 1000
201 | with self.test_session() as sess:
202 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
203 | inception.inception_v4(train_inputs, num_classes)
204 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
205 | logits, _ = inception.inception_v4(eval_inputs,
206 | num_classes,
207 | is_training=False,
208 | reuse=True)
209 | predictions = tf.argmax(logits, 1)
210 | sess.run(tf.initialize_all_variables())
211 | output = sess.run(predictions)
212 | self.assertEquals(output.shape, (eval_batch_size,))
213 |
214 |
215 | if __name__ == '__main__':
216 | tf.test.main()
217 |
--------------------------------------------------------------------------------
/nets/lenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a variant of the LeNet model definition."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 |
26 | def lenet(images, num_classes=10, is_training=False,
27 | dropout_keep_prob=0.5,
28 | prediction_fn=slim.softmax,
29 | scope='LeNet'):
30 | """Creates a variant of the LeNet model.
31 |
32 | Note that since the output is a set of 'logits', the values fall in the
33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a
34 | probability distribution over the characters, one will need to convert them
35 | using the softmax function:
36 |
37 | logits = lenet.lenet(images, is_training=False)
38 | probabilities = tf.nn.softmax(logits)
39 | predictions = tf.argmax(logits, 1)
40 |
41 | Args:
42 | images: A batch of `Tensors` of size [batch_size, height, width, channels].
43 | num_classes: the number of classes in the dataset.
44 | is_training: specifies whether or not we're currently training the model.
45 | This variable will determine the behaviour of the dropout layer.
46 | dropout_keep_prob: the percentage of activation values that are retained.
47 | prediction_fn: a function to get predictions out of logits.
48 | scope: Optional variable_scope.
49 |
50 | Returns:
51 | logits: the pre-softmax activations, a tensor of size
52 | [batch_size, `num_classes`]
53 | end_points: a dictionary from components of the network to the corresponding
54 | activation.
55 | """
56 | end_points = {}
57 |
58 | with tf.variable_scope(scope, 'LeNet', [images, num_classes]):
59 | net = slim.conv2d(images, 32, [5, 5], scope='conv1')
60 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
61 | net = slim.conv2d(net, 64, [5, 5], scope='conv2')
62 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
63 | net = slim.flatten(net)
64 | end_points['Flatten'] = net
65 |
66 | net = slim.fully_connected(net, 1024, scope='fc3')
67 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
68 | scope='dropout3')
69 | logits = slim.fully_connected(net, num_classes, activation_fn=None,
70 | scope='fc4')
71 |
72 | end_points['Logits'] = logits
73 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
74 |
75 | return logits, end_points
76 | lenet.default_image_size = 28
77 |
78 |
79 | def lenet_arg_scope(weight_decay=0.0):
80 | """Defines the default lenet argument scope.
81 |
82 | Args:
83 | weight_decay: The weight decay to use for regularizing the model.
84 |
85 | Returns:
86 | An `arg_scope` to use for the inception v3 model.
87 | """
88 | with slim.arg_scope(
89 | [slim.conv2d, slim.fully_connected],
90 | weights_regularizer=slim.l2_regularizer(weight_decay),
91 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
92 | activation_fn=tf.nn.relu) as sc:
93 | return sc
94 |
--------------------------------------------------------------------------------
/nets/nets_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import functools
21 |
22 | import tensorflow as tf
23 |
24 | from nets import alexnet
25 | from nets import cifarnet
26 | from nets import inception
27 | from nets import lenet
28 | from nets import overfeat
29 | from nets import resnet_v1
30 | from nets import resnet_v2
31 | from nets import vgg
32 |
33 | slim = tf.contrib.slim
34 |
35 | networks_map = {'alexnet_v2': alexnet.alexnet_v2,
36 | 'cifarnet': cifarnet.cifarnet,
37 | 'overfeat': overfeat.overfeat,
38 | 'vgg_a': vgg.vgg_a,
39 | 'vgg_16': vgg.vgg_16,
40 | 'vgg_19': vgg.vgg_19,
41 | 'inception_v1': inception.inception_v1,
42 | 'inception_v2': inception.inception_v2,
43 | 'inception_v3': inception.inception_v3,
44 | 'inception_v4': inception.inception_v4,
45 | 'inception_resnet_v2': inception.inception_resnet_v2,
46 | 'lenet': lenet.lenet,
47 | 'resnet_v1_50': resnet_v1.resnet_v1_50,
48 | 'resnet_v1_101': resnet_v1.resnet_v1_101,
49 | 'resnet_v1_152': resnet_v1.resnet_v1_152,
50 | 'resnet_v1_200': resnet_v1.resnet_v1_200,
51 | 'resnet_v2_50': resnet_v2.resnet_v2_50,
52 | 'resnet_v2_101': resnet_v2.resnet_v2_101,
53 | 'resnet_v2_152': resnet_v2.resnet_v2_152,
54 | 'resnet_v2_200': resnet_v2.resnet_v2_200,
55 | }
56 |
57 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
58 | 'cifarnet': cifarnet.cifarnet_arg_scope,
59 | 'overfeat': overfeat.overfeat_arg_scope,
60 | 'vgg_a': vgg.vgg_arg_scope,
61 | 'vgg_16': vgg.vgg_arg_scope,
62 | 'vgg_19': vgg.vgg_arg_scope,
63 | 'inception_v1': inception.inception_v3_arg_scope,
64 | 'inception_v2': inception.inception_v3_arg_scope,
65 | 'inception_v3': inception.inception_v3_arg_scope,
66 | 'inception_v4': inception.inception_v4_arg_scope,
67 | 'inception_resnet_v2':
68 | inception.inception_resnet_v2_arg_scope,
69 | 'lenet': lenet.lenet_arg_scope,
70 | 'resnet_v1_50': resnet_v1.resnet_arg_scope,
71 | 'resnet_v1_101': resnet_v1.resnet_arg_scope,
72 | 'resnet_v1_152': resnet_v1.resnet_arg_scope,
73 | 'resnet_v1_200': resnet_v1.resnet_arg_scope,
74 | 'resnet_v2_50': resnet_v2.resnet_arg_scope,
75 | 'resnet_v2_101': resnet_v2.resnet_arg_scope,
76 | 'resnet_v2_152': resnet_v2.resnet_arg_scope,
77 | 'resnet_v2_200': resnet_v2.resnet_arg_scope,
78 | }
79 |
80 |
81 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
82 | """Returns a network_fn such as `logits, end_points = network_fn(images)`.
83 |
84 | Args:
85 | name: The name of the network.
86 | num_classes: The number of classes to use for classification.
87 | weight_decay: The l2 coefficient for the model weights.
88 | is_training: `True` if the model is being used for training and `False`
89 | otherwise.
90 |
91 | Returns:
92 | network_fn: A function that applies the model to a batch of images. It has
93 | the following signature:
94 | logits, end_points = network_fn(images)
95 | Raises:
96 | ValueError: If network `name` is not recognized.
97 | """
98 | if name not in networks_map:
99 | raise ValueError('Name of network unknown %s' % name)
100 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
101 | func = networks_map[name]
102 | @functools.wraps(func)
103 | def network_fn(images, **kwargs):
104 | with slim.arg_scope(arg_scope):
105 | return func(images, num_classes, is_training=is_training, **kwargs)
106 | if hasattr(func, 'default_image_size'):
107 | network_fn.default_image_size = func.default_image_size
108 |
109 | return network_fn
110 |
--------------------------------------------------------------------------------
/nets/nets_factory_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for slim.inception."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | import tensorflow as tf
24 |
25 | from nets import nets_factory
26 |
27 |
28 | class NetworksTest(tf.test.TestCase):
29 |
30 | def testGetNetworkFn(self):
31 | batch_size = 5
32 | num_classes = 1000
33 | for net in nets_factory.networks_map:
34 | with self.test_session():
35 | net_fn = nets_factory.get_network_fn(net, num_classes)
36 | # Most networks use 224 as their default_image_size
37 | image_size = getattr(net_fn, 'default_image_size', 224)
38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
39 | logits, end_points = net_fn(inputs)
40 | self.assertTrue(isinstance(logits, tf.Tensor))
41 | self.assertTrue(isinstance(end_points, dict))
42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size)
43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
44 |
45 | if __name__ == '__main__':
46 | tf.test.main()
47 |
--------------------------------------------------------------------------------
/nets/overfeat.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains the model definition for the OverFeat network.
16 |
17 | The definition for the network was obtained from:
18 | OverFeat: Integrated Recognition, Localization and Detection using
19 | Convolutional Networks
20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
21 | Yann LeCun, 2014
22 | http://arxiv.org/abs/1312.6229
23 |
24 | Usage:
25 | with slim.arg_scope(overfeat.overfeat_arg_scope()):
26 | outputs, end_points = overfeat.overfeat(inputs)
27 |
28 | @@overfeat
29 | """
30 | from __future__ import absolute_import
31 | from __future__ import division
32 | from __future__ import print_function
33 |
34 | import tensorflow as tf
35 |
36 | slim = tf.contrib.slim
37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
38 |
39 |
40 | def overfeat_arg_scope(weight_decay=0.0005):
41 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
42 | activation_fn=tf.nn.relu,
43 | weights_regularizer=slim.l2_regularizer(weight_decay),
44 | biases_initializer=tf.zeros_initializer):
45 | with slim.arg_scope([slim.conv2d], padding='SAME'):
46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
47 | return arg_sc
48 |
49 |
50 | def overfeat(inputs,
51 | num_classes=1000,
52 | is_training=True,
53 | dropout_keep_prob=0.5,
54 | spatial_squeeze=True,
55 | scope='overfeat'):
56 | """Contains the model definition for the OverFeat network.
57 |
58 | The definition for the network was obtained from:
59 | OverFeat: Integrated Recognition, Localization and Detection using
60 | Convolutional Networks
61 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
62 | Yann LeCun, 2014
63 | http://arxiv.org/abs/1312.6229
64 |
65 | Note: All the fully_connected layers have been transformed to conv2d layers.
66 | To use in classification mode, resize input to 231x231. To use in fully
67 | convolutional mode, set spatial_squeeze to false.
68 |
69 | Args:
70 | inputs: a tensor of size [batch_size, height, width, channels].
71 | num_classes: number of predicted classes.
72 | is_training: whether or not the model is being trained.
73 | dropout_keep_prob: the probability that activations are kept in the dropout
74 | layers during training.
75 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
76 | outputs. Useful to remove unnecessary dimensions for classification.
77 | scope: Optional scope for the variables.
78 |
79 | Returns:
80 | the last op containing the log predictions and end_points dict.
81 |
82 | """
83 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc:
84 | end_points_collection = sc.name + '_end_points'
85 | # Collect outputs for conv2d, fully_connected and max_pool2d
86 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
87 | outputs_collections=end_points_collection):
88 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
89 | scope='conv1')
90 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
91 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2')
92 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
93 | net = slim.conv2d(net, 512, [3, 3], scope='conv3')
94 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4')
95 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5')
96 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
97 | with slim.arg_scope([slim.conv2d],
98 | weights_initializer=trunc_normal(0.005),
99 | biases_initializer=tf.constant_initializer(0.1)):
100 | # Use conv2d instead of fully_connected layers.
101 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6')
102 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
103 | scope='dropout6')
104 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
105 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
106 | scope='dropout7')
107 | net = slim.conv2d(net, num_classes, [1, 1],
108 | activation_fn=None,
109 | normalizer_fn=None,
110 | biases_initializer=tf.zeros_initializer,
111 | scope='fc8')
112 | # Convert end_points_collection into a end_point dict.
113 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
114 | if spatial_squeeze:
115 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
116 | end_points[sc.name + '/fc8'] = net
117 | return net, end_points
118 | overfeat.default_image_size = 231
119 |
--------------------------------------------------------------------------------
/nets/overfeat_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.nets.overfeat."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from nets import overfeat
23 |
24 | slim = tf.contrib.slim
25 |
26 |
27 | class OverFeatTest(tf.test.TestCase):
28 |
29 | def testBuild(self):
30 | batch_size = 5
31 | height, width = 231, 231
32 | num_classes = 1000
33 | with self.test_session():
34 | inputs = tf.random_uniform((batch_size, height, width, 3))
35 | logits, _ = overfeat.overfeat(inputs, num_classes)
36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed')
37 | self.assertListEqual(logits.get_shape().as_list(),
38 | [batch_size, num_classes])
39 |
40 | def testFullyConvolutional(self):
41 | batch_size = 1
42 | height, width = 281, 281
43 | num_classes = 1000
44 | with self.test_session():
45 | inputs = tf.random_uniform((batch_size, height, width, 3))
46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False)
47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
48 | self.assertListEqual(logits.get_shape().as_list(),
49 | [batch_size, 2, 2, num_classes])
50 |
51 | def testEndPoints(self):
52 | batch_size = 5
53 | height, width = 231, 231
54 | num_classes = 1000
55 | with self.test_session():
56 | inputs = tf.random_uniform((batch_size, height, width, 3))
57 | _, end_points = overfeat.overfeat(inputs, num_classes)
58 | expected_names = ['overfeat/conv1',
59 | 'overfeat/pool1',
60 | 'overfeat/conv2',
61 | 'overfeat/pool2',
62 | 'overfeat/conv3',
63 | 'overfeat/conv4',
64 | 'overfeat/conv5',
65 | 'overfeat/pool5',
66 | 'overfeat/fc6',
67 | 'overfeat/fc7',
68 | 'overfeat/fc8'
69 | ]
70 | self.assertSetEqual(set(end_points.keys()), set(expected_names))
71 |
72 | def testModelVariables(self):
73 | batch_size = 5
74 | height, width = 231, 231
75 | num_classes = 1000
76 | with self.test_session():
77 | inputs = tf.random_uniform((batch_size, height, width, 3))
78 | overfeat.overfeat(inputs, num_classes)
79 | expected_names = ['overfeat/conv1/weights',
80 | 'overfeat/conv1/biases',
81 | 'overfeat/conv2/weights',
82 | 'overfeat/conv2/biases',
83 | 'overfeat/conv3/weights',
84 | 'overfeat/conv3/biases',
85 | 'overfeat/conv4/weights',
86 | 'overfeat/conv4/biases',
87 | 'overfeat/conv5/weights',
88 | 'overfeat/conv5/biases',
89 | 'overfeat/fc6/weights',
90 | 'overfeat/fc6/biases',
91 | 'overfeat/fc7/weights',
92 | 'overfeat/fc7/biases',
93 | 'overfeat/fc8/weights',
94 | 'overfeat/fc8/biases',
95 | ]
96 | model_variables = [v.op.name for v in slim.get_model_variables()]
97 | self.assertSetEqual(set(model_variables), set(expected_names))
98 |
99 | def testEvaluation(self):
100 | batch_size = 2
101 | height, width = 231, 231
102 | num_classes = 1000
103 | with self.test_session():
104 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
105 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False)
106 | self.assertListEqual(logits.get_shape().as_list(),
107 | [batch_size, num_classes])
108 | predictions = tf.argmax(logits, 1)
109 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
110 |
111 | def testTrainEvalWithReuse(self):
112 | train_batch_size = 2
113 | eval_batch_size = 1
114 | train_height, train_width = 231, 231
115 | eval_height, eval_width = 281, 281
116 | num_classes = 1000
117 | with self.test_session():
118 | train_inputs = tf.random_uniform(
119 | (train_batch_size, train_height, train_width, 3))
120 | logits, _ = overfeat.overfeat(train_inputs)
121 | self.assertListEqual(logits.get_shape().as_list(),
122 | [train_batch_size, num_classes])
123 | tf.get_variable_scope().reuse_variables()
124 | eval_inputs = tf.random_uniform(
125 | (eval_batch_size, eval_height, eval_width, 3))
126 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False,
127 | spatial_squeeze=False)
128 | self.assertListEqual(logits.get_shape().as_list(),
129 | [eval_batch_size, 2, 2, num_classes])
130 | logits = tf.reduce_mean(logits, [1, 2])
131 | predictions = tf.argmax(logits, 1)
132 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
133 |
134 | def testForward(self):
135 | batch_size = 1
136 | height, width = 231, 231
137 | with self.test_session() as sess:
138 | inputs = tf.random_uniform((batch_size, height, width, 3))
139 | logits, _ = overfeat.overfeat(inputs)
140 | sess.run(tf.initialize_all_variables())
141 | output = sess.run(logits)
142 | self.assertTrue(output.any())
143 |
144 | if __name__ == '__main__':
145 | tf.test.main()
146 |
--------------------------------------------------------------------------------
/nets/resnet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains building blocks for various versions of Residual Networks.
16 |
17 | Residual networks (ResNets) were proposed in:
18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
20 |
21 | More variants were introduced in:
22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
24 |
25 | We can obtain different ResNet variants by changing the network depth, width,
26 | and form of residual unit. This module implements the infrastructure for
27 | building them. Concrete ResNet units and full ResNet networks are implemented in
28 | the accompanying resnet_v1.py and resnet_v2.py modules.
29 |
30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
31 | implementation we subsample the output activations in the last residual unit of
32 | each block, instead of subsampling the input activations in the first residual
33 | unit of each block. The two implementations give identical results but our
34 | implementation is more memory efficient.
35 | """
36 | from __future__ import absolute_import
37 | from __future__ import division
38 | from __future__ import print_function
39 |
40 | import collections
41 | import tensorflow as tf
42 |
43 | slim = tf.contrib.slim
44 |
45 |
46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
47 | """A named tuple describing a ResNet block.
48 |
49 | Its parts are:
50 | scope: The scope of the `Block`.
51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and
52 | returns another `Tensor` with the output of the ResNet unit.
53 | args: A list of length equal to the number of units in the `Block`. The list
54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the
55 | block to serve as argument to unit_fn.
56 | """
57 |
58 |
59 | def subsample(inputs, factor, scope=None):
60 | """Subsamples the input along the spatial dimensions.
61 |
62 | Args:
63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels].
64 | factor: The subsampling factor.
65 | scope: Optional variable_scope.
66 |
67 | Returns:
68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the
69 | input, either intact (if factor == 1) or subsampled (if factor > 1).
70 | """
71 | if factor == 1:
72 | return inputs
73 | else:
74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
75 |
76 |
77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
78 | """Strided 2-D convolution with 'SAME' padding.
79 |
80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with
81 | 'VALID' padding.
82 |
83 | Note that
84 |
85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride)
86 |
87 | is equivalent to
88 |
89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
90 | net = subsample(net, factor=stride)
91 |
92 | whereas
93 |
94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
95 |
96 | is different when the input's height or width is even, which is why we add the
97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
98 |
99 | Args:
100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
101 | num_outputs: An integer, the number of output filters.
102 | kernel_size: An int with the kernel_size of the filters.
103 | stride: An integer, the output stride.
104 | rate: An integer, rate for atrous convolution.
105 | scope: Scope.
106 |
107 | Returns:
108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with
109 | the convolution output.
110 | """
111 | if stride == 1:
112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
113 | padding='SAME', scope=scope)
114 | else:
115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
116 | pad_total = kernel_size_effective - 1
117 | pad_beg = pad_total // 2
118 | pad_end = pad_total - pad_beg
119 | inputs = tf.pad(inputs,
120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
122 | rate=rate, padding='VALID', scope=scope)
123 |
124 |
125 | @slim.add_arg_scope
126 | def stack_blocks_dense(net, blocks, output_stride=None,
127 | outputs_collections=None):
128 | """Stacks ResNet `Blocks` and controls output feature density.
129 |
130 | First, this function creates scopes for the ResNet in the form of
131 | 'block_name/unit_1', 'block_name/unit_2', etc.
132 |
133 | Second, this function allows the user to explicitly control the ResNet
134 | output_stride, which is the ratio of the input to output spatial resolution.
135 | This is useful for dense prediction tasks such as semantic segmentation or
136 | object detection.
137 |
138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a
139 | factor of 2 when transitioning between consecutive ResNet blocks. This results
140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to
141 | half the nominal network stride (e.g., output_stride=4), then we compute
142 | responses twice.
143 |
144 | Control of the output feature density is implemented by atrous convolution.
145 |
146 | Args:
147 | net: A `Tensor` of size [batch, height, width, channels].
148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each
149 | element is a ResNet `Block` object describing the units in the `Block`.
150 | output_stride: If `None`, then the output will be computed at the nominal
151 | network stride. If output_stride is not `None`, it specifies the requested
152 | ratio of input to output spatial resolution, which needs to be equal to
153 | the product of unit strides from the start up to some level of the ResNet.
154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which
156 | is equivalent to output_stride=24).
157 | outputs_collections: Collection to add the ResNet block outputs.
158 |
159 | Returns:
160 | net: Output tensor with stride equal to the specified output_stride.
161 |
162 | Raises:
163 | ValueError: If the target output_stride is not valid.
164 | """
165 | # The current_stride variable keeps track of the effective stride of the
166 | # activations. This allows us to invoke atrous convolution whenever applying
167 | # the next residual unit would result in the activations having stride larger
168 | # than the target output_stride.
169 | current_stride = 1
170 |
171 | # The atrous convolution rate parameter.
172 | rate = 1
173 |
174 | for block in blocks:
175 | with tf.variable_scope(block.scope, 'block', [net]) as sc:
176 | for i, unit in enumerate(block.args):
177 | if output_stride is not None and current_stride > output_stride:
178 | raise ValueError('The target output_stride cannot be reached.')
179 |
180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
181 | unit_depth, unit_depth_bottleneck, unit_stride = unit
182 |
183 | # If we have reached the target output_stride, then we need to employ
184 | # atrous convolution with stride=1 and multiply the atrous rate by the
185 | # current unit's stride for use in subsequent layers.
186 | if output_stride is not None and current_stride == output_stride:
187 | net = block.unit_fn(net,
188 | depth=unit_depth,
189 | depth_bottleneck=unit_depth_bottleneck,
190 | stride=1,
191 | rate=rate)
192 | rate *= unit_stride
193 |
194 | else:
195 | net = block.unit_fn(net,
196 | depth=unit_depth,
197 | depth_bottleneck=unit_depth_bottleneck,
198 | stride=unit_stride,
199 | rate=1)
200 | current_stride *= unit_stride
201 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
202 |
203 | if output_stride is not None and current_stride != output_stride:
204 | raise ValueError('The target output_stride cannot be reached.')
205 |
206 | return net
207 |
208 |
209 | def resnet_arg_scope(weight_decay=0.0001,
210 | batch_norm_decay=0.997,
211 | batch_norm_epsilon=1e-5,
212 | batch_norm_scale=True):
213 | """Defines the default ResNet arg scope.
214 |
215 | TODO(gpapan): The batch-normalization related default values above are
216 | appropriate for use in conjunction with the reference ResNet models
217 | released at https://github.com/KaimingHe/deep-residual-networks. When
218 | training ResNets from scratch, they might need to be tuned.
219 |
220 | Args:
221 | weight_decay: The weight decay to use for regularizing the model.
222 | batch_norm_decay: The moving average decay when estimating layer activation
223 | statistics in batch normalization.
224 | batch_norm_epsilon: Small constant to prevent division by zero when
225 | normalizing activations by their variance in batch normalization.
226 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
227 | activations in the batch normalization layer.
228 |
229 | Returns:
230 | An `arg_scope` to use for the resnet models.
231 | """
232 | batch_norm_params = {
233 | 'decay': batch_norm_decay,
234 | 'epsilon': batch_norm_epsilon,
235 | 'scale': batch_norm_scale,
236 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
237 | }
238 |
239 | with slim.arg_scope(
240 | [slim.conv2d],
241 | weights_regularizer=slim.l2_regularizer(weight_decay),
242 | weights_initializer=slim.variance_scaling_initializer(),
243 | activation_fn=tf.nn.relu,
244 | normalizer_fn=slim.batch_norm,
245 | normalizer_params=batch_norm_params):
246 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
247 | # The following implies padding='SAME' for pool1, which makes feature
248 | # alignment easier for dense prediction tasks. This is also used in
249 | # https://github.com/facebook/fb.resnet.torch. However the accompanying
250 | # code of 'Deep Residual Learning for Image Recognition' uses
251 | # padding='VALID' for pool1. You can switch to that choice by setting
252 | # slim.arg_scope([slim.max_pool2d], padding='VALID').
253 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
254 | return arg_sc
255 |
--------------------------------------------------------------------------------
/nets/resnet_v1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains definitions for the original form of Residual Networks.
16 |
17 | The 'v1' residual networks (ResNets) implemented in this module were proposed
18 | by:
19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
21 |
22 | Other variants were introduced in:
23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25 |
26 | The networks defined in this module utilize the bottleneck building block of
27 | [1] with projection shortcuts only for increasing depths. They employ batch
28 | normalization *after* every weight layer. This is the architecture used by
29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and
30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1'
31 | architecture and the alternative 'v2' architecture of [2] which uses batch
32 | normalization *before* every weight layer in the so-called full pre-activation
33 | units.
34 |
35 | Typical use:
36 |
37 | from tensorflow.contrib.slim.nets import resnet_v1
38 |
39 | ResNet-101 for image classification into 1000 classes:
40 |
41 | # inputs has shape [batch, 224, 224, 3]
42 | with slim.arg_scope(resnet_v1.resnet_arg_scope()):
43 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False)
44 |
45 | ResNet-101 for semantic segmentation into 21 classes:
46 |
47 | # inputs has shape [batch, 513, 513, 3]
48 | with slim.arg_scope(resnet_v1.resnet_arg_scope()):
49 | net, end_points = resnet_v1.resnet_v1_101(inputs,
50 | 21,
51 | is_training=False,
52 | global_pool=False,
53 | output_stride=16)
54 | """
55 | from __future__ import absolute_import
56 | from __future__ import division
57 | from __future__ import print_function
58 |
59 | import tensorflow as tf
60 |
61 | from nets import resnet_utils
62 |
63 |
64 | resnet_arg_scope = resnet_utils.resnet_arg_scope
65 | slim = tf.contrib.slim
66 |
67 |
68 | @slim.add_arg_scope
69 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
70 | outputs_collections=None, scope=None):
71 | """Bottleneck residual unit variant with BN after convolutions.
72 |
73 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
74 | its definition. Note that we use here the bottleneck variant which has an
75 | extra bottleneck layer.
76 |
77 | When putting together two consecutive ResNet blocks that use this unit, one
78 | should use stride = 2 in the last unit of the first block.
79 |
80 | Args:
81 | inputs: A tensor of size [batch, height, width, channels].
82 | depth: The depth of the ResNet unit output.
83 | depth_bottleneck: The depth of the bottleneck layers.
84 | stride: The ResNet unit's stride. Determines the amount of downsampling of
85 | the units output compared to its input.
86 | rate: An integer, rate for atrous convolution.
87 | outputs_collections: Collection to add the ResNet unit output.
88 | scope: Optional variable_scope.
89 |
90 | Returns:
91 | The ResNet unit's output.
92 | """
93 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
94 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
95 | if depth == depth_in:
96 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
97 | else:
98 | shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride,
99 | activation_fn=None, scope='shortcut')
100 |
101 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
102 | scope='conv1')
103 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
104 | rate=rate, scope='conv2')
105 | residual = slim.conv2d(residual, depth, [1, 1], stride=1,
106 | activation_fn=None, scope='conv3')
107 |
108 | output = tf.nn.relu(shortcut + residual)
109 |
110 | return slim.utils.collect_named_outputs(outputs_collections,
111 | sc.original_name_scope,
112 | output)
113 |
114 |
115 | def resnet_v1(inputs,
116 | blocks,
117 | num_classes=None,
118 | is_training=True,
119 | global_pool=True,
120 | output_stride=None,
121 | include_root_block=True,
122 | reuse=None,
123 | scope=None):
124 | """Generator for v1 ResNet models.
125 |
126 | This function generates a family of ResNet v1 models. See the resnet_v1_*()
127 | methods for specific model instantiations, obtained by selecting different
128 | block instantiations that produce ResNets of various depths.
129 |
130 | Training for image classification on Imagenet is usually done with [224, 224]
131 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet
132 | block for the ResNets defined in [1] that have nominal stride equal to 32.
133 | However, for dense prediction tasks we advise that one uses inputs with
134 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
135 | this case the feature maps at the ResNet output will have spatial shape
136 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
137 | and corners exactly aligned with the input image corners, which greatly
138 | facilitates alignment of the features to the image. Using as input [225, 225]
139 | images results in [8, 8] feature maps at the output of the last ResNet block.
140 |
141 | For dense prediction tasks, the ResNet needs to run in fully-convolutional
142 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
143 | have nominal stride equal to 32 and a good choice in FCN mode is to use
144 | output_stride=16 in order to increase the density of the computed features at
145 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
146 |
147 | Args:
148 | inputs: A tensor of size [batch, height_in, width_in, channels].
149 | blocks: A list of length equal to the number of ResNet blocks. Each element
150 | is a resnet_utils.Block object describing the units in the block.
151 | num_classes: Number of predicted classes for classification tasks. If None
152 | we return the features before the logit layer.
153 | is_training: whether is training or not.
154 | global_pool: If True, we perform global average pooling before computing the
155 | logits. Set to True for image classification, False for dense prediction.
156 | output_stride: If None, then the output will be computed at the nominal
157 | network stride. If output_stride is not None, it specifies the requested
158 | ratio of input to output spatial resolution.
159 | include_root_block: If True, include the initial convolution followed by
160 | max-pooling, if False excludes it.
161 | reuse: whether or not the network and its variables should be reused. To be
162 | able to reuse 'scope' must be given.
163 | scope: Optional variable_scope.
164 |
165 | Returns:
166 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
167 | If global_pool is False, then height_out and width_out are reduced by a
168 | factor of output_stride compared to the respective height_in and width_in,
169 | else both height_out and width_out equal one. If num_classes is None, then
170 | net is the output of the last ResNet block, potentially after global
171 | average pooling. If num_classes is not None, net contains the pre-softmax
172 | activations.
173 | end_points: A dictionary from components of the network to the corresponding
174 | activation.
175 |
176 | Raises:
177 | ValueError: If the target output_stride is not valid.
178 | """
179 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
180 | end_points_collection = sc.name + '_end_points'
181 | with slim.arg_scope([slim.conv2d, bottleneck,
182 | resnet_utils.stack_blocks_dense],
183 | outputs_collections=end_points_collection):
184 | with slim.arg_scope([slim.batch_norm], is_training=is_training):
185 | net = inputs
186 | if include_root_block:
187 | if output_stride is not None:
188 | if output_stride % 4 != 0:
189 | raise ValueError('The output_stride needs to be a multiple of 4.')
190 | output_stride /= 4
191 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
192 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
193 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
194 | if global_pool:
195 | # Global average pooling.
196 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
197 | if num_classes is not None:
198 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
199 | normalizer_fn=None, scope='logits')
200 | # Convert end_points_collection into a dictionary of end_points.
201 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
202 | if num_classes is not None:
203 | end_points['predictions'] = slim.softmax(net, scope='predictions')
204 | return net, end_points
205 | resnet_v1.default_image_size = 224
206 |
207 |
208 | def resnet_v1_50(inputs,
209 | num_classes=None,
210 | is_training=True,
211 | global_pool=True,
212 | output_stride=None,
213 | reuse=None,
214 | scope='resnet_v1_50'):
215 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
216 | blocks = [
217 | resnet_utils.Block(
218 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
219 | resnet_utils.Block(
220 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
221 | resnet_utils.Block(
222 | 'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
223 | resnet_utils.Block(
224 | 'block4', bottleneck, [(2048, 512, 1)] * 3)
225 | ]
226 | return resnet_v1(inputs, blocks, num_classes, is_training,
227 | global_pool=global_pool, output_stride=output_stride,
228 | include_root_block=True, reuse=reuse, scope=scope)
229 |
230 |
231 | def resnet_v1_101(inputs,
232 | num_classes=None,
233 | is_training=True,
234 | global_pool=True,
235 | output_stride=None,
236 | reuse=None,
237 | scope='resnet_v1_101'):
238 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
239 | blocks = [
240 | resnet_utils.Block(
241 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
242 | resnet_utils.Block(
243 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
244 | resnet_utils.Block(
245 | 'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
246 | resnet_utils.Block(
247 | 'block4', bottleneck, [(2048, 512, 1)] * 3)
248 | ]
249 | return resnet_v1(inputs, blocks, num_classes, is_training,
250 | global_pool=global_pool, output_stride=output_stride,
251 | include_root_block=True, reuse=reuse, scope=scope)
252 |
253 |
254 | def resnet_v1_152(inputs,
255 | num_classes=None,
256 | is_training=True,
257 | global_pool=True,
258 | output_stride=None,
259 | reuse=None,
260 | scope='resnet_v1_152'):
261 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
262 | blocks = [
263 | resnet_utils.Block(
264 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
265 | resnet_utils.Block(
266 | 'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]),
267 | resnet_utils.Block(
268 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
269 | resnet_utils.Block(
270 | 'block4', bottleneck, [(2048, 512, 1)] * 3)]
271 | return resnet_v1(inputs, blocks, num_classes, is_training,
272 | global_pool=global_pool, output_stride=output_stride,
273 | include_root_block=True, reuse=reuse, scope=scope)
274 |
275 |
276 | def resnet_v1_200(inputs,
277 | num_classes=None,
278 | is_training=True,
279 | global_pool=True,
280 | output_stride=None,
281 | reuse=None,
282 | scope='resnet_v1_200'):
283 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
284 | blocks = [
285 | resnet_utils.Block(
286 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
287 | resnet_utils.Block(
288 | 'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]),
289 | resnet_utils.Block(
290 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
291 | resnet_utils.Block(
292 | 'block4', bottleneck, [(2048, 512, 1)] * 3)]
293 | return resnet_v1(inputs, blocks, num_classes, is_training,
294 | global_pool=global_pool, output_stride=output_stride,
295 | include_root_block=True, reuse=reuse, scope=scope)
296 |
--------------------------------------------------------------------------------
/nets/vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains model definitions for versions of the Oxford VGG network.
16 | These model definitions were introduced in the following technical report:
17 | Very Deep Convolutional Networks For Large-Scale Image Recognition
18 | Karen Simonyan and Andrew Zisserman
19 | arXiv technical report, 2015
20 | PDF: http://arxiv.org/pdf/1409.1556.pdf
21 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
22 | CC-BY-4.0
23 | More information can be obtained from the VGG website:
24 | www.robots.ox.ac.uk/~vgg/research/very_deep/
25 | Usage:
26 | with slim.arg_scope(vgg.vgg_arg_scope()):
27 | outputs, end_points = vgg.vgg_a(inputs)
28 | with slim.arg_scope(vgg.vgg_arg_scope()):
29 | outputs, end_points = vgg.vgg_16(inputs)
30 | @@vgg_a
31 | @@vgg_16
32 | @@vgg_19
33 | """
34 | from __future__ import absolute_import
35 | from __future__ import division
36 | from __future__ import print_function
37 |
38 | import tensorflow as tf
39 |
40 | slim = tf.contrib.slim
41 |
42 |
43 | def vgg_arg_scope(weight_decay=0.0005):
44 | """Defines the VGG arg scope.
45 | Args:
46 | weight_decay: The l2 regularization coefficient.
47 | Returns:
48 | An arg_scope.
49 | """
50 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
51 | activation_fn=tf.nn.relu,
52 | weights_regularizer=slim.l2_regularizer(weight_decay),
53 | biases_initializer=tf.zeros_initializer()):
54 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
55 | return arg_sc
56 |
57 |
58 | def vgg_a(inputs,
59 | num_classes=1000,
60 | is_training=True,
61 | dropout_keep_prob=0.5,
62 | spatial_squeeze=True,
63 | scope='vgg_a'):
64 | """Oxford Net VGG 11-Layers version A Example.
65 | Note: All the fully_connected layers have been transformed to conv2d layers.
66 | To use in classification mode, resize input to 224x224.
67 | Args:
68 | inputs: a tensor of size [batch_size, height, width, channels].
69 | num_classes: number of predicted classes.
70 | is_training: whether or not the model is being trained.
71 | dropout_keep_prob: the probability that activations are kept in the dropout
72 | layers during training.
73 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
74 | outputs. Useful to remove unnecessary dimensions for classification.
75 | scope: Optional scope for the variables.
76 | Returns:
77 | the last op containing the log predictions and end_points dict.
78 | """
79 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc:
80 | end_points_collection = sc.name + '_end_points'
81 | # Collect outputs for conv2d, fully_connected and max_pool2d.
82 | with slim.arg_scope([slim.conv2d, slim.max_pool2d],
83 | outputs_collections=end_points_collection):
84 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1')
85 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
86 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2')
87 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
88 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3')
89 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
90 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4')
91 | net = slim.max_pool2d(net, [2, 2], scope='pool4')
92 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
93 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
94 | # Use conv2d instead of fully_connected layers.
95 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
96 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
97 | scope='dropout6')
98 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
99 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
100 | scope='dropout7')
101 | net = slim.conv2d(net, num_classes, [1, 1],
102 | activation_fn=None,
103 | normalizer_fn=None,
104 | scope='fc8')
105 | # Convert end_points_collection into a end_point dict.
106 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
107 | if spatial_squeeze:
108 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
109 | end_points[sc.name + '/fc8'] = net
110 | return net, end_points
111 | vgg_a.default_image_size = 224
112 |
113 |
114 | def vgg_16(inputs,
115 | num_classes=1000,
116 | is_training=True,
117 | dropout_keep_prob=0.5,
118 | spatial_squeeze=True,
119 | scope='vgg_16'):
120 | """Oxford Net VGG 16-Layers version D Example.
121 | Note: All the fully_connected layers have been transformed to conv2d layers.
122 | To use in classification mode, resize input to 224x224.
123 | Args:
124 | inputs: a tensor of size [batch_size, height, width, channels].
125 | num_classes: number of predicted classes.
126 | is_training: whether or not the model is being trained.
127 | dropout_keep_prob: the probability that activations are kept in the dropout
128 | layers during training.
129 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
130 | outputs. Useful to remove unnecessary dimensions for classification.
131 | scope: Optional scope for the variables.
132 | Returns:
133 | the last op containing the log predictions and end_points dict.
134 | """
135 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
136 | end_points_collection = sc.name + '_end_points'
137 | # Collect outputs for conv2d, fully_connected and max_pool2d.
138 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
139 | outputs_collections=end_points_collection):
140 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
141 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
142 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
143 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
144 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
145 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
146 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
147 | net = slim.max_pool2d(net, [2, 2], scope='pool4')
148 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
149 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
150 | # Use conv2d instead of fully_connected layers.
151 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
152 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
153 | scope='dropout6')
154 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
155 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
156 | scope='dropout7')
157 | net = slim.conv2d(net, num_classes, [1, 1],
158 | activation_fn=None,
159 | normalizer_fn=None,
160 | scope='fc8')
161 | # Convert end_points_collection into a end_point dict.
162 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
163 | if spatial_squeeze:
164 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
165 | end_points[sc.name + '/fc8'] = net
166 | return net, end_points
167 | vgg_16.default_image_size = 224
168 |
169 |
170 | def vgg_19(inputs,
171 | num_classes=1000,
172 | is_training=True,
173 | dropout_keep_prob=0.5,
174 | spatial_squeeze=True,
175 | scope='vgg_19'):
176 | """Oxford Net VGG 19-Layers version E Example.
177 | Note: All the fully_connected layers have been transformed to conv2d layers.
178 | To use in classification mode, resize input to 224x224.
179 | Args:
180 | inputs: a tensor of size [batch_size, height, width, channels].
181 | num_classes: number of predicted classes.
182 | is_training: whether or not the model is being trained.
183 | dropout_keep_prob: the probability that activations are kept in the dropout
184 | layers during training.
185 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
186 | outputs. Useful to remove unnecessary dimensions for classification.
187 | scope: Optional scope for the variables.
188 | Returns:
189 | the last op containing the log predictions and end_points dict.
190 | """
191 | with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc:
192 | end_points_collection = sc.name + '_end_points'
193 | # Collect outputs for conv2d, fully_connected and max_pool2d.
194 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
195 | outputs_collections=end_points_collection):
196 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
197 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
198 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
199 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
200 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3')
201 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
202 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4')
203 | net = slim.max_pool2d(net, [2, 2], scope='pool4')
204 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5')
205 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
206 | # Use conv2d instead of fully_connected layers.
207 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
208 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
209 | scope='dropout6')
210 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
211 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
212 | scope='dropout7')
213 | net = slim.conv2d(net, num_classes, [1, 1],
214 | activation_fn=None,
215 | normalizer_fn=None,
216 | scope='fc8')
217 | # Convert end_points_collection into a end_point dict.
218 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
219 | if spatial_squeeze:
220 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
221 | end_points[sc.name + '/fc8'] = net
222 | return net, end_points
223 | vgg_19.default_image_size = 224
224 |
225 | # Alias
226 | vgg_d = vgg_16
227 | vgg_e = vgg_19
--------------------------------------------------------------------------------
/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/preprocessing/cifarnet_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides utilities to preprocess images in CIFAR-10.
16 |
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow as tf
24 |
25 | _PADDING = 4
26 |
27 | slim = tf.contrib.slim
28 |
29 |
30 | def preprocess_for_train(image,
31 | output_height,
32 | output_width,
33 | padding=_PADDING):
34 | """Preprocesses the given image for training.
35 |
36 | Note that the actual resizing scale is sampled from
37 | [`resize_size_min`, `resize_size_max`].
38 |
39 | Args:
40 | image: A `Tensor` representing an image of arbitrary size.
41 | output_height: The height of the image after preprocessing.
42 | output_width: The width of the image after preprocessing.
43 | padding: The amound of padding before and after each dimension of the image.
44 |
45 | Returns:
46 | A preprocessed image.
47 | """
48 | tf.image_summary('image', tf.expand_dims(image, 0))
49 |
50 | # Transform the image to floats.
51 | image = tf.to_float(image)
52 | if padding > 0:
53 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]])
54 | # Randomly crop a [height, width] section of the image.
55 | distorted_image = tf.random_crop(image,
56 | [output_height, output_width, 3])
57 |
58 | # Randomly flip the image horizontally.
59 | distorted_image = tf.image.random_flip_left_right(distorted_image)
60 |
61 | tf.image_summary('distorted_image', tf.expand_dims(distorted_image, 0))
62 |
63 | # Because these operations are not commutative, consider randomizing
64 | # the order their operation.
65 | distorted_image = tf.image.random_brightness(distorted_image,
66 | max_delta=63)
67 | distorted_image = tf.image.random_contrast(distorted_image,
68 | lower=0.2, upper=1.8)
69 | # Subtract off the mean and divide by the variance of the pixels.
70 | return tf.image.per_image_whitening(distorted_image)
71 |
72 |
73 | def preprocess_for_eval(image, output_height, output_width):
74 | """Preprocesses the given image for evaluation.
75 |
76 | Args:
77 | image: A `Tensor` representing an image of arbitrary size.
78 | output_height: The height of the image after preprocessing.
79 | output_width: The width of the image after preprocessing.
80 |
81 | Returns:
82 | A preprocessed image.
83 | """
84 | tf.image_summary('image', tf.expand_dims(image, 0))
85 | # Transform the image to floats.
86 | image = tf.to_float(image)
87 |
88 | # Resize and crop if needed.
89 | resized_image = tf.image.resize_image_with_crop_or_pad(image,
90 | output_width,
91 | output_height)
92 | tf.image_summary('resized_image', tf.expand_dims(resized_image, 0))
93 |
94 | # Subtract off the mean and divide by the variance of the pixels.
95 | return tf.image.per_image_whitening(resized_image)
96 |
97 |
98 | def preprocess_image(image, output_height, output_width, is_training=False):
99 | """Preprocesses the given image.
100 |
101 | Args:
102 | image: A `Tensor` representing an image of arbitrary size.
103 | output_height: The height of the image after preprocessing.
104 | output_width: The width of the image after preprocessing.
105 | is_training: `True` if we're preprocessing the image for training and
106 | `False` otherwise.
107 |
108 | Returns:
109 | A preprocessed image.
110 | """
111 | if is_training:
112 | return preprocess_for_train(image, output_height, output_width)
113 | else:
114 | return preprocess_for_eval(image, output_height, output_width)
115 |
--------------------------------------------------------------------------------
/preprocessing/inception_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides utilities to preprocess images for the Inception networks."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from tensorflow.python.ops import control_flow_ops
24 |
25 |
26 | def apply_with_random_selector(x, func, num_cases):
27 | """Computes func(x, sel), with sel sampled from [0...num_cases-1].
28 |
29 | Args:
30 | x: input Tensor.
31 | func: Python function to apply.
32 | num_cases: Python int32, number of cases to sample sel from.
33 |
34 | Returns:
35 | The result of func(x, sel), where func receives the value of the
36 | selector as a python integer, but sel is sampled dynamically.
37 | """
38 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
39 | # Pass the real x only to one of the func calls.
40 | return control_flow_ops.merge([
41 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
42 | for case in range(num_cases)])[0]
43 |
44 |
45 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
46 | """Distort the color of a Tensor image.
47 |
48 | Each color distortion is non-commutative and thus ordering of the color ops
49 | matters. Ideally we would randomly permute the ordering of the color ops.
50 | Rather then adding that level of complication, we select a distinct ordering
51 | of color ops for each preprocessing thread.
52 |
53 | Args:
54 | image: 3-D Tensor containing single image in [0, 1].
55 | color_ordering: Python int, a type of distortion (valid values: 0-3).
56 | fast_mode: Avoids slower ops (random_hue and random_contrast)
57 | scope: Optional scope for name_scope.
58 | Returns:
59 | 3-D Tensor color-distorted image on range [0, 1]
60 | Raises:
61 | ValueError: if color_ordering not in [0, 3]
62 | """
63 | with tf.name_scope(scope, 'distort_color', [image]):
64 | if fast_mode:
65 | if color_ordering == 0:
66 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
67 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
68 | else:
69 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
70 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
71 | else:
72 | if color_ordering == 0:
73 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
74 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
75 | image = tf.image.random_hue(image, max_delta=0.2)
76 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
77 | elif color_ordering == 1:
78 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
79 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
80 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
81 | image = tf.image.random_hue(image, max_delta=0.2)
82 | elif color_ordering == 2:
83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
84 | image = tf.image.random_hue(image, max_delta=0.2)
85 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
86 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
87 | elif color_ordering == 3:
88 | image = tf.image.random_hue(image, max_delta=0.2)
89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
90 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
91 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
92 | else:
93 | raise ValueError('color_ordering must be in [0, 3]')
94 |
95 | # The random_* ops do not necessarily clamp.
96 | return tf.clip_by_value(image, 0.0, 1.0)
97 |
98 |
99 | def distorted_bounding_box_crop(image,
100 | bbox,
101 | min_object_covered=0.1,
102 | aspect_ratio_range=(0.75, 1.33),
103 | area_range=(0.05, 1.0),
104 | max_attempts=100,
105 | scope=None):
106 | """Generates cropped_image using a one of the bboxes randomly distorted.
107 |
108 | See `tf.image.sample_distorted_bounding_box` for more documentation.
109 |
110 | Args:
111 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
112 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
113 | where each coordinate is [0, 1) and the coordinates are arranged
114 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
115 | image.
116 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
117 | area of the image must contain at least this fraction of any bounding box
118 | supplied.
119 | aspect_ratio_range: An optional list of `floats`. The cropped area of the
120 | image must have an aspect ratio = width / height within this range.
121 | area_range: An optional list of `floats`. The cropped area of the image
122 | must contain a fraction of the supplied image within in this range.
123 | max_attempts: An optional `int`. Number of attempts at generating a cropped
124 | region of the image of the specified constraints. After `max_attempts`
125 | failures, return the entire image.
126 | scope: Optional scope for name_scope.
127 | Returns:
128 | A tuple, a 3-D Tensor cropped_image and the distorted bbox
129 | """
130 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
131 | # Each bounding box has shape [1, num_boxes, box coords] and
132 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
133 |
134 | # A large fraction of image datasets contain a human-annotated bounding
135 | # box delineating the region of the image containing the object of interest.
136 | # We choose to create a new bounding box for the object which is a randomly
137 | # distorted version of the human-annotated bounding box that obeys an
138 | # allowed range of aspect ratios, sizes and overlap with the human-annotated
139 | # bounding box. If no box is supplied, then we assume the bounding box is
140 | # the entire image.
141 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
142 | tf.shape(image),
143 | bounding_boxes=bbox,
144 | min_object_covered=min_object_covered,
145 | aspect_ratio_range=aspect_ratio_range,
146 | area_range=area_range,
147 | max_attempts=max_attempts,
148 | use_image_if_no_bounding_boxes=True)
149 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
150 |
151 | # Crop the image to the specified bounding box.
152 | cropped_image = tf.slice(image, bbox_begin, bbox_size)
153 | return cropped_image, distort_bbox
154 |
155 |
156 | def preprocess_for_train(image, height, width, bbox,
157 | fast_mode=True,
158 | scope=None):
159 | """Distort one image for training a network.
160 |
161 | Distorting images provides a useful technique for augmenting the data
162 | set during training in order to make the network invariant to aspects
163 | of the image that do not effect the label.
164 |
165 | Additionally it would create image_summaries to display the different
166 | transformations applied to the image.
167 |
168 | Args:
169 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
170 | [0, 1], otherwise it would converted to tf.float32 assuming that the range
171 | is [0, MAX], where MAX is largest positive representable number for
172 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
173 | height: integer
174 | width: integer
175 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
176 | where each coordinate is [0, 1) and the coordinates are arranged
177 | as [ymin, xmin, ymax, xmax].
178 | fast_mode: Optional boolean, if True avoids slower transformations (i.e.
179 | bi-cubic resizing, random_hue or random_contrast).
180 | scope: Optional scope for name_scope.
181 | Returns:
182 | 3-D float Tensor of distorted image used for training with range [-1, 1].
183 | """
184 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
185 | if bbox is None:
186 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
187 | dtype=tf.float32,
188 | shape=[1, 1, 4])
189 | if image.dtype != tf.float32:
190 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
191 | # Each bounding box has shape [1, num_boxes, box coords] and
192 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
193 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
194 | bbox)
195 | tf.image_summary('image_with_bounding_boxes', image_with_box)
196 |
197 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
198 | # Restore the shape since the dynamic slice based upon the bbox_size loses
199 | # the third dimension.
200 | distorted_image.set_shape([None, None, 3])
201 | image_with_distorted_box = tf.image.draw_bounding_boxes(
202 | tf.expand_dims(image, 0), distorted_bbox)
203 | tf.image_summary('images_with_distorted_bounding_box',
204 | image_with_distorted_box)
205 |
206 | # This resizing operation may distort the images because the aspect
207 | # ratio is not respected. We select a resize method in a round robin
208 | # fashion based on the thread number.
209 | # Note that ResizeMethod contains 4 enumerated resizing methods.
210 |
211 | # We select only 1 case for fast_mode bilinear.
212 | num_resize_cases = 1 if fast_mode else 4
213 | distorted_image = apply_with_random_selector(
214 | distorted_image,
215 | lambda x, method: tf.image.resize_images(x, [height, width], method=method),
216 | num_cases=num_resize_cases)
217 |
218 | tf.image_summary('cropped_resized_image',
219 | tf.expand_dims(distorted_image, 0))
220 |
221 | # Randomly flip the image horizontally.
222 | distorted_image = tf.image.random_flip_left_right(distorted_image)
223 |
224 | # Randomly distort the colors. There are 4 ways to do it.
225 | distorted_image = apply_with_random_selector(
226 | distorted_image,
227 | lambda x, ordering: distort_color(x, ordering, fast_mode),
228 | num_cases=4)
229 |
230 | tf.image_summary('final_distorted_image',
231 | tf.expand_dims(distorted_image, 0))
232 | distorted_image = tf.sub(distorted_image, 0.5)
233 | distorted_image = tf.mul(distorted_image, 2.0)
234 | return distorted_image
235 |
236 |
237 | def preprocess_for_eval(image, height, width,
238 | central_fraction=0.875, scope=None):
239 | """Prepare one image for evaluation.
240 |
241 | If height and width are specified it would output an image with that size by
242 | applying resize_bilinear.
243 |
244 | If central_fraction is specified it would cropt the central fraction of the
245 | input image.
246 |
247 | Args:
248 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
249 | [0, 1], otherwise it would converted to tf.float32 assuming that the range
250 | is [0, MAX], where MAX is largest positive representable number for
251 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
252 | height: integer
253 | width: integer
254 | central_fraction: Optional Float, fraction of the image to crop.
255 | scope: Optional scope for name_scope.
256 | Returns:
257 | 3-D float Tensor of prepared image.
258 | """
259 | with tf.name_scope(scope, 'eval_image', [image, height, width]):
260 | if image.dtype != tf.float32:
261 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
262 | # Crop the central region of the image with an area containing 87.5% of
263 | # the original image.
264 | if central_fraction:
265 | image = tf.image.central_crop(image, central_fraction=central_fraction)
266 |
267 | if height and width:
268 | # Resize the image to the specified height and width.
269 | image = tf.expand_dims(image, 0)
270 | image = tf.image.resize_bilinear(image, [height, width],
271 | align_corners=False)
272 | image = tf.squeeze(image, [0])
273 | image = tf.sub(image, 0.5)
274 | image = tf.mul(image, 2.0)
275 | return image
276 |
277 |
278 | def preprocess_image(image, height, width,
279 | is_training=False,
280 | bbox=None,
281 | fast_mode=True):
282 | """Pre-process one image for training or evaluation.
283 |
284 | Args:
285 | image: 3-D Tensor [height, width, channels] with the image.
286 | height: integer, image expected height.
287 | width: integer, image expected width.
288 | is_training: Boolean. If true it would transform an image for train,
289 | otherwise it would transform it for evaluation.
290 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
291 | where each coordinate is [0, 1) and the coordinates are arranged as
292 | [ymin, xmin, ymax, xmax].
293 | fast_mode: Optional boolean, if True avoids slower transformations.
294 |
295 | Returns:
296 | 3-D float Tensor containing an appropriately scaled image
297 |
298 | Raises:
299 | ValueError: if user does not provide bounding box
300 | """
301 | if is_training:
302 | return preprocess_for_train(image, height, width, bbox, fast_mode)
303 | else:
304 | return preprocess_for_eval(image, height, width)
305 |
--------------------------------------------------------------------------------
/preprocessing/lenet_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides utilities for preprocessing."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 |
26 | def preprocess_image(image, output_height, output_width, is_training):
27 | """Preprocesses the given image.
28 |
29 | Args:
30 | image: A `Tensor` representing an image of arbitrary size.
31 | output_height: The height of the image after preprocessing.
32 | output_width: The width of the image after preprocessing.
33 | is_training: `True` if we're preprocessing the image for training and
34 | `False` otherwise.
35 |
36 | Returns:
37 | A preprocessed image.
38 | """
39 | image = tf.to_float(image)
40 | image = tf.image.resize_image_with_crop_or_pad(
41 | image, output_width, output_height)
42 | image = tf.sub(image, 128.0)
43 | image = tf.div(image, 128.0)
44 | return image
45 |
--------------------------------------------------------------------------------
/preprocessing/preprocessing_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from preprocessing import cifarnet_preprocessing
24 | from preprocessing import inception_preprocessing
25 | from preprocessing import lenet_preprocessing
26 | from preprocessing import vgg_preprocessing
27 |
28 | slim = tf.contrib.slim
29 |
30 |
31 | def get_preprocessing(name, is_training=False):
32 | """Returns preprocessing_fn(image, height, width, **kwargs).
33 |
34 | Args:
35 | name: The name of the preprocessing function.
36 | is_training: `True` if the model is being used for training and `False`
37 | otherwise.
38 |
39 | Returns:
40 | preprocessing_fn: A function that preprocessing a single image (pre-batch).
41 | It has the following signature:
42 | image = preprocessing_fn(image, output_height, output_width, ...).
43 |
44 | Raises:
45 | ValueError: If Preprocessing `name` is not recognized.
46 | """
47 | preprocessing_fn_map = {
48 | 'cifarnet': cifarnet_preprocessing,
49 | 'inception': inception_preprocessing,
50 | 'inception_v1': inception_preprocessing,
51 | 'inception_v2': inception_preprocessing,
52 | 'inception_v3': inception_preprocessing,
53 | 'inception_v4': inception_preprocessing,
54 | 'inception_resnet_v2': inception_preprocessing,
55 | 'lenet': lenet_preprocessing,
56 | 'resnet_v1_50': vgg_preprocessing,
57 | 'resnet_v1_101': vgg_preprocessing,
58 | 'resnet_v1_152': vgg_preprocessing,
59 | 'vgg': vgg_preprocessing,
60 | 'vgg_a': vgg_preprocessing,
61 | 'vgg_16': vgg_preprocessing,
62 | 'vgg_19': vgg_preprocessing,
63 | }
64 |
65 | if name not in preprocessing_fn_map:
66 | raise ValueError('Preprocessing name [%s] was not recognized' % name)
67 |
68 | def preprocessing_fn(image, output_height, output_width, **kwargs):
69 | return preprocessing_fn_map[name].preprocess_image(
70 | image, output_height, output_width, is_training=is_training, **kwargs)
71 |
72 | def unprocessing_fn(image, **kwargs):
73 | return preprocessing_fn_map[name].unprocess_image(
74 | image, **kwargs)
75 |
76 | return preprocessing_fn, unprocessing_fn
77 |
--------------------------------------------------------------------------------
/reader.py:
--------------------------------------------------------------------------------
1 | from os import listdir
2 | from os.path import isfile, join
3 | import tensorflow as tf
4 |
5 |
6 | def get_image(path, height, width, preprocess_fn):
7 | png = path.lower().endswith('png')
8 | img_bytes = tf.read_file(path)
9 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
10 | return preprocess_fn(image, height, width)
11 |
12 |
13 | def image(batch_size, height, width, path, preprocess_fn, epochs=2, shuffle=True):
14 | filenames = [join(path, f) for f in listdir(path) if isfile(join(path, f))]
15 | if not shuffle:
16 | filenames = sorted(filenames)
17 |
18 | png = filenames[0].lower().endswith('png') # If first file is a png, assume they all are
19 |
20 | filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle, num_epochs=epochs)
21 | reader = tf.WholeFileReader()
22 | _, img_bytes = reader.read(filename_queue)
23 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
24 |
25 | processed_image = preprocess_fn(image, height, width)
26 | return tf.train.batch([processed_image], batch_size, dynamic_pad=True)
27 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | from __future__ import print_function
3 | from __future__ import division
4 | import tensorflow as tf
5 | from nets import nets_factory
6 | from preprocessing import preprocessing_factory
7 | import reader
8 | import model
9 | import time
10 | import losses
11 | import utils
12 | import os
13 | import argparse
14 |
15 | slim = tf.contrib.slim
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-c', '--conf', default='conf/mosaic.yml', help='the path to the conf file')
21 | return parser.parse_args()
22 |
23 |
24 | def main(FLAGS):
25 | style_features_t = losses.get_style_features(FLAGS)
26 |
27 | # Make sure the training path exists.
28 | training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
29 | if not(os.path.exists(training_path)):
30 | os.makedirs(training_path)
31 |
32 | with tf.Graph().as_default():
33 | with tf.Session() as sess:
34 | """Build Network"""
35 | network_fn = nets_factory.get_network_fn(
36 | FLAGS.loss_model,
37 | num_classes=1,
38 | is_training=False)
39 |
40 | image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
41 | FLAGS.loss_model,
42 | is_training=False)
43 | processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
44 | 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
45 | generated = model.net(processed_images, training=True)
46 | processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
47 | for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
48 | ]
49 | processed_generated = tf.stack(processed_generated)
50 | _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
51 |
52 | # Log the structure of loss network
53 | tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
54 | for key in endpoints_dict:
55 | tf.logging.info(key)
56 |
57 | """Build Losses"""
58 | content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
59 | style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
60 | tv_loss = losses.total_variation_loss(generated) # use the unprocessed image
61 |
62 | loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss
63 |
64 | # Add Summary for visualization in tensorboard.
65 | """Add Summary"""
66 | tf.summary.scalar('losses/content_loss', content_loss)
67 | tf.summary.scalar('losses/style_loss', style_loss)
68 | tf.summary.scalar('losses/regularizer_loss', tv_loss)
69 |
70 | tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
71 | tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
72 | tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
73 | tf.summary.scalar('total_loss', loss)
74 |
75 | for layer in FLAGS.style_layers:
76 | tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
77 | tf.summary.image('generated', generated)
78 | # tf.image_summary('processed_generated', processed_generated) # May be better?
79 | tf.summary.image('origin', tf.stack([
80 | image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
81 | ]))
82 | summary = tf.summary.merge_all()
83 | writer = tf.summary.FileWriter(training_path)
84 |
85 | """Prepare to Train"""
86 | global_step = tf.Variable(0, name="global_step", trainable=False)
87 |
88 | variable_to_train = []
89 | for variable in tf.trainable_variables():
90 | if not(variable.name.startswith(FLAGS.loss_model)):
91 | variable_to_train.append(variable)
92 | train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)
93 |
94 | variables_to_restore = []
95 | for v in tf.global_variables():
96 | if not(v.name.startswith(FLAGS.loss_model)):
97 | variables_to_restore.append(v)
98 | saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
99 |
100 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
101 |
102 | # Restore variables for loss network.
103 | init_func = utils._get_init_fn(FLAGS)
104 | init_func(sess)
105 |
106 | # Restore variables for training model if the checkpoint file exists.
107 | last_file = tf.train.latest_checkpoint(training_path)
108 | if last_file:
109 | tf.logging.info('Restoring model from {}'.format(last_file))
110 | saver.restore(sess, last_file)
111 |
112 | """Start Training"""
113 | coord = tf.train.Coordinator()
114 | threads = tf.train.start_queue_runners(coord=coord)
115 | start_time = time.time()
116 | try:
117 | while not coord.should_stop():
118 | _, loss_t, step = sess.run([train_op, loss, global_step])
119 | elapsed_time = time.time() - start_time
120 | start_time = time.time()
121 | """logging"""
122 | # print(step)
123 | if step % 10 == 0:
124 | tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
125 | """summary"""
126 | if step % 25 == 0:
127 | tf.logging.info('adding summary...')
128 | summary_str = sess.run(summary)
129 | writer.add_summary(summary_str, step)
130 | writer.flush()
131 | """checkpoint"""
132 | if step % 1000 == 0:
133 | saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
134 | except tf.errors.OutOfRangeError:
135 | saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
136 | tf.logging.info('Done training -- epoch limit reached')
137 | finally:
138 | coord.request_stop()
139 | coord.join(threads)
140 |
141 |
142 | if __name__ == '__main__':
143 | tf.logging.set_verbosity(tf.logging.INFO)
144 | args = parse_args()
145 | FLAGS = utils.read_conf_file(args.conf)
146 | main(FLAGS)
147 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import yaml
3 |
4 | slim = tf.contrib.slim
5 |
6 |
7 | def _get_init_fn(FLAGS):
8 | """
9 | This function is copied from TF slim.
10 |
11 | Returns a function run by the chief worker to warm-start the training.
12 |
13 | Note that the init_fn is only run when initializing the model during the very
14 | first global step.
15 |
16 | Returns:
17 | An init function run by the supervisor.
18 | """
19 | tf.logging.info('Use pretrained model %s' % FLAGS.loss_model_file)
20 |
21 | exclusions = []
22 | if FLAGS.checkpoint_exclude_scopes:
23 | exclusions = [scope.strip()
24 | for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
25 |
26 | # TODO(sguada) variables.filter_variables()
27 | variables_to_restore = []
28 | for var in slim.get_model_variables():
29 | excluded = False
30 | for exclusion in exclusions:
31 | if var.op.name.startswith(exclusion):
32 | excluded = True
33 | break
34 | if not excluded:
35 | variables_to_restore.append(var)
36 |
37 | return slim.assign_from_checkpoint_fn(
38 | FLAGS.loss_model_file,
39 | variables_to_restore,
40 | ignore_missing_vars=True)
41 |
42 |
43 | class Flag(object):
44 | def __init__(self, **entries):
45 | self.__dict__.update(entries)
46 |
47 |
48 | def read_conf_file(conf_file):
49 | with open(conf_file) as f:
50 | FLAGS = Flag(**yaml.load(f))
51 | return FLAGS
52 |
53 |
54 | def mean_image_subtraction(image, means):
55 | image = tf.to_float(image)
56 |
57 | num_channels = 3
58 | channels = tf.split(image, num_channels, 2)
59 | for i in range(num_channels):
60 | channels[i] -= means[i]
61 | return tf.concat(channels, 2)
62 |
63 |
64 | if __name__ == '__main__':
65 | f = read_conf_file('conf/mosaic.yml')
66 | print(f.loss_model_file)
67 |
--------------------------------------------------------------------------------